From 804fa20782255031e92ff029c5ca474839991f6f Mon Sep 17 00:00:00 2001 From: yhirose Date: Mon, 29 Dec 2025 12:23:58 -0500 Subject: [PATCH] Phase 2 --- docs/tls/checklist.md | 28 +- docs/tls/httplib_tls_migration.ja.md | 72 ++++ httplib.h | 554 +++++++++++++++++---------- 3 files changed, 445 insertions(+), 209 deletions(-) diff --git a/docs/tls/checklist.md b/docs/tls/checklist.md index 4ecb1cf..c765196 100644 --- a/docs/tls/checklist.md +++ b/docs/tls/checklist.md @@ -31,17 +31,21 @@ ### フェーズ 2: SSLClient と SSLSocketStream の移行 -- [ ] `SSLSocketStream` が `tls_session_t` を受け取るように変更 -- [ ] `SSLSocketStream::read()` → `tls_read()` に置き換え -- [ ] `SSLSocketStream::write()` → `tls_write()` に置き換え -- [ ] `SSLClient` が抽象化 API を使用 -- [ ] CA 証明書の読み込みが動作 -- [ ] クライアント証明書認証が動作 -- [ ] SNI とホスト名検証が動作 -- [ ] 大容量データ転送テスト通過 -- [ ] タイムアウト処理が正常動作 -- [ ] 不要になった `detail::ssl_new()` 等を削除 -- [ ] `make test_split` 通過 +- [x] `SSLSocketStream` が `tls_session_t` を受け取るように変更 +- [x] `SSLSocketStream::read()` → `tls_read()` に置き換え +- [x] `SSLSocketStream::write()` → `tls_write()` に置き換え +- [x] `SSLClient` が抽象化 API を使用 + - [x] `SSLClient::initialize_ssl()` を `tls_create_session()`, `tls_connect_nonblocking()` 等で移行 + - [x] `tls_set_sni()` を追加(SNI のみ設定、検証モードは変更しない) + - [x] `tls_connect_nonblocking()`, `tls_accept_nonblocking()` を追加 +- [x] CA 証明書の読み込みが動作 +- [x] クライアント証明書認証が動作 +- [x] SNI とホスト名検証が動作 +- [x] 大容量データ転送テスト通過 +- [x] タイムアウト処理が正常動作 + - [x] `tls_is_peer_closed()` に `socket_t sock` パラメータを追加して修正 +- [ ] 不要になった `detail::ssl_new()` 等を削除 → Phase 3 で `SSLServer` 移行後に実施 +- [x] `make test_split` 通過 ### フェーズ 3: SSLServer の移行 @@ -53,7 +57,7 @@ ### フェーズ 4: 残りの detail ヘルパーの削除とクリーンアップ -- [ ] `detail::is_ssl_peer_could_be_closed()` が削除または置き換え済み +- [x] `detail::is_ssl_peer_could_be_closed()` が削除または置き換え済み → `tls_is_peer_closed()` で置き換え - [ ] `detail::load_system_certs_on_windows()` が `tls_load_system_certs()` 内に統合済み - [ ] `detail::load_system_certs_on_macos()` が `tls_load_system_certs()` 内に統合済み - [ ] 未使用の SSL 関連 `detail::*` 関数がない diff --git a/docs/tls/httplib_tls_migration.ja.md b/docs/tls/httplib_tls_migration.ja.md index 92be8ba..b59e729 100644 --- a/docs/tls/httplib_tls_migration.ja.md +++ b/docs/tls/httplib_tls_migration.ja.md @@ -332,6 +332,78 @@ res.ssl_openssl_error() // OpenSSL エラーコード (ERR_get_error()) - `split.py` による分割ビルド (`make test_split`) との互換性を常に確認可能 - シングルファイル・ヘッダーオンリーの形態を維持 +### 実装時の教訓(Phase 2 より) + +Phase 2 の実装で学んだ重要な教訓: + +#### 1. 機能を無効化する際は代替実装を用意する + +一時的な回避策として `if (false)` やコメントアウトで機能を無効化すると、テストが通っているように見えても実際には機能が失われている場合がある。 + +**悪い例:** +```cpp +// TODO: need tls_peek() for is_ssl_peer_could_be_closed() +// if (detail::is_ssl_peer_could_be_closed(socket_.session, socket_.sock)) { +if (false) { + is_alive = false; +} +``` + +**良い例:** +```cpp +// 一時的に SSL* へキャストして既存関数を使用 +// TODO: tls_is_peer_closed() を TLS API に追加後、置き換える +if (detail::is_ssl_peer_could_be_closed( + static_cast(socket_.session), socket_.sock)) { + is_alive = false; +} +``` + +#### 2. 移行対象の機能を完全にリストアップする + +`SSL*` → `void*` への型変更時、全ての使用箇所を事前にリストアップする。特に: + +- 直接使用している関数 +- 間接的に依存している関数(例: `is_ssl_peer_could_be_closed()` は `SSL_peek()` を使用) +- テストで検証されている機能 + +#### 3. テスト失敗時はデバッグ出力で挙動を確認する + +テストが失敗した場合、「動作していない」と決めつけず、実際の挙動を確認する: + +```cpp +fprintf(stderr, "[DEBUG] wait_readable: max=%ld dur=%ld actual=%ld.%ld\n", + (long)max_timeout_msec_, (long)dur, (long)read_timeout_sec, + (long)read_timeout_usec); +``` + +Phase 2 では、タイムアウト自体は正常に動作していたが、サーバー側の接続終了検出が無効化されていたことが原因だった。 + +#### 4. 元のコードとの比較を早めに行う + +問題が発生したら、`git stash` で元のコードに戻してテストを実行し、変更が原因かどうかを素早く確認する: + +```bash +git stash # 変更を退避 +make test && ./test # 元のコードでテスト +git stash pop # 変更を復元 +``` + +#### 5. TODO コメントは具体的なタスクとして追跡する + +`// TODO:` コメントを書いただけで終わらせず、実際に実装するまで追跡する。本ドキュメントの「移行チェックリスト」または外部のタスク管理ツールで管理すること。 + +#### 6. テストの失敗メッセージを正確に読む + +``` +test.cc:11248: Failure +Value of: success + Actual: true +Expected: false +``` + +この情報から「サーバー側のコールバックが成功と認識している」ことが読み取れる。クライアント側ではなくサーバー側の問題であることを示唆していた。 + **httplib.h の構造:** ```text diff --git a/httplib.h b/httplib.h index 5bcc3c9..c61e407 100644 --- a/httplib.h +++ b/httplib.h @@ -1401,25 +1401,17 @@ private: struct ClientConnection { socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; + // Use void* directly since tls::tls_session_t is not yet defined here + void *session = nullptr; #endif bool is_open() const { return sock != INVALID_SOCKET; } ClientConnection() = default; - ~ClientConnection() { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (ssl) { - SSL_free(ssl); - ssl = nullptr; - } -#endif - if (sock != INVALID_SOCKET) { - detail::close_socket(sock); - sock = INVALID_SOCKET; - } - } + // Destructor defined after tls namespace is available (see implementation + // section) + ~ClientConnection(); ClientConnection(const ClientConnection &) = delete; ClientConnection &operator=(const ClientConnection &) = delete; @@ -1428,12 +1420,12 @@ struct ClientConnection { : sock(other.sock) #ifdef CPPHTTPLIB_OPENSSL_SUPPORT , - ssl(other.ssl) + session(other.session) #endif { other.sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; + other.session = nullptr; #endif } @@ -1441,11 +1433,11 @@ struct ClientConnection { if (this != &other) { sock = other.sock; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ssl = other.ssl; + session = other.session; #endif other.sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; + other.session = nullptr; #endif } return *this; @@ -1720,7 +1712,8 @@ protected: struct Socket { socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; + // Use void* directly since tls::tls_session_t is not yet defined here + void *session = nullptr; #endif bool is_open() const { return sock != INVALID_SOCKET; } @@ -2199,7 +2192,8 @@ private: bool verify_host_with_common_name(X509 *server_cert) const; bool check_host_name(const char *pattern, size_t pattern_len) const; - SSL_CTX *ctx_; + // Use void* directly since tls::tls_ctx_t is not yet defined here + void *ctx_; std::mutex ctx_mutex_; std::once_flag initialize_cert_; @@ -2818,12 +2812,21 @@ void tls_set_verify_client(tls_ctx_t ctx, bool require); // Session management tls_session_t tls_create_session(tls_ctx_t ctx, socket_t sock); void tls_free_session(tls_session_t session); +bool tls_set_sni(tls_session_t session, const char *hostname); bool tls_set_hostname(tls_session_t session, const char *hostname); // Handshake (non-blocking capable) TlsError tls_connect(tls_session_t session); TlsError tls_accept(tls_session_t session); +// Handshake with timeout (blocking until timeout) +bool tls_connect_nonblocking(tls_session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err); +bool tls_accept_nonblocking(tls_session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err); + // I/O (non-blocking capable) ssize_t tls_read(tls_session_t session, void *buf, size_t len, TlsError &err); ssize_t tls_write(tls_session_t session, const void *buf, size_t len, @@ -2832,7 +2835,7 @@ int tls_pending(tls_session_t session); void tls_shutdown(tls_session_t session, bool graceful); // Connection state -bool tls_is_peer_closed(tls_session_t session); +bool tls_is_peer_closed(tls_session_t session, socket_t sock); // Certificate verification tls_cert_t tls_get_peer_cert(tls_session_t session); @@ -4442,7 +4445,7 @@ private: class SSLSocketStream final : public Stream { public: SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, + socket_t sock, tls::tls_session_t session, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = @@ -4461,7 +4464,7 @@ public: private: socket_t sock_; - SSL *ssl_; + tls::tls_session_t session_; time_t read_timeout_sec_; time_t read_timeout_usec_; time_t write_timeout_sec_; @@ -10168,7 +10171,7 @@ inline void ClientImpl::close_socket(Socket &socket) { // It is also a bug if this happens while SSL is still active #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - assert(socket.ssl == nullptr); + assert(socket.session == nullptr); #endif if (socket.sock == INVALID_SOCKET) { return; } detail::close_socket(socket.sock); @@ -10236,7 +10239,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (detail::tls::tls_is_peer_closed(socket_.session, socket_.sock)) { is_alive = false; } } @@ -10423,7 +10426,7 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, is_alive = detail::is_socket_alive(socket_.sock); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (detail::tls::tls_is_peer_closed(socket_.session, socket_.sock)) { is_alive = false; } } @@ -10458,10 +10461,11 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && handle.connection_->ssl) { + if (is_ssl() && handle.connection_->session) { handle.socket_stream_ = detail::make_unique( - handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + handle.connection_->sock, handle.connection_->session, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_); } else { handle.socket_stream_ = detail::make_unique( handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, @@ -10681,8 +10685,8 @@ inline void ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { handle.connection_->sock = socket_.sock; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - handle.connection_->ssl = socket_.ssl; - socket_.ssl = nullptr; + handle.connection_->session = socket_.session; + socket_.session = nullptr; #endif socket_.sock = INVALID_SOCKET; } @@ -11285,7 +11289,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, if (is_ssl()) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (detail::tls::tls_is_peer_closed(socket_.session, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; output_error_log(error, &req); return false; @@ -12688,13 +12692,28 @@ inline void tls_free_session(tls_session_t session) { if (session) { SSL_free(static_cast(session)); } } +inline bool tls_set_sni(tls_session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) only - does not enable verification +#if defined(OPENSSL_IS_BORINGSSL) + return SSL_set_tlsext_host_name(ssl, hostname) == 1; +#else + // Direct call instead of macro to suppress -Wold-style-cast warning + return SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(hostname))) == 1; +#endif +} + inline bool tls_set_hostname(tls_session_t session, const char *hostname) { if (!session || !hostname) return false; auto ssl = static_cast(session); // Set SNI (Server Name Indication) - if (SSL_set_tlsext_host_name(ssl, hostname) != 1) { return false; } + if (!tls_set_sni(session, hostname)) { return false; } // Enable hostname verification auto param = SSL_get0_param(ssl); @@ -12741,6 +12760,96 @@ inline TlsError tls_accept(tls_session_t session) { return err; } +inline bool tls_connect_nonblocking(tls_session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_connect(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + if (err) { + err->code = map_ssl_error(ssl_err, err->sys_errno); + if (err->code == ErrorCode::Fatal) { + err->backend_code = ERR_get_error(); + } + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +inline bool tls_accept_nonblocking(tls_session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_accept(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + if (err) { + err->code = map_ssl_error(ssl_err, err->sys_errno); + if (err->code == ErrorCode::Fatal) { + err->backend_code = ERR_get_error(); + } + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + inline ssize_t tls_read(tls_session_t session, void *buf, size_t len, TlsError &err) { if (!session || !buf) { @@ -12801,16 +12910,20 @@ inline void tls_shutdown(tls_session_t session, bool graceful) { } } -inline bool tls_is_peer_closed(tls_session_t session) { +inline bool tls_is_peer_closed(tls_session_t session, socket_t sock) { if (!session) return true; + // Temporarily set socket to non-blocking to avoid blocking on SSL_peek + set_nonblocking(sock, true); + auto se = scope_exit([&]() { set_nonblocking(sock, false); }); + auto ssl = static_cast(session); char buf; auto ret = SSL_peek(ssl, &buf, 1); if (ret > 0) return false; auto err = SSL_get_error(ssl, ret); - return err == SSL_ERROR_ZERO_RETURN || err == SSL_ERROR_SYSCALL; + return err == SSL_ERROR_ZERO_RETURN; } inline tls_cert_t tls_get_peer_cert(tls_session_t session) { @@ -12847,6 +12960,20 @@ inline std::string tls_error_string(uint64_t code) { } // namespace tls } // namespace detail + +// ClientConnection destructor (defined here because tls namespace is now +// available) +inline ClientConnection::~ClientConnection() { + if (session) { + detail::tls::tls_shutdown(session, true); + detail::tls::tls_free_session(session); + session = nullptr; + } + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; + } +} #endif /* @@ -12965,22 +13092,27 @@ inline bool process_client_socket_ssl( // SSL socket stream implementation inline SSLSocketStream::SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, + socket_t sock, tls::tls_session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, std::chrono::time_point start_time) - : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + : sock_(sock), session_(session), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), max_timeout_msec_(max_timeout_msec), start_time_(start_time) { - SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); + // Clear AUTO_RETRY for proper non-blocking I/O timeout handling + // Note: tls_create_session() also clears this, but SSLClient currently + // uses ssl_new() which does not. Until full TLS API migration is complete, + // we need to ensure AUTO_RETRY is cleared here regardless of how the + // SSL session was created. + SSL_clear_mode(static_cast(session), SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { - return SSL_pending(ssl_) > 0; + return tls::tls_pending(session_) > 0; } inline bool SSLSocketStream::wait_readable() const { @@ -12998,39 +13130,41 @@ inline bool SSLSocketStream::wait_readable() const { inline bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); + is_socket_alive(sock_) && !tls::tls_is_peer_closed(session_, sock_); } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret == 0) { error_ = Error::ConnectionClosed; } + if (tls::tls_pending(session_) > 0) { + tls::TlsError err; + auto ret = tls::tls_read(session_, ptr, size, err); + if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } return ret; } else if (wait_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); + tls::TlsError err; + auto ret = tls::tls_read(session_, ptr, size, err); if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); auto n = 1000; #ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && + while (--n >= 0 && (err.code == tls::ErrorCode::WantRead || + (err.code == tls::ErrorCode::SyscallError && WSAGetLastError() == WSAETIMEDOUT))) { #else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { + while (--n >= 0 && err.code == tls::ErrorCode::WantRead) { #endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); + if (tls::tls_pending(session_) > 0) { + return tls::tls_read(session_, ptr, size, err); } else if (wait_readable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_read(ssl_, ptr, static_cast(size)); + ret = tls::tls_read(session_, ptr, size, err); if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); } else { break; } } assert(ret < 0); - } else if (ret == 0) { + } else if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { error_ = Error::ConnectionClosed; } return ret; @@ -13042,25 +13176,24 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { if (wait_writable()) { - auto handle_size = static_cast( - std::min(size, (std::numeric_limits::max)())); + auto handle_size = + std::min(size, (std::numeric_limits::max)()); - auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + tls::TlsError err; + auto ret = tls::tls_write(session_, ptr, handle_size, err); if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); auto n = 1000; #ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && + while (--n >= 0 && (err.code == tls::ErrorCode::WantWrite || + (err.code == tls::ErrorCode::SyscallError && WSAGetLastError() == WSAETIMEDOUT))) { #else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { + while (--n >= 0 && err.code == tls::ErrorCode::WantWrite) { #endif if (wait_writable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + ret = tls::tls_write(session_, ptr, handle_size, err); if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); } else { break; } @@ -13299,9 +13432,10 @@ inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_key_path, const std::string &private_key_password) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(TLS_client_method()); + ctx_ = detail::tls::tls_create_client_context(); - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + // TODO: Add tls_set_min_protocol_version() to TLS abstraction API + // SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { @@ -13309,18 +13443,13 @@ inline SSLClient::SSLClient(const std::string &host, int port, }); if (!client_cert_path.empty() && !client_key_path.empty()) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + if (!detail::tls::tls_set_client_cert_file(ctx_, client_cert_path.c_str(), + client_key_path.c_str(), + password)) { last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); + detail::tls::tls_free_context(ctx_); ctx_ = nullptr; } } @@ -13330,7 +13459,9 @@ inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key, const std::string &private_key_password) : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(TLS_client_method()); + // TODO: This constructor uses OpenSSL-specific types (X509, EVP_PKEY) + // Consider adding tls_set_client_cert_native() to TLS abstraction API + ctx_ = detail::tls::tls_create_client_context(); detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { @@ -13338,23 +13469,25 @@ inline SSLClient::SSLClient(const std::string &host, int port, }); if (client_cert != nullptr && client_key != nullptr) { + // Temporarily cast to SSL_CTX* for OpenSSL-specific operations + auto ssl_ctx = static_cast(ctx_); if (!private_key_password.empty()) { SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); + ssl_ctx, reinterpret_cast( + const_cast(private_key_password.c_str()))); } - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + if (SSL_CTX_use_certificate(ssl_ctx, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ssl_ctx, client_key) != 1) { last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); + detail::tls::tls_free_context(ctx_); ctx_ = nullptr; } } } inline SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } + if (ctx_) { detail::tls::tls_free_context(ctx_); } // Make sure to shut down SSL since shutdown_ssl will resolve to the // base function rather than the derived function once we get to the // base class destructor, and won't free the SSL (causing a leak). @@ -13366,10 +13499,11 @@ inline bool SSLClient::is_valid() const { return ctx_; } inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { if (ca_cert_store) { if (ctx_) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + auto ssl_ctx = static_cast(ctx_); + if (SSL_CTX_get_cert_store(ssl_ctx) != ca_cert_store) { // Free memory allocated for old cert and use new store // `ca_cert_store` - SSL_CTX_set_cert_store(ctx_, ca_cert_store); + SSL_CTX_set_cert_store(ssl_ctx, ca_cert_store); ca_cert_store_ = ca_cert_store; } } else { @@ -13387,7 +13521,9 @@ inline long SSLClient::get_openssl_verify_result() const { return verify_result_; } -inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } +inline SSL_CTX *SSLClient::ssl_context() const { + return static_cast(ctx_); +} inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { if (!is_valid()) { @@ -13498,27 +13634,23 @@ inline bool SSLClient::load_certs() { std::call_once(initialize_cert_, [&]() { std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { + if (!detail::tls::tls_load_ca_file(ctx_, ca_cert_file_path_.c_str())) { last_openssl_error_ = ERR_get_error(); ret = false; } } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { + if (!detail::tls::tls_load_ca_dir(ctx_, ca_cert_dir_path_.c_str())) { last_openssl_error_ = ERR_get_error(); ret = false; } } else { - auto loaded = false; -#ifdef _WIN32 - loaded = - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC - loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); -#endif // _WIN32 - if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + // Load system certificates + if (!detail::tls::tls_load_system_certs(ctx_)) { + last_openssl_error_ = ERR_get_error(); + // Ignore error and continue - some systems may not have certs + } } }); @@ -13526,102 +13658,126 @@ inline bool SSLClient::load_certs() { } inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - if (server_certificate_verification_) { - if (!load_certs()) { - error = Error::SSLLoadingCerts; - output_error_log(error, nullptr); - return false; - } - SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); - } + using namespace detail::tls; - if (!detail::ssl_connect_or_accept_nonblocking( - socket.sock, ssl2, SSL_connect, connection_timeout_sec_, - connection_timeout_usec_, &last_ssl_error_)) { - error = Error::SSLConnection; - output_error_log(error, nullptr); - return false; - } - - if (server_certificate_verification_) { - auto verification_status = SSLVerifierResponse::NoDecisionMade; - - if (server_certificate_verifier_) { - verification_status = server_certificate_verifier_(ssl2); - } - - if (verification_status == SSLVerifierResponse::CertificateRejected) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (verification_status == SSLVerifierResponse::NoDecisionMade) { - verify_result_ = SSL_get_verify_result(ssl2); - - if (verify_result_ != X509_V_OK) { - last_openssl_error_ = static_cast(verify_result_); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - auto se = detail::scope_exit([&] { X509_free(server_cert); }); - - if (server_cert == nullptr) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (server_hostname_verification_) { - if (!verify_host(server_cert)) { - last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; - error = Error::SSLServerHostnameVerification; - output_error_log(error, nullptr); - return false; - } - } - } - } - - return true; - }, - [&](SSL *ssl2) { - // Set SNI only if host is not IP address - if (!detail::is_ip_address(host_)) { -#if defined(OPENSSL_IS_BORINGSSL) - SSL_set_tlsext_host_name(ssl2, host_.c_str()); -#else - // NOTE: Direct call instead of using the OpenSSL macro to suppress - // -Wold-style-cast warning - SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, - TLSEXT_NAMETYPE_host_name, - static_cast(const_cast(host_.c_str()))); -#endif - } - return true; - }); - - if (ssl) { - socket.ssl = ssl; - return true; + // Load CA certificates if server verification is enabled + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + output_error_log(error, nullptr); + return false; + } } - if (ctx_ == nullptr) { + // Create TLS session (uses ctx_ which has SSL_VERIFY_NONE by default) + tls_session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = tls_create_session(ctx_, socket.sock); + } + + if (!session) { error = Error::SSLConnection; last_openssl_error_ = ERR_get_error(); + shutdown_socket(socket); + close_socket(socket); + return false; } - shutdown_socket(socket); - close_socket(socket); - return false; + // Set SNI before handshake (only if host is not IP address) + if (!detail::is_ip_address(host_)) { + if (!tls_set_sni(session, host_.c_str())) { + tls_free_session(session); + error = Error::SSLConnection; + last_openssl_error_ = ERR_get_error(); + shutdown_socket(socket); + close_socket(socket); + return false; + } + } + + // Perform non-blocking TLS handshake with timeout + TlsError tls_err; + if (!tls_connect_nonblocking(session, socket.sock, connection_timeout_sec_, + connection_timeout_usec_, &tls_err)) { + // Map TlsError to legacy ssl_error for backward compatibility + if (tls_err.code == ErrorCode::WantRead) { + last_ssl_error_ = SSL_ERROR_WANT_READ; + } else if (tls_err.code == ErrorCode::WantWrite) { + last_ssl_error_ = SSL_ERROR_WANT_WRITE; + } else { + last_ssl_error_ = SSL_ERROR_SSL; + } + tls_free_session(session); + error = Error::SSLConnection; + output_error_log(error, nullptr); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + // Server certificate verification + if (server_certificate_verification_) { + // Cast to SSL* for backward compatibility with verifier callback + auto ssl = static_cast(session); + + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (server_certificate_verifier_) { + verification_status = server_certificate_verifier_(ssl); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + tls_free_session(session); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { + verify_result_ = tls_get_verify_result(session); + + if (verify_result_ != X509_V_OK) { + last_openssl_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + tls_free_session(session); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + auto server_cert = tls_get_peer_cert(session); + if (!server_cert) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + tls_free_session(session); + shutdown_socket(socket); + close_socket(socket); + return false; + } + auto se = detail::scope_exit([&] { tls_free_cert(server_cert); }); + + if (server_hostname_verification_) { + // verify_host() expects X509*, so cast from tls_cert_t + if (!verify_host(static_cast(server_cert))) { + last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; + error = Error::SSLServerHostnameVerification; + output_error_log(error, nullptr); + tls_free_session(session); + shutdown_socket(socket); + close_socket(socket); + return false; + } + } + } + } + + socket.session = session; + return true; } inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { @@ -13631,26 +13787,30 @@ inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { inline void SSLClient::shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully) { if (socket.sock == INVALID_SOCKET) { - assert(socket.ssl == nullptr); + assert(socket.session == nullptr); return; } - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, - shutdown_gracefully); - socket.ssl = nullptr; + if (socket.session) { + // Temporarily cast session to SSL* for ssl_delete() + // TODO: Replace with tls_shutdown/tls_free_session when ready + detail::ssl_delete(ctx_mutex_, static_cast(socket.session), + socket.sock, shutdown_gracefully); + socket.session = nullptr; } - assert(socket.ssl == nullptr); + assert(socket.session == nullptr); } inline bool SSLClient::process_socket( const Socket &socket, std::chrono::time_point start_time, std::function callback) { - assert(socket.ssl); + assert(socket.session); + // Temporarily cast session to SSL* for process_client_socket_ssl() + // TODO: Update process_client_socket_ssl to use tls_session_t return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); + static_cast(socket.session), socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, + max_timeout_msec_, start_time, std::move(callback)); } inline bool SSLClient::is_ssl() const { return true; }