From 64e18d01f7c79a879e73ff7e01279e922d91ac44 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 20 Jan 2025 18:51:02 +0200 Subject: [PATCH 1/8] feat: TlsSocket AsyncWriteSome and AsyncReadSome Signed-off-by: kostas --- util/fibers/epoll_socket.cc | 1 + util/tls/CMakeLists.txt | 1 + util/tls/tls_socket.cc | 249 ++++++++++++++++++++++++++++++++++-- util/tls/tls_socket.h | 61 ++++++++- util/tls/tls_socket_test.cc | 246 +++++++++++++++++++++++++++++++++++ 5 files changed, 549 insertions(+), 9 deletions(-) create mode 100644 util/tls/tls_socket_test.cc diff --git a/util/fibers/epoll_socket.cc b/util/fibers/epoll_socket.cc index e9536fc1..1aaca377 100644 --- a/util/fibers/epoll_socket.cc +++ b/util/fibers/epoll_socket.cc @@ -388,6 +388,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress async_write_pending_ = 1; } +// TODO implement async functionality void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { auto res = ReadSome(v, len); cb(res); diff --git a/util/tls/CMakeLists.txt b/util/tls/CMakeLists.txt index 21fcba63..fb2fd27f 100644 --- a/util/tls/CMakeLists.txt +++ b/util/tls/CMakeLists.txt @@ -4,3 +4,4 @@ add_library(tls_lib tls_engine.cc tls_socket.cc) cxx_link(tls_lib fibers2 OpenSSL::SSL) cxx_test(tls_engine_test tls_lib LABELS CI) +cxx_test(tls_socket_test tls_lib LABELS CI) diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index 67229d57..5d2bb2e4 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -329,23 +329,217 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint return res; } -// TODO: to implement async functionality. +void TlsSocket::HandleOpAsync(int op_val) { + switch (op_val) { + case Engine::EOF_STREAM: + VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle(); + async_write_req_->caller_completion_cb( + make_unexpected(make_error_code(errc::connection_aborted))); + break; + case Engine::NEED_READ_AND_MAYBE_WRITE: + HandleUpstreamAsyncRead(); + break; + case Engine::NEED_WRITE: + MaybeSendOutputAsync(); + break; + default: + LOG(DFATAL) << "Unsupported " << op_val; + } +} + +void TlsSocket::AsyncWriteReq::Run() { + if (state == AsyncWriteReq::PushToEngine) { + io::Result push_res = owner->PushToEngine(vec, len); + if (!push_res) { + caller_completion_cb(make_unexpected(push_res.error())); + return; + } + last_push = *push_res; + state = AsyncWriteReq::HandleOpAsync; + } + + if (state == AsyncWriteReq::HandleOpAsync) { + state = AsyncWriteReq::MaybeSendOutputAsync; + if (last_push.engine_opcode < 0) { + owner->HandleOpAsync(last_push.engine_opcode); + } + } + + if (state == AsyncWriteReq::MaybeSendOutputAsync) { + state = AsyncWriteReq::PushToEngine; + if (last_push.written > 0) { + DCHECK(!continuation); + continuation = [this]() { + state = AsyncWriteReq::Done; + caller_completion_cb(last_push.written); + }; + owner->MaybeSendOutputAsync(); + } + } +} + void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = WriteSome(v, len); - cb(res); + CHECK(!async_write_req_.has_value()); + async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len)); + async_write_req_->Run(); } +void TlsSocket::AsyncReadReq::Run() { + DCHECK_GT(len, 0u); + + while (true) { + DCHECK(!dest.empty()); + + size_t read_len = std::min(dest.size(), size_t(INT_MAX)); + + Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len); + + int op_val = op_result; + + DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; + + if (op_val > 0) { + read_total += op_val; + + // I do not understand this code and what the hell I meant to do here. Seems to work + // though. + if (size_t(op_val) == read_len) { + if (size_t(op_val) < dest.size()) { + dest.remove_prefix(op_val); + } else { + ++vec; + --len; + if (len == 0) { + // We are done. Call completion callback. + caller_completion_cb(read_total); + return; + } + dest = Engine::MutableBuffer{reinterpret_cast(vec->iov_base), vec->iov_len}; + } + // We read everything we asked for but there are still buffers left to fill. + continue; + } + break; + } + + // Will automatically call Run() + owner->HandleOpAsync(op_val); + } + + // We are done. Call completion callback. + caller_completion_cb(read_total); + + // clean up so we can queue more reads +} -// TODO: to implement async functionality. void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = ReadSome(v, len); - cb(res); + CHECK(!async_read_req_.has_value()); + auto req = AsyncReadReq(this, std::move(cb), v, len); + req.dest = {reinterpret_cast(v->iov_base), v->iov_len}; + async_read_req_.emplace(std::move(req)); + async_read_req_->Run(); } SSL* TlsSocket::ssl_handle() { return engine_ ? engine_->native_handle() : nullptr; } +void TlsSocket::HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer) { + if (!write_result) { + state_ &= ~WRITE_IN_PROGRESS; + + // broken_pipe - happens when the other side closes the connection. do not log this. + if (write_result.error() != errc::broken_pipe) { + VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error(); + } + + // We are done. Errornous exit. + async_write_req_->caller_completion_cb(write_result); + return; + } + + CHECK_GT(*write_result, 0u); + upstream_write_ += *write_result; + engine_->ConsumeOutputBuf(*write_result); + buffer.remove_prefix(*write_result); + + // We are not done. Re-arm the async write until we drive it to completion or error. + if (!buffer.empty()) { + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); + } + + if (engine_->OutputPending() > 0) { + LOG(INFO) << "ssl buffer is not empty with " << engine_->OutputPending() + << " bytes. short write detected"; + } + + state_ &= ~WRITE_IN_PROGRESS; + + // If there is a continuation run it and let it yield back to the main loop + if (async_write_req_->continuation) { + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + + // Yield back to main loop + return async_write_req_->Run(); +} + +void TlsSocket::StartUpstreamWrite() { + Engine::Buffer buffer = engine_->PeekOutputBuf(); + DCHECK(!buffer.empty()); + DCHECK((state_ & WRITE_IN_PROGRESS) == 0); + + if (buffer.empty()) { + // We are done + return; + } + + DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); + // we do not allow concurrent writes from multiple fibers. + state_ |= WRITE_IN_PROGRESS; + + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); +} + +void TlsSocket::MaybeSendOutputAsync() { + if (engine_->OutputPending() == 0) { + if (async_write_req_->continuation) { + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + async_write_req_->Run(); + } + + // This function is present in both read and write paths. + // meaning that both of them can be called concurrently from differrent fibers and then + // race over flushing the output buffer. We use state_ to prevent that. + if (state_ & WRITE_IN_PROGRESS) { + if (async_write_req_->continuation) { + // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise + // we might deadlock. See the sync version of HandleOp for more info + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + } + + StartUpstreamWrite(); +} + auto TlsSocket::MaybeSendOutput() -> error_code { if (engine_->OutputPending() == 0) return {}; @@ -373,6 +567,46 @@ auto TlsSocket::MaybeSendOutput() -> error_code { return HandleUpstreamWrite(); } +void TlsSocket::StartUpstreamRead() { + auto buffer = engine_->PeekInputBuf(); + state_ |= READ_IN_PROGRESS; + + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { + state_ &= ~READ_IN_PROGRESS; + if (!read_result) { + // log any errors as well as situations where we have unflushed output. + if (read_result.error() != errc::connection_aborted || engine_->OutputPending() > 0) { + VSOCK(1) << "HandleUpstreamRead failed " << read_result.error(); + } + // Erronous path. Apply the completion callback and exit. + async_write_req_->caller_completion_cb(read_result); + return; + } + + DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; + engine_->CommitInput(*read_result); + // We are not done. Give back control to the main loop. + async_write_req_->Run(); + }); +} + +void TlsSocket::HandleUpstreamAsyncRead() { + auto on_success = [this]() { + if (state_ & READ_IN_PROGRESS) { + async_write_req_->Run(); + } + + StartUpstreamRead(); + }; + + async_write_req_->continuation = on_success; + MaybeSendOutputAsync(); +} + auto TlsSocket::HandleUpstreamRead() -> error_code { RETURN_ON_ERROR(MaybeSendOutput()); @@ -481,8 +715,7 @@ unsigned TlsSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { } void TlsSocket::ReturnProvided(const ProvidedBuffer& pbuf) { - proactor()->DeallocateBuffer( - io::MutableBytes{const_cast(pbuf.start), pbuf.allocated}); + proactor()->DeallocateBuffer(io::MutableBytes{const_cast(pbuf.start), pbuf.allocated}); } bool TlsSocket::IsUDS() const { diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 83a470b0..4fbeddbe 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -7,6 +7,7 @@ #include #include +#include #include "util/fiber_socket_base.h" #include "util/tls/tls_engine.h" @@ -90,7 +91,6 @@ class TlsSocket final : public FiberSocketBase { virtual void SetProactor(ProactorBase* p) override; private: - struct PushResult { size_t written = 0; int engine_opcode = 0; // Engine::OpCode @@ -114,6 +114,65 @@ class TlsSocket final : public FiberSocketBase { std::unique_ptr engine_; size_t upstream_write_ = 0; + struct AsyncReqBase { + AsyncReqBase(TlsSocket* owner, io::AsyncProgressCb caller_cb, const iovec* vec, uint32_t len) + : owner(owner), caller_completion_cb(std::move(caller_cb)), vec(vec), len(len) { + } + + TlsSocket* owner; + // Callback passed from the user. + io::AsyncProgressCb caller_completion_cb; + + const iovec* vec; + uint32_t len; + + std::function continuation; + }; + + struct AsyncWriteReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + iovec scratch_iovec; + // TODO simplify state transitions + // TODO handle async yields to avoid deadlocks (see HandleOp) + enum State { PushToEngine, HandleOpAsync, MaybeSendOutputAsync, Done }; + State state = PushToEngine; + PushResult last_push; + + // Main loop + void Run(); + }; + + friend AsyncWriteReq; + + struct AsyncReadReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + Engine::MutableBuffer dest; + size_t read_total = 0; + + // Main loop + void Run(); + }; + + friend AsyncReadReq; + + // Asynchronous helpers + void MaybeSendOutputAsync(); + + void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); + void HandleUpstreamAsyncRead(); + + void HandleOpAsync(int op_val); + + void StartUpstreamWrite(); + void StartUpstreamRead(); + + // TODO clean up the optional before we yield such that progress callback can dispatch another + // async operation + std::optional async_write_req_; + std::optional async_read_req_; + enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 }; uint8_t state_{0}; }; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc new file mode 100644 index 00000000..eae010dc --- /dev/null +++ b/util/tls/tls_socket_test.cc @@ -0,0 +1,246 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "util/tls/tls_socket.h" + +#include + +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "util/fiber_socket_base.h" +#include "util/fibers/fibers.h" +#include "util/fibers/synchronization.h" + +#ifdef __linux__ +#include "util/fibers/uring_proactor.h" +#include "util/fibers/uring_socket.h" +#endif +#include "util/fibers/epoll_proactor.h" + +namespace util { +namespace fb2 { + +constexpr uint32_t kRingDepth = 8; +using namespace testing; + +#ifdef __linux__ +void InitProactor(ProactorBase* p) { + if (p->GetKind() == ProactorBase::IOURING) { + static_cast(p)->Init(0, kRingDepth); + } else { + static_cast(p)->Init(0); + } +} +#else +void InitProactor(ProactorBase* p) { + static_cast(p)->Init(0); +} +#endif + +using namespace std; + +enum TlsContextRole { SERVER, CLIENT }; + +SSL_CTX* CreateSslCntx(TlsContextRole role) { + std::string tls_key_file; + std::string tls_key_cert; + std::string tls_ca_cert_file; + SSL_CTX* ctx; + + if (role == TlsContextRole::SERVER) { + ctx = SSL_CTX_new(TLS_server_method()); + // TODO init those to build on ci + } else { + ctx = SSL_CTX_new(TLS_client_method()); + } + unsigned mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + + bool res = SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1; + EXPECT_FALSE(res); + res = SSL_CTX_use_certificate_chain_file(ctx, tls_key_cert.c_str()) != 1; + EXPECT_FALSE(res); + res = SSL_CTX_load_verify_locations(ctx, tls_ca_cert_file.data(), nullptr) != 1; + EXPECT_FALSE(res); + res = 1 == SSL_CTX_set_cipher_list(ctx, "DEFAULT"); + EXPECT_TRUE(res); + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); + SSL_CTX_set_verify(ctx, mask, NULL); + SSL_CTX_set_dh_auto(ctx, 1); + return ctx; +} + +class TlsFiberSocketTest : public testing::TestWithParam { + protected: + void SetUp() final; + void TearDown() final; + + static void SetUpTestCase() { + testing::FLAGS_gtest_death_test_style = "threadsafe"; + } + + using IoResult = int; + + // TODO clean up + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + tls_socket_->InitSSL(CreateSslCntx(SERVER)); + tls_socket_->Accept(); + + uint8_t buf[16]; + auto res = tls_socket_->Recv(buf); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == 16); + + auto write_res = tls_socket_->Write(buf); + EXPECT_FALSE(write_res); + } + + unique_ptr proactor_; + thread proactor_thread_; + unique_ptr listen_socket_; + unique_ptr conn_socket_; + unique_ptr tls_socket_; + + uint16_t listen_port_ = 0; + Fiber accept_fb_; + Fiber conn_fb_; + std::error_code accept_ec_; + FiberSocketBase::endpoint_type listen_ep_; + uint32_t conn_sock_err_mask_ = 0; +}; + +INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTest, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +void TlsFiberSocketTest::SetUp() { +#if __linux__ + bool use_uring = GetParam() == "uring"; + ProactorBase* proactor = nullptr; + if (use_uring) + proactor = new UringProactor; + else + proactor = new EpollProactor; +#else + ProactorBase* proactor = new EpollProactor; +#endif + + atomic_bool init_done{false}; + + proactor_thread_ = thread{[proactor, &init_done] { + InitProactor(proactor); + init_done.store(true, memory_order_release); + proactor->Run(); + }}; + + proactor_.reset(proactor); + + error_code ec = proactor_->AwaitBrief([&] { + listen_socket_.reset(proactor_->CreateSocket()); + return listen_socket_->Listen(0, 0); + }); + + CHECK(!ec); + listen_port_ = listen_socket_->LocalEndpoint().port(); + DCHECK_GT(listen_port_, 0); + + auto address = boost::asio::ip::make_address("127.0.0.1"); + listen_ep_ = FiberSocketBase::endpoint_type{address, listen_port_}; + + accept_fb_ = proactor_->LaunchFiber("AcceptFb", [this] { + auto accept_res = listen_socket_->Accept(); + VLOG_IF(1, !accept_res) << "Accept res: " << accept_res.error(); + + if (accept_res) { + VLOG(1) << "Accepted connection " << *accept_res; + FiberSocketBase* sock = *accept_res; + conn_socket_.reset(sock); + conn_socket_->SetProactor(proactor_.get()); + conn_socket_->RegisterOnErrorCb([this](uint32_t mask) { + LOG(INFO) << "Error mask: " << mask; + conn_sock_err_mask_ = mask; + }); + conn_fb_ = proactor_->LaunchFiber([this]() { HandleRequest(); }); + } else { + accept_ec_ = accept_res.error(); + } + }); +} + +void TlsFiberSocketTest::TearDown() { + VLOG(1) << "TearDown"; + + proactor_->Await([&] { + std::ignore = listen_socket_->Shutdown(SHUT_RDWR); + if (conn_socket_) { + std::ignore = conn_socket_->Close(); + } else { + std::ignore = tls_socket_->Close(); + } + }); + + conn_fb_.JoinIfNeeded(); + accept_fb_.JoinIfNeeded(); + + // We close here because we need to wake up listening socket. + proactor_->Await([&] { std::ignore = listen_socket_->Close(); }); + + proactor_->Stop(); + proactor_thread_.join(); + proactor_.reset(); +} + +TEST_P(TlsFiberSocketTest, Basic) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + tls_sock->InitSSL(CreateSslCntx(CLIENT)); + + LOG(INFO) << "before wait "; + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + uint8_t buf[16] = {120}; + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + + tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + // TODO with iouring this max outs the memory and crashes + // TODO investigate why + tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); +} + +} // namespace fb2 +} // namespace util From 701425ac609a6135d7549358e7c53a1b1d5aada3 Mon Sep 17 00:00:00 2001 From: kostas Date: Fri, 7 Feb 2025 16:39:32 +0200 Subject: [PATCH 2/8] fixes and certificates Signed-off-by: kostas --- util/tls/CMakeLists.txt | 1 + util/tls/certificates/ca-cert.pem | 34 ++++++++ util/tls/certificates/ca-key.pem | 52 +++++++++++++ util/tls/certificates/server-cert.pem | 32 ++++++++ util/tls/certificates/server-key.pem | 52 +++++++++++++ util/tls/tls_socket.cc | 108 +++++++++++++------------- util/tls/tls_socket.h | 32 ++++---- util/tls/tls_socket_test.cc | 70 +++++++++++------ 8 files changed, 287 insertions(+), 94 deletions(-) create mode 100644 util/tls/certificates/ca-cert.pem create mode 100644 util/tls/certificates/ca-key.pem create mode 100644 util/tls/certificates/server-cert.pem create mode 100644 util/tls/certificates/server-key.pem diff --git a/util/tls/CMakeLists.txt b/util/tls/CMakeLists.txt index fb2fd27f..b595082f 100644 --- a/util/tls/CMakeLists.txt +++ b/util/tls/CMakeLists.txt @@ -5,3 +5,4 @@ add_library(tls_lib tls_engine.cc tls_socket.cc) cxx_link(tls_lib fibers2 OpenSSL::SSL) cxx_test(tls_engine_test tls_lib LABELS CI) cxx_test(tls_socket_test tls_lib LABELS CI) +target_compile_definitions(tls_socket_test PRIVATE TEST_CERT_PATH="${CMAKE_SOURCE_DIR}/util/tls/certificates") diff --git a/util/tls/certificates/ca-cert.pem b/util/tls/certificates/ca-cert.pem new file mode 100644 index 00000000..8681a870 --- /dev/null +++ b/util/tls/certificates/ca-cert.pem @@ -0,0 +1,34 @@ +-----BEGIN CERTIFICATE----- +MIIF5zCCA8+gAwIBAgIUN1lzJZ5fsK/ikGfiK9rR2VMCVyYwDQYJKoZIhvcNAQEL +BQAwgYExCzAJBgNVBAYTAkdSMQwwCgYDVQQIDANTS0cxFTATBgNVBAcMDFRoZXNz +YWxvbmlraTELMAkGA1UECgwCS0sxFDASBgNVBAsMC0FjbWVTdHVkaW9zMQswCQYD +VQQDDAJHcjEdMBsGCSqGSIb3DQEJARYOYWNtZUBnbWFpbC5jb20wIBcNMjUwMjA3 +MTQzMTE1WhgPMjA3NTAyMDcxNDMxMTVaMIGBMQswCQYDVQQGEwJHUjEMMAoGA1UE +CAwDU0tHMRUwEwYDVQQHDAxUaGVzc2Fsb25pa2kxCzAJBgNVBAoMAktLMRQwEgYD +VQQLDAtBY21lU3R1ZGlvczELMAkGA1UEAwwCR3IxHTAbBgkqhkiG9w0BCQEWDmFj +bWVAZ21haWwuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAg/Zy +kiHrSqi5aKe0oUHj2GyHZssYNoIKNQ3jIBfKZOtwIRcq3QbMmeVSJGpZtmU7nKok +j1QXok3H9JR4tWvCTQH0pGC8HPgWqxb6GpxbUS/3ZABsiTEwqPZlIIJsNo7nNdcq +hNSpGC0jmjQ/Mqn61rnZwrFnTPn2oyq8La2k3vPg3J4zds3Xh/LvZSuOdU2bHFTt +ptInc1/gIRbafFterF8LxW+d5I/ZwRaME1Rjdy1bn00cDeU0+JvSMn/h1/gIOSeD +t00JEB7CG1mtNmO8XaQNu1UvJQr5CIz4vDGg7t7fK3gOHBF3ygqt63pDeThpu9LU +bGKTOiOt+QMVLtbM2nR8jybkJCg6iYwezG9x9bCsZr0t9XfVpRvc4MaEr8rf4APo +4oYfcEnIWAo4niItPSOPcVxwlzyUA8bjq0NpR9chZMi6oC8Zr8TVbGS6B/mwrL37 +IpB/3sfXQTruAoBjhxjF0JkWAkFYrMrxzW584wnu9qXYPpijyT1OF1GBcWnvVyLu +XFL4N+Brrw5k9RCxVL2nHy3zlhQUn1RG442AI4DkrzbRB1ADlKcQr7yka2JrMpCv +pIqXxOiavqtthAnfX19gwf6A4AbSseLwqqu34aS1etJPdulD6KyeGYwLhEg5tKpN +/XXkSjNcyimJzTjErkWsYLwPSE8DNQetvupNAnMCAwEAAaNTMFEwHQYDVR0OBBYE +FBcHh0U0qUBoll5fP77vuHA7R1gCMB8GA1UdIwQYMBaAFBcHh0U0qUBoll5fP77v +uHA7R1gCMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggIBAGz7O995 +O+xOicmWKYaoPoeG3eldzSlsf4uJ6O/DrK1yyQXdUgDnt6BBq10oP4VaE5eTtwjb +RxFe1bp7magAv+cPq//Do2CNMc7rYDDhoCLYpRlvA5rxgHs7fvp6ImRDj4Q3nLcM +g41XmTlTjP6fPLjd0H1p5V7sayQ+l/0QMgMXJ3hJQtpq/lTDMBW5iKc5mHvABnfr +OPoP8JRaG6yHaCVuUr/RqoDV+1O/3zXLkRhb+pBOTvbtLn8G7XvJxcbXnX/gz/h7 +JZrCdCNgr6pXplqs3BlCX3clQhwpc1SRrxOw2wZdIQnbpEghyLFpYNjsuR70MsFN +pen7CcAcw3sq0Q5frMa+ygkRTNNmfuqvcbAZBs+BCEn4cnzCMQcHLW+m8YIGql6U +zdHjEofZ1oS4gI34NwsyTzouTxq/90gQvOZpXfAUxiPBSPwMWNQ3bA7xXbz6bvt5 +sgD/9jRlM0r6/cGs6jCQqpuAuvTiV0Pt9BM3BP8iiqn45QXLOLpbeIRTtKGv4gM3 +Q8AiNPitlg5MNaw87xCvbYkgfjAG08eaaema6za/4iK39od1/HtgRcPaTXf8biwr +aVnY3/e1OGGNNHckw3Y2nsMAfFiU1++cvg35PXdU+YkYzuuXiRL2bPs83YTtXWET +NKAA0KWByKALMro4PBlpF/kTxr2CbKVeYBHV +-----END CERTIFICATE----- diff --git a/util/tls/certificates/ca-key.pem b/util/tls/certificates/ca-key.pem new file mode 100644 index 00000000..313e879d --- /dev/null +++ b/util/tls/certificates/ca-key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQCD9nKSIetKqLlo +p7ShQePYbIdmyxg2ggo1DeMgF8pk63AhFyrdBsyZ5VIkalm2ZTucqiSPVBeiTcf0 +lHi1a8JNAfSkYLwc+BarFvoanFtRL/dkAGyJMTCo9mUggmw2juc11yqE1KkYLSOa +ND8yqfrWudnCsWdM+fajKrwtraTe8+DcnjN2zdeH8u9lK451TZscVO2m0idzX+Ah +Ftp8W16sXwvFb53kj9nBFowTVGN3LVufTRwN5TT4m9Iyf+HX+Ag5J4O3TQkQHsIb +Wa02Y7xdpA27VS8lCvkIjPi8MaDu3t8reA4cEXfKCq3rekN5OGm70tRsYpM6I635 +AxUu1szadHyPJuQkKDqJjB7Mb3H1sKxmvS31d9WlG9zgxoSvyt/gA+jihh9wSchY +CjieIi09I49xXHCXPJQDxuOrQ2lH1yFkyLqgLxmvxNVsZLoH+bCsvfsikH/ex9dB +Ou4CgGOHGMXQmRYCQVisyvHNbnzjCe72pdg+mKPJPU4XUYFxae9XIu5cUvg34Guv +DmT1ELFUvacfLfOWFBSfVEbjjYAjgOSvNtEHUAOUpxCvvKRrYmsykK+kipfE6Jq+ +q22ECd9fX2DB/oDgBtKx4vCqq7fhpLV60k926UPorJ4ZjAuESDm0qk39deRKM1zK +KYnNOMSuRaxgvA9ITwM1B62+6k0CcwIDAQABAoICACeYvp4swV6ArEnD8MZmcAjT +3/kvPc+1S3zJ8voBSYDoyJeVTQ5PaPtQvUoiA1NgoveKcjfzwre34ST5nBLMB9x1 +lsPwJuIGaz7hQSDVA+2jl/cQzYCJGxHIBWYw3GmujaAxNRfwe+C+Qq2VudTo/lSK +JdZuxxFo++HQA/Es5ojj4vgwHD8s2tx3P/A6lp+KLt3cegcRjjbncOhc5Chmfkz4 +pB6VNGqN44g1zMhMDSCIorJ1P9LHkRJ8JyFyEAFu8oC746EP44VLxXDRgtEMMkxi +2p/4mpHh7gHr0wMdXS3wAEUZ3Bn9/9THSZKb+D0aeVeblpQDLCxI4n4St7t7RrJj +7fcX2a/WeFWxUyt9VdnEkp404fLJVhpIET/vrFjfcQjxBEd20jgVzR9rnGEVz+eV +/ExHujZ5FkUYie0ol7Dr3CBj7Gf9/dTjsqORPgWvcwxBGSa9eUBs2Xf+o0jLajXS +ujhi5jNMmb6D4WW3yciC1UXP4azmNTbqoHjZAx+c4d52qL5CQcQNylgvPHD8qYUo +43yEa0ytdKd4m9x4/I0lOt8qEDNcXRQlmnc5i9A+OSbymP+rIRAR3gmG07Z2l2Nu +cE/YgrSRIxV2oIdpDhIPs/qUaTWFzYEkyjElP/DkzJB6aDH6w5Apx5HLu9YD4WNk +W9f/dEApjw/rrwsNVR1BAoIBAQC3ps3hNnWmNVEMq/HYcLzmM4nRjkoCRs9M4PG2 +zoAX1kDbcf6By8F+7AOUmJeXcCFY/v6sP6IyD0gPDzEMoSjAKj/ErzNL6VIvz0+j +1F9yl6Uv7wz5xYlU+h93hCyMaHcTWwdfrjvqOUGwpw+DauWMVZ4NM0qQuImOujT5 +5qO0dyRm201ZqnSXkZOJQQf7W13bYXKjQDFJswvCBupJsWlfaKXkmUoe4cAa2qwg +EC6jaUV+6cnTN5ESj3nm+UFWu9ka+1FZYHQq9+QqHkkNoEBNDPNp5ZaJF0rVxbpA +LV7k+IYUTtIPWaRoWsTn+5qq1opQzcwd+PiDYUfdX97eV9ehAoIBAQC38tZFz72v +VYA7U7Q+wD9VXqXBo3MFsf0o7pPHoIRlxtIzgDfnZ6dYiI5NWZATPwKxot12RHtO +LTk6RUUN9YWPhmRZFMlSztlP03f4+KBl+GWpFJZURVQvQy6zrxlj9kF9lhHPtEQp +km5tXH2eon/saVcSNq4gIYuhnqGBiUXq4FaNsD1Rb4cRn8zK2AcoOz2bCr++odTD +Et7JQoYv8REldG3Bm8OkO2dykgp8KeHuhK1ZbMgIo1IrSOq/1GHPFygzczjxbmlV +2fZu5HMr2kZPZC4z/h0oJOxwO4gW1MIW4n+BqrQQiYrDV/9NiW+m3hlCTDoZYL2r +l5jMbnZ6epGTAoIBAAPNe2PXadY4MmZtxQMzSmYF2Suyo4uqha5U1gxv+C0GLa+d +i6SKYIZNQsG36yOimb4rAYD1jFk3Acn2CZD2YU6hUVK0Qf6nZSFCTKbaxeMsiqoU +bBNb6L0OtMoXvYhmvVh0QRHVHL570wViYCrbcsdWGoCxeDDI8Wg4KNKn2OnqsaFD +lzVtFx7wT7q+0vh8atQZD3Ob56lcALlSxVUjTEhCdXTnS6aaDA1CS+AaFa0ih+LZ +2mj6NJHK0L5cmOK/3v31CDkuixk8qsfIesDCebJeu0eNDnHmPpFwl6uuhNF+59/R +xf21YHccsgkPp6Mz8Ac+S6SvPA4UXJWT35yA+yECggEAY7phVyoA0e1N/1wrLZY3 +AAa3YRtHgf/0m9t0/VbWUQOQ9OD/7hJxVPt8Aw2aogSYZkxBOxx3qXO1QhVKEf5Z +se2PvAgb+iww/ylMMwxAkegw3ZFOy1NnB9SpnjtBTcO1z+urrmsyRUOhYOMzK+03 +46lczoAcuUjWlgIV18/fuy5zXo/9Pohztydm9VZX0wUDKmqSeDzux/AUHxNVAur0 +e9T6qGvVjtWyCRiKXLSsTA9cmE30yVIae2Ml+mifupH4dqRya9qLe0MXaxmqI01M +r0BGGGQd1KToFxT/fDlPHO8hZ+BvjUO0mqG7xYLMqLBjC2GeYiHj0wL2kWcWDuMi +xwKCAQAV6l4+hmt4WPAZoLvgUxyhPf50c36Q8ZGrNMExi6aAmZc+hfURkd5FNjI4 +SAuQEe2yrYCVUkwzkMfqFa3YzWP6mSoHS3tCHRG4lOKq+dbBHOFtSYLu2U3T3J/b +bifP/kNi/c7do8MnJOH7YHE+DDKoWcVIk4kelJtwZIzxTS5PXYqisan1X/WHSsAy +CASp1z4cQjTYsyYO9w5pESlLYLIfoeppOJlEf76SY8vlwyTQXaVZpQgruNE4hPHk ++Z8zm660pmzkCnT8UR9lvdtlatbDRgsPPAO7GJd7RUSA2KTiy1eZ9AGjMhwQUSc1 +kJqZvaQ0+xLiJzLEFcBbj9W9ifuI +-----END PRIVATE KEY----- diff --git a/util/tls/certificates/server-cert.pem b/util/tls/certificates/server-cert.pem new file mode 100644 index 00000000..e6a613f1 --- /dev/null +++ b/util/tls/certificates/server-cert.pem @@ -0,0 +1,32 @@ +-----BEGIN CERTIFICATE----- +MIIFkDCCA3gCFAl+cQgAc8X1kKpQRGpm1yuClHAyMA0GCSqGSIb3DQEBCwUAMIGB +MQswCQYDVQQGEwJHUjEMMAoGA1UECAwDU0tHMRUwEwYDVQQHDAxUaGVzc2Fsb25p +a2kxCzAJBgNVBAoMAktLMRQwEgYDVQQLDAtBY21lU3R1ZGlvczELMAkGA1UEAwwC +R3IxHTAbBgkqhkiG9w0BCQEWDmFjbWVAZ21haWwuY29tMCAXDTI1MDIwNzE0MzEx +N1oYDzIwNzUwMjA3MTQzMTE3WjCBhDELMAkGA1UEBhMCR1IxDDAKBgNVBAgMA1NL +RzEVMBMGA1UEBwwMVGhlc3NhbG9uaWtpMQswCQYDVQQKDAJLSzENMAsGA1UECwwE +Q29tcDELMAkGA1UEAwwCR3IxJzAlBgkqhkiG9w0BCQEWGGRvZXNfbm90X2V4aXN0 +QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBANPj8Dd0 +1rMlAsO+sbusg8Lx+QM0gSHXsDSz9Gd7MAqGz+whj9bogmfkWQVuMZLtGXeYTFOX +BRve/VK+4v2ubr0S7qkQFOFJLsFhiFjjl5EdwqWroO38imRst5CVsEHopk9kF6L2 +gd/W9Gh0OnZK2XY514N5l/dtn4YIsx5yApPdMk6eOwSzjYgHyWFmAxqAtr8fED2I +q8GoBCSuECe8Zg0Yisf53L0RVhlBnb4YYA9XGKpNo/B1kXJfy/mXAzqS255EsPuq +rIMy8Zdx08BxuOITvbaP1PsvBuhzB380KCcqnY1DGaeKm+t5mBI8a2hpjI+CZmif +ZYS1PfVDURq7tGCyW1zG5lv7rm0Z2FqT5IMwaOwvNzIdoZL4/wi5uUXfoVGu3NTR +llNNAc8lHFeBm5NjXn321BwfgoCvh0LRMxW7pogdkmUmozcrxK0YNGCRWO3vP0nU +xLJZUf5THEpL5YIIFSzngB1NlaFjxDMOMzEub4VQkpFulfjmH/BXTDM9nfe0zXyR +QCfTsqVWpLD6ZQvWgirB0ywwpVihB9B+XxFER6/qkkA5WY0M4CXg8nT3SY5G1Zql +4CPO4JLt2XZNkOze+mnTkAJ9LXCdvsiWPwsYsP6T8sw05DhRJGG0j1VQ1eyWXnwT +Pn6gH9Y3BX/gO/JKxQNIhxgdG9pDpcTIn5y5AgMBAAEwDQYJKoZIhvcNAQELBQAD +ggIBAEL1mvCUfpGZGBdH+1kkQjOuw0Cq7OCzsywEu/XM/I64Y4z5GoAVatB573VC +BGK40NhenKMk2o5KmIUs8iO3psOMWCu0/HQ1fZnz7hI0RiJLau51yfq8RzHx54Dd +8y5P7JcuezPFeD9DsZWis4OHG1B/XpETwpApzwFympWMSmr7Y1LxWNtHDvv6d0iQ +yjyojWkN7zh0lbPxVM8teeG4FsIE0EKUKUanS3aj67AtXmAjYWthVf/ptt2uMr81 +BYMbwbotwsZ0KrDyyr3z/SZShocENV3l3Fsr2/vLBIlw4rDG815jChOo22YHut3V +2eLHkaCQ8nihRr0kqsIaiMlB9d2NzGobXxvx3g00GFUPiR8rOWer08Igb5HlEC87 +oxrlgBcIohcJUAMYZJCBkLDfOejPxTbVvaMvI4DVo9ujIuWjQy3zn+l7ZzqrGMa+ +Bt4UAMCarTMsv+j6zGlDMqVWyWWvtfD7iB28Y/UpsD6A2rbMGcVuRURwn2ZauTYE +PVgOCv7oFSgNzAGaHRFkbHgstWlY0RsW4kduFjcBf6fV71RO3OJGtePDLgQa7oIx +sw9CPJ/VeHEAOvBHVbFj1N3v490z597Z/lUzBGjsCGDsyIAj3z6BNfGT7gRn54NO +G2rnT9kvaqyC/46nrT24BaHFJAI05sLcN49Kr86WD8gTzNAU +-----END CERTIFICATE----- diff --git a/util/tls/certificates/server-key.pem b/util/tls/certificates/server-key.pem new file mode 100644 index 00000000..c39abd82 --- /dev/null +++ b/util/tls/certificates/server-key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQDT4/A3dNazJQLD +vrG7rIPC8fkDNIEh17A0s/RnezAKhs/sIY/W6IJn5FkFbjGS7Rl3mExTlwUb3v1S +vuL9rm69Eu6pEBThSS7BYYhY45eRHcKlq6Dt/IpkbLeQlbBB6KZPZBei9oHf1vRo +dDp2Stl2OdeDeZf3bZ+GCLMecgKT3TJOnjsEs42IB8lhZgMagLa/HxA9iKvBqAQk +rhAnvGYNGIrH+dy9EVYZQZ2+GGAPVxiqTaPwdZFyX8v5lwM6ktueRLD7qqyDMvGX +cdPAcbjiE722j9T7Lwbocwd/NCgnKp2NQxmnipvreZgSPGtoaYyPgmZon2WEtT31 +Q1Eau7RgsltcxuZb+65tGdhak+SDMGjsLzcyHaGS+P8IublF36FRrtzU0ZZTTQHP +JRxXgZuTY1599tQcH4KAr4dC0TMVu6aIHZJlJqM3K8StGDRgkVjt7z9J1MSyWVH+ +UxxKS+WCCBUs54AdTZWhY8QzDjMxLm+FUJKRbpX45h/wV0wzPZ33tM18kUAn07Kl +VqSw+mUL1oIqwdMsMKVYoQfQfl8RREev6pJAOVmNDOAl4PJ090mORtWapeAjzuCS +7dl2TZDs3vpp05ACfS1wnb7Ilj8LGLD+k/LMNOQ4USRhtI9VUNXsll58Ez5+oB/W +NwV/4DvySsUDSIcYHRvaQ6XEyJ+cuQIDAQABAoICABCsKRDoBePLN/w4g0op4SVD +/mzA4x9LVi9TcJn64LUXVgF1w5hsq6QqnNGXUdnGg8A8ENdr9Og0Q9kQsZI1+Tsx +4+sUG5x9cmsfdkfOQrUVsyTvjAZl3mrX/hqnmJqbCIkLLmvxexcmlg0pBhecPJ+3 +nexXszwyGUEF6rgszuSdHVH/09ODIFIRkMgz628YrSh9NH0vBZrDkm3jb1x9D8ec +hIEHOVX8KPrsRZH5X4edehCectWfHp5yCL3/Iq8ncpXxwD5RN+lL5yQcPgXwvNQ3 +KvCUQTUxhljggjjBXR511TdSDhD3kFy3MN7Qd8AbvAZnw5CcaDPIwhMMJLPWjLMB +WggcBh75oCILeZHjJ+gx7t/AKt/ZQu/ppy+USBynfdKqAol/CcEPj0C0D8448AyB +R5uf2KpKau9DFY9QTkYyKbv10z4jvFd72rvXQZTkwYNsvGmZ9gWy6Y8qHTokX/OA +u1V3NyRIfzmkT+UpD+GQklSQo14KYM4e7steNOFckyNlIukx0RGRGc7a7a8P0K9l +tQYKdjQbkHRlqCGGX+dO+jU++0Xz2UU6vycoqF3YGA5VnLF6wMggEX21A4vTnyQm +LUEPs6LWWqSsluEpaJjU3tdU4fIPo1AMPtTyMeCR8PituRrzgcJd+wVVncB3GH0V +TgvpeQ2bFC7j7we4foN9AoIBAQD72PhOFSk1s//JMhZDZWqOI57MjPoNlL6F2RIp +1hHc9eKgWoP4777kOgDtgc7C3JJDKsNPAqJFRmoLXAOhYWYOYqsmEOkG4ZjU9L5E ++qQJCyg0CmsZMBjmvOTV/yuPKNb/SDsh58z7s7pGENogXDH0vRW2Gxu8InhBmcff +t7OLgHX3/lWVtw8mjX+5VO+2NB0ycQdPRkY6bZG22BuLGizza8bbFpWNHjJZmeVu +0nJc31j9PN6eeIz4Z6wZu4xAtr6UVojw1BUxlmvZ5BMICYS6sK0M4zBLxSXQ5Wqh +lCUQd0xChOa/lLwIGqSe33GbZrXhtB4Sy8ne0+0iAIiXNUGdAoIBAQDXYk/qrQhw +nE57IFj0es84BHViPJrfJjn0R2yBLvmZedHUwEVTQhVnvB0ikRNJ2+mKiIAdkjLF +Yrb27wvpW6ItRlbWzix1nbC/RsuEO+BIsKdTFkXWJWbtDY5du18SBSqVr1wBwf1b +v6ql7vIrNbLJ9E5rDEH+UHqcj3nlJMwTJNF9Jpulc+eCH2ch5TuGwESKzV/9F/8U +okoXuHqByYfZF8UCgBNqNQVWT2DiirgaNw3Mza/0rQ8k7aZg4Z4WSUHsGTyXlq4J +jWlB0o4fxrusd4cYKwcMyMHtKGFSm7xBb1Z8PuwWAfVXszdlVYYBK8GKSJjhvH5B +K7fCOMsUJrrNAoIBAD1zYIr04NxIslXuUb5aJZjPGjVBBNaBf8d7AtKQeEVY/dYw +n9kC7qoTeRx6uu+TEGExMvy5YMzUdJWW+w/KizNhYe9k2uch7r+vhCmimpnWThX8 +oMtBkCHk31VT8NX6mhMqFbudKsgTv5TPEdophMr5xC1uCeNq6brgAgQVd8rHKoG7 +XjistRasGgknr2He72zaZXUzaXliONbLflT/qw4uMxRMO2t2fcSdJ31V/i+pE7ae +vpceRQ9rhHO27m7v6Cqbvsg8h3tU/7Xnz7j7UZaX+3GUkbk7PpHtGIqacjzFTyc/ +9Gm0qfi0P4zAaqEHe8O5xkjBztz/CvJr/OggR4UCggEBAIBXqN9iNEFGIs3jvJ4S +ACCVJ41eJ7sJAEe7t8BSyZDWsl5gI+801aR2x7WtVR0R2dwe8pisYWyVIgmK8EIh +xEXOQDjHql57lLKl8Ofe9gramRo9j2fH6ckf5tGbsU7/nRyM3fp+Kgbd80XlWJC+ +8sa8uW24ZCqysh1QsYYFo0VVDy/QLbctlapIJCBihFILh8xeDPC3t9wHyLbRys5D +1JtcOpz+zJLg/UktC8JyfrnATIzZlBvsc7XBlv7r8lO9W3bgouaBdzth9HKwkNgG +iBaBMxMHsK/BgS1cfoHHIyqquZJXvD5w9E1KEZxklfFkrXNFRzRcKa+T6W/mf7yG +R5ECggEBAJRQU9HXP7JAKhCah55MHQxOXbemGf6uyTXdDFIzt0757dLvg+YBNouF +A1VIF2rXlcOI5LZpchqnVLdYyD7FzClEK5Ae3CBsf+cL1M+3eYMWKJDg86L7g2GK +IAOzMo2/Ri/Tq+wZ5VEG4sDWw8nd1KCLZXT8gqg16Meb7Mz14Nj4RUXamcQbz6gV +fsq5NyVwVghmvYCCQSIvo3wYQejwZMdlS0kokrq65ouPoJNw8VH9hNrWxqfyyZZp +e2aNv9kD6VShkNyCdFWCvkvQj7Zbq36IBhUmsILPXOtWiXSVIRjxQ0bP6CixJED/ +VND/KvYrY0yf03p9a/p1DFIqLpRtGLs= +-----END PRIVATE KEY----- diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index 5d2bb2e4..c8076344 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -329,12 +329,11 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint return res; } -void TlsSocket::HandleOpAsync(int op_val) { +void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { switch (op_val) { case Engine::EOF_STREAM: - VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle(); - async_write_req_->caller_completion_cb( - make_unexpected(make_error_code(errc::connection_aborted))); + VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); + caller_completion_cb(make_unexpected(make_error_code(errc::connection_aborted))); break; case Engine::NEED_READ_AND_MAYBE_WRITE: HandleUpstreamAsyncRead(); @@ -355,17 +354,17 @@ void TlsSocket::AsyncWriteReq::Run() { return; } last_push = *push_res; - state = AsyncWriteReq::HandleOpAsync; + state = AsyncWriteReq::HandleOpAsyncTag; } - if (state == AsyncWriteReq::HandleOpAsync) { - state = AsyncWriteReq::MaybeSendOutputAsync; + if (state == AsyncWriteReq::HandleOpAsyncTag) { + state = AsyncWriteReq::MaybeSendOutputAsyncTag; if (last_push.engine_opcode < 0) { - owner->HandleOpAsync(last_push.engine_opcode); + HandleOpAsync(last_push.engine_opcode); } } - if (state == AsyncWriteReq::MaybeSendOutputAsync) { + if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { state = AsyncWriteReq::PushToEngine; if (last_push.written > 0) { DCHECK(!continuation); @@ -373,7 +372,7 @@ void TlsSocket::AsyncWriteReq::Run() { state = AsyncWriteReq::Done; caller_completion_cb(last_push.written); }; - owner->MaybeSendOutputAsync(); + MaybeSendOutputAsync(); } } } @@ -423,7 +422,7 @@ void TlsSocket::AsyncReadReq::Run() { } // Will automatically call Run() - owner->HandleOpAsync(op_val); + return HandleOpAsync(op_val); } // We are done. Call completion callback. @@ -444,57 +443,58 @@ SSL* TlsSocket::ssl_handle() { return engine_ ? engine_->native_handle() : nullptr; } -void TlsSocket::HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer) { +void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_result, + Engine::Buffer buffer) { if (!write_result) { - state_ &= ~WRITE_IN_PROGRESS; + owner->state_ &= ~WRITE_IN_PROGRESS; // broken_pipe - happens when the other side closes the connection. do not log this. if (write_result.error() != errc::broken_pipe) { - VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error(); + // VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error(); } // We are done. Errornous exit. - async_write_req_->caller_completion_cb(write_result); + caller_completion_cb(write_result); return; } CHECK_GT(*write_result, 0u); - upstream_write_ += *write_result; - engine_->ConsumeOutputBuf(*write_result); + owner->upstream_write_ += *write_result; + owner->engine_->ConsumeOutputBuf(*write_result); buffer.remove_prefix(*write_result); // We are not done. Re-arm the async write until we drive it to completion or error. if (!buffer.empty()) { - auto& scratch = async_write_req_->scratch_iovec; + auto& scratch = scratch_iovec; scratch.iov_base = const_cast(buffer.data()); scratch.iov_len = buffer.size(); - next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { HandleUpstreamAsyncWrite(write_result, buffer); }); } - if (engine_->OutputPending() > 0) { - LOG(INFO) << "ssl buffer is not empty with " << engine_->OutputPending() + if (owner->engine_->OutputPending() > 0) { + LOG(INFO) << "ssl buffer is not empty with " << owner->engine_->OutputPending() << " bytes. short write detected"; } - state_ &= ~WRITE_IN_PROGRESS; + owner->state_ &= ~WRITE_IN_PROGRESS; // If there is a continuation run it and let it yield back to the main loop - if (async_write_req_->continuation) { - auto cont = std::exchange(async_write_req_->continuation, std::function{}); + if (continuation) { + auto cont = std::exchange(continuation, std::function{}); cont(); return; } // Yield back to main loop - return async_write_req_->Run(); + return Run(); } -void TlsSocket::StartUpstreamWrite() { - Engine::Buffer buffer = engine_->PeekOutputBuf(); +void TlsSocket::AsyncReqBase::StartUpstreamWrite() { + Engine::Buffer buffer = owner->engine_->PeekOutputBuf(); DCHECK(!buffer.empty()); - DCHECK((state_ & WRITE_IN_PROGRESS) == 0); + DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0); if (buffer.empty()) { // We are done @@ -503,35 +503,35 @@ void TlsSocket::StartUpstreamWrite() { DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); // we do not allow concurrent writes from multiple fibers. - state_ |= WRITE_IN_PROGRESS; + owner->state_ |= WRITE_IN_PROGRESS; - auto& scratch = async_write_req_->scratch_iovec; + auto& scratch = scratch_iovec; scratch.iov_base = const_cast(buffer.data()); scratch.iov_len = buffer.size(); - next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { HandleUpstreamAsyncWrite(write_result, buffer); }); } -void TlsSocket::MaybeSendOutputAsync() { - if (engine_->OutputPending() == 0) { - if (async_write_req_->continuation) { - auto cont = std::exchange(async_write_req_->continuation, std::function{}); +void TlsSocket::AsyncReqBase::MaybeSendOutputAsync() { + if (owner->engine_->OutputPending() == 0) { + if (continuation) { + auto cont = std::exchange(continuation, std::function{}); cont(); return; } - async_write_req_->Run(); + Run(); } // This function is present in both read and write paths. // meaning that both of them can be called concurrently from differrent fibers and then // race over flushing the output buffer. We use state_ to prevent that. - if (state_ & WRITE_IN_PROGRESS) { - if (async_write_req_->continuation) { + if (owner->state_ & WRITE_IN_PROGRESS) { + if (continuation) { // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise // we might deadlock. See the sync version of HandleOp for more info - auto cont = std::exchange(async_write_req_->continuation, std::function{}); + auto cont = std::exchange(continuation, std::function{}); cont(); return; } @@ -567,43 +567,43 @@ auto TlsSocket::MaybeSendOutput() -> error_code { return HandleUpstreamWrite(); } -void TlsSocket::StartUpstreamRead() { - auto buffer = engine_->PeekInputBuf(); - state_ |= READ_IN_PROGRESS; +void TlsSocket::AsyncReqBase::StartUpstreamRead() { + auto buffer = owner->engine_->PeekInputBuf(); + owner->state_ |= READ_IN_PROGRESS; - auto& scratch = async_write_req_->scratch_iovec; + auto& scratch = scratch_iovec; scratch.iov_base = const_cast(buffer.data()); scratch.iov_len = buffer.size(); - next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { - state_ &= ~READ_IN_PROGRESS; + owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { + owner->state_ &= ~READ_IN_PROGRESS; if (!read_result) { // log any errors as well as situations where we have unflushed output. - if (read_result.error() != errc::connection_aborted || engine_->OutputPending() > 0) { - VSOCK(1) << "HandleUpstreamRead failed " << read_result.error(); + if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { + /// VSOCK(1) << "HandleUpstreamRead failed " << read_result.error(); } // Erronous path. Apply the completion callback and exit. - async_write_req_->caller_completion_cb(read_result); + caller_completion_cb(read_result); return; } DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; - engine_->CommitInput(*read_result); + owner->engine_->CommitInput(*read_result); // We are not done. Give back control to the main loop. - async_write_req_->Run(); + Run(); }); } -void TlsSocket::HandleUpstreamAsyncRead() { +void TlsSocket::AsyncReqBase::HandleUpstreamAsyncRead() { auto on_success = [this]() { - if (state_ & READ_IN_PROGRESS) { - async_write_req_->Run(); + if (owner->state_ & READ_IN_PROGRESS) { + Run(); } StartUpstreamRead(); }; - async_write_req_->continuation = on_success; + continuation = on_success; MaybeSendOutputAsync(); } diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 4fbeddbe..4479b546 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -127,20 +127,33 @@ class TlsSocket final : public FiberSocketBase { uint32_t len; std::function continuation; + iovec scratch_iovec; + + // Asynchronous helpers + void MaybeSendOutputAsync(); + + void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); + void HandleUpstreamAsyncRead(); + + void HandleOpAsync(int op_val); + + void StartUpstreamWrite(); + void StartUpstreamRead(); + + virtual void Run() = 0; }; struct AsyncWriteReq : AsyncReqBase { using AsyncReqBase::AsyncReqBase; - iovec scratch_iovec; // TODO simplify state transitions // TODO handle async yields to avoid deadlocks (see HandleOp) - enum State { PushToEngine, HandleOpAsync, MaybeSendOutputAsync, Done }; + enum State { PushToEngine, HandleOpAsyncTag, MaybeSendOutputAsyncTag, Done }; State state = PushToEngine; PushResult last_push; // Main loop - void Run(); + void Run() override; }; friend AsyncWriteReq; @@ -152,22 +165,11 @@ class TlsSocket final : public FiberSocketBase { size_t read_total = 0; // Main loop - void Run(); + void Run() override; }; friend AsyncReadReq; - // Asynchronous helpers - void MaybeSendOutputAsync(); - - void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); - void HandleUpstreamAsyncRead(); - - void HandleOpAsync(int op_val); - - void StartUpstreamWrite(); - void StartUpstreamRead(); - // TODO clean up the optional before we yield such that progress callback can dispatch another // async operation std::optional async_write_req_; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index eae010dc..d744f0ac 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -6,8 +6,10 @@ #include +#include #include +#include "absl/strings/str_cat.h" #include "base/gtest.h" #include "base/logging.h" #include "util/fiber_socket_base.h" @@ -45,9 +47,11 @@ using namespace std; enum TlsContextRole { SERVER, CLIENT }; SSL_CTX* CreateSslCntx(TlsContextRole role) { - std::string tls_key_file; - std::string tls_key_cert; - std::string tls_ca_cert_file; + std::string base_path = TEST_CERT_PATH; + std::string tls_key_file = absl::StrCat(base_path, "/server-key.pem"); + std::string tls_key_cert = absl::StrCat(base_path, "/server-cert.pem"); + std::string tls_ca_cert_file = absl::StrCat(base_path, "/ca-cert.pem"); + SSL_CTX* ctx; if (role == TlsContextRole::SERVER) { @@ -87,7 +91,8 @@ class TlsFiberSocketTest : public testing::TestWithParam { // TODO clean up virtual void HandleRequest() { tls_socket_ = std::make_unique(conn_socket_.release()); - tls_socket_->InitSSL(CreateSslCntx(SERVER)); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); tls_socket_->Accept(); uint8_t buf[16]; @@ -104,6 +109,7 @@ class TlsFiberSocketTest : public testing::TestWithParam { unique_ptr listen_socket_; unique_ptr conn_socket_; unique_ptr tls_socket_; + SSL_CTX* ssl_ctx_; uint16_t listen_port_ = 0; Fiber accept_fb_; @@ -197,11 +203,14 @@ void TlsFiberSocketTest::TearDown() { proactor_->Stop(); proactor_thread_.join(); proactor_.reset(); + + SSL_CTX_free(ssl_ctx_); } TEST_P(TlsFiberSocketTest, Basic) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); - tls_sock->InitSSL(CreateSslCntx(CLIENT)); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); LOG(INFO) << "before wait "; proactor_->Await([&] { @@ -209,29 +218,39 @@ TEST_P(TlsFiberSocketTest, Basic) { LOG(INFO) << "Connecting to " << listen_ep_; error_code ec = tls_sock->Connect(listen_ep_); - uint8_t buf[16] = {120}; - VLOG(1) << "Before writesome"; - - Done done; - iovec v{.iov_base = &buf, .iov_len = 16}; - - tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { - EXPECT_TRUE(result.has_value()); - EXPECT_EQ(*result, 16); - done.Notify(); - }); + EXPECT_FALSE(ec); + { + uint8_t buf[16]; + std::fill(std::begin(buf), std::end(buf), uint8_t(120)); + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + + tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); - done.Wait(); + done.Wait(); + } + { + uint8_t buf[16]; + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); - // TODO with iouring this max outs the memory and crashes - // TODO investigate why - tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { - EXPECT_TRUE(result.has_value()); - EXPECT_EQ(*result, 16); - done.Notify(); - }); + done.Wait(); - done.Wait(); + for (uint8_t c : buf) { + EXPECT_EQ(c, 120); + } + } VLOG(1) << "closing client sock " << tls_sock->native_handle(); std::ignore = tls_sock->Close(); @@ -240,6 +259,7 @@ TEST_P(TlsFiberSocketTest, Basic) { ASSERT_FALSE(ec) << ec.message(); ASSERT_FALSE(accept_ec_); }); + SSL_CTX_free(ssl_ctx); } } // namespace fb2 From 916678cfc4bb20e070df024a7edb475485c5f01a Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 10 Feb 2025 22:01:05 +0200 Subject: [PATCH 3/8] refactor, add fixes and tests Signed-off-by: kostas --- io/io.cc | 1 + util/tls/tls_socket.cc | 128 +++++++++++++++++++++--------------- util/tls/tls_socket.h | 13 +++- util/tls/tls_socket_test.cc | 99 +++++++++++++++++++++++++--- 4 files changed, 178 insertions(+), 63 deletions(-) diff --git a/io/io.cc b/io/io.cc index 3727ca5b..7606657a 100644 --- a/io/io.cc +++ b/io/io.cc @@ -50,6 +50,7 @@ struct AsyncReadState { AsyncReadState(AsyncSource* source, const iovec* v, uint32_t length) : arr(length), owner(source) { + cur = arr.data(); std::copy(v, v + length, arr.data()); } diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index c8076344..adf4bb77 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -329,14 +329,22 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint return res; } +void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result result) { + auto current = std::exchange(owner->async_write_req_, std::nullopt); + current->caller_completion_cb(result); +} + +void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result result) { + auto current = std::exchange(owner->async_read_req_, std::nullopt); + current->caller_completion_cb(result); +} + void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { switch (op_val) { case Engine::EOF_STREAM: - VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); - caller_completion_cb(make_unexpected(make_error_code(errc::connection_aborted))); break; case Engine::NEED_READ_AND_MAYBE_WRITE: - HandleUpstreamAsyncRead(); + MaybeSendOutputAsync(true); break; case Engine::NEED_WRITE: MaybeSendOutputAsync(); @@ -348,9 +356,12 @@ void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { void TlsSocket::AsyncWriteReq::Run() { if (state == AsyncWriteReq::PushToEngine) { + // We never preempt here io::Result push_res = owner->PushToEngine(vec, len); if (!push_res) { - caller_completion_cb(make_unexpected(push_res.error())); + CompleteAsyncReq(make_unexpected(push_res.error())); + // We are done with this AsyncWriteReq. Caller might started + // a new one. return; } last_push = *push_res; @@ -360,20 +371,25 @@ void TlsSocket::AsyncWriteReq::Run() { if (state == AsyncWriteReq::HandleOpAsyncTag) { state = AsyncWriteReq::MaybeSendOutputAsyncTag; if (last_push.engine_opcode < 0) { + if (last_push.engine_opcode == Engine::EOF_STREAM) { + VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); + CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); + return; + } HandleOpAsync(last_push.engine_opcode); + return; } } if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { state = AsyncWriteReq::PushToEngine; if (last_push.written > 0) { - DCHECK(!continuation); - continuation = [this]() { - state = AsyncWriteReq::Done; - caller_completion_cb(last_push.written); - }; + continuation = [this]() { CompleteAsyncReq(last_push.written); }; MaybeSendOutputAsync(); + return; } + // Run again we are not done. + Run(); } } @@ -410,7 +426,7 @@ void TlsSocket::AsyncReadReq::Run() { --len; if (len == 0) { // We are done. Call completion callback. - caller_completion_cb(read_total); + CompleteAsyncReq(read_total); return; } dest = Engine::MutableBuffer{reinterpret_cast(vec->iov_base), vec->iov_len}; @@ -421,14 +437,17 @@ void TlsSocket::AsyncReadReq::Run() { break; } - // Will automatically call Run() - return HandleOpAsync(op_val); + DCHECK(!continuation); + if (op_val == Engine::EOF_STREAM) { + VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); + CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); + } + HandleOpAsync(op_val); + return; } // We are done. Call completion callback. - caller_completion_cb(read_total); - - // clean up so we can queue more reads + CompleteAsyncReq(read_total); } void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { @@ -450,11 +469,14 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ // broken_pipe - happens when the other side closes the connection. do not log this. if (write_result.error() != errc::broken_pipe) { - // VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error(); + VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) + << ", write_total:" << owner->upstream_write_ << " " + << " pending output: " << owner->engine_->OutputPending() + << " HandleUpstreamAsyncWrite failed " << write_result.error(); } // We are done. Errornous exit. - caller_completion_cb(write_result); + CompleteAsyncReq(write_result); return; } @@ -468,7 +490,7 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ auto& scratch = scratch_iovec; scratch.iov_base = const_cast(buffer.data()); scratch.iov_len = buffer.size(); - owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + return owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { HandleUpstreamAsyncWrite(write_result, buffer); }); } @@ -480,7 +502,8 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ owner->state_ &= ~WRITE_IN_PROGRESS; - // If there is a continuation run it and let it yield back to the main loop + // If there is a continuation run it and let it yield back to the main loop. + // Continuation is responsible for calling AsyncRequest::Run again. if (continuation) { auto cont = std::exchange(continuation, std::function{}); cont(); @@ -488,7 +511,8 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ } // Yield back to main loop - return Run(); + Run(); + return; } void TlsSocket::AsyncReqBase::StartUpstreamWrite() { @@ -496,11 +520,6 @@ void TlsSocket::AsyncReqBase::StartUpstreamWrite() { DCHECK(!buffer.empty()); DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0); - if (buffer.empty()) { - // We are done - return; - } - DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); // we do not allow concurrent writes from multiple fibers. owner->state_ |= WRITE_IN_PROGRESS; @@ -514,29 +533,38 @@ void TlsSocket::AsyncReqBase::StartUpstreamWrite() { }); } -void TlsSocket::AsyncReqBase::MaybeSendOutputAsync() { - if (owner->engine_->OutputPending() == 0) { +void TlsSocket::AsyncReqBase::MaybeSendOutputAsync(bool should_read) { + auto body = [should_read, this]() { + if (should_read) { + return StartUpstreamRead(); + } if (continuation) { auto cont = std::exchange(continuation, std::function{}); - cont(); - return; + return cont(); } - Run(); + return Run(); + }; + + if (owner->engine_->OutputPending() == 0) { + return body(); } // This function is present in both read and write paths. // meaning that both of them can be called concurrently from differrent fibers and then // race over flushing the output buffer. We use state_ to prevent that. if (owner->state_ & WRITE_IN_PROGRESS) { - if (continuation) { - // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise - // we might deadlock. See the sync version of HandleOp for more info - auto cont = std::exchange(continuation, std::function{}); - cont(); - return; - } + // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise + // we might deadlock. See the sync version of HandleOp for more info + return body(); } + if (should_read) { + DCHECK(!continuation); + continuation = [this]() { + // Yields to Run internally() + StartUpstreamRead(); + }; + } StartUpstreamWrite(); } @@ -568,6 +596,12 @@ auto TlsSocket::MaybeSendOutput() -> error_code { } void TlsSocket::AsyncReqBase::StartUpstreamRead() { + if (owner->state_ & READ_IN_PROGRESS) { + // TODO we should yield instead + Run(); + return; + } + auto buffer = owner->engine_->PeekInputBuf(); owner->state_ |= READ_IN_PROGRESS; @@ -580,10 +614,13 @@ void TlsSocket::AsyncReqBase::StartUpstreamRead() { if (!read_result) { // log any errors as well as situations where we have unflushed output. if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { - /// VSOCK(1) << "HandleUpstreamRead failed " << read_result.error(); + VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) + << ", write_total:" << owner->upstream_write_ << " " + << " pending output: " << owner->engine_->OutputPending() << " " + << "StartUpstreamRead failed " << read_result.error(); } // Erronous path. Apply the completion callback and exit. - caller_completion_cb(read_result); + CompleteAsyncReq(read_result); return; } @@ -594,19 +631,6 @@ void TlsSocket::AsyncReqBase::StartUpstreamRead() { }); } -void TlsSocket::AsyncReqBase::HandleUpstreamAsyncRead() { - auto on_success = [this]() { - if (owner->state_ & READ_IN_PROGRESS) { - Run(); - } - - StartUpstreamRead(); - }; - - continuation = on_success; - MaybeSendOutputAsync(); -} - auto TlsSocket::HandleUpstreamRead() -> error_code { RETURN_ON_ERROR(MaybeSendOutput()); diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 4479b546..df45320a 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -130,10 +130,9 @@ class TlsSocket final : public FiberSocketBase { iovec scratch_iovec; // Asynchronous helpers - void MaybeSendOutputAsync(); + void MaybeSendOutputAsync(bool should_read = false); void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); - void HandleUpstreamAsyncRead(); void HandleOpAsync(int op_val); @@ -141,12 +140,18 @@ class TlsSocket final : public FiberSocketBase { void StartUpstreamRead(); virtual void Run() = 0; + virtual void CompleteAsyncReq(io::Result result) = 0; }; + // Helper function that resets the internal async request, applies the + // user AsyncProgressCb and returns. We need this, because progress callbacks + // can start another async request and for that to work, we need to clean up + // the one we are running on. + void CompleteAsyncRequest(io::Result result); + struct AsyncWriteReq : AsyncReqBase { using AsyncReqBase::AsyncReqBase; - // TODO simplify state transitions // TODO handle async yields to avoid deadlocks (see HandleOp) enum State { PushToEngine, HandleOpAsyncTag, MaybeSendOutputAsyncTag, Done }; State state = PushToEngine; @@ -154,6 +159,7 @@ class TlsSocket final : public FiberSocketBase { // Main loop void Run() override; + virtual void CompleteAsyncReq(io::Result result) override; }; friend AsyncWriteReq; @@ -166,6 +172,7 @@ class TlsSocket final : public FiberSocketBase { // Main loop void Run() override; + virtual void CompleteAsyncReq(io::Result result) override; }; friend AsyncReadReq; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index d744f0ac..6e53c657 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -207,25 +207,24 @@ void TlsFiberSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } -TEST_P(TlsFiberSocketTest, Basic) { +TEST_P(TlsFiberSocketTest, AsyncRW) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); - LOG(INFO) << "before wait "; proactor_->Await([&] { ThisFiber::SetName("ConnectFb"); LOG(INFO) << "Connecting to " << listen_ep_; error_code ec = tls_sock->Connect(listen_ep_); EXPECT_FALSE(ec); + uint8_t res[16]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); { - uint8_t buf[16]; - std::fill(std::begin(buf), std::end(buf), uint8_t(120)); VLOG(1) << "Before writesome"; Done done; - iovec v{.iov_base = &buf, .iov_len = 16}; + iovec v{.iov_base = &res, .iov_len = 16}; tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { EXPECT_TRUE(result.has_value()); @@ -247,9 +246,93 @@ TEST_P(TlsFiberSocketTest, Basic) { done.Wait(); - for (uint8_t c : buf) { - EXPECT_EQ(c, 120); - } + EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); + SSL_CTX_free(ssl_ctx); +} + +class TlsFiberSocketTestPartialRW : public TlsFiberSocketTest { + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); + tls_socket_->Accept(); + + uint8_t buf[payload_sz_]; + auto res = tls_socket_->ReadAtLeast(buf, payload_sz_); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == payload_sz_) << res.value(); + + absl::Span partial_write(buf, payload_sz_ / 2); + // We split the write to two small ones. + auto write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + } + + public: + static constexpr size_t payload_sz_ = 32768; +}; + +INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTestPartialRW, + testing::Values("epoll" +#ifdef __linux__ + // , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +TEST_P(TlsFiberSocketTestPartialRW, PartialAsyncReadWrite) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[payload_sz_]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = payload_sz_}; + + // TODO replace this to show that here are partial reads/writes + tls_sock->AsyncWrite(&v, 1, [done](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[payload_sz_]; + Done done; + iovec v{.iov_base = &buf, .iov_len = payload_sz_}; + + // TODO replace this to show that here are partial reads/writes + tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); } VLOG(1) << "closing client sock " << tls_sock->native_handle(); From 1b28ab8d1876d69a70d787ab0ddd20c3b129e387 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 7 Apr 2025 13:57:56 +0300 Subject: [PATCH 4/8] clean up --- util/tls/tls_socket.cc | 123 +++++++++++++++++++----------------- util/tls/tls_socket.h | 10 ++- util/tls/tls_socket_test.cc | 21 +++--- 3 files changed, 82 insertions(+), 72 deletions(-) diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index d7d66435..cb482330 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -334,19 +334,30 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result result) { auto current = std::exchange(owner->async_write_req_, std::nullopt); current->caller_completion_cb(result); + if (owner->pending_blocked_) { + auto* blocked = std::exchange(owner->pending_blocked_, nullptr); + blocked->Run(); + } } void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result result) { auto current = std::exchange(owner->async_read_req_, std::nullopt); current->caller_completion_cb(result); + // Run pending if blocked + if (owner->pending_blocked_) { + auto* blocked = std::exchange(owner->pending_blocked_, nullptr); + blocked->Run(); + } } void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { switch (op_val) { case Engine::EOF_STREAM: + VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); + CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); break; case Engine::NEED_READ_AND_MAYBE_WRITE: - MaybeSendOutputAsync(true); + MaybeSendOutputAsyncWithRead(); break; case Engine::NEED_WRITE: MaybeSendOutputAsync(); @@ -357,41 +368,40 @@ void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { } void TlsSocket::AsyncWriteReq::Run() { - if (state == AsyncWriteReq::PushToEngine) { - // We never preempt here - io::Result push_res = owner->PushToEngine(vec, len); - if (!push_res) { - CompleteAsyncReq(make_unexpected(push_res.error())); - // We are done with this AsyncWriteReq. Caller might started - // a new one. - return; + while (true) { + if (state == AsyncWriteReq::PushToEngine) { + // We never preempt here + io::Result push_res = owner->PushToEngine(vec, len); + if (!push_res) { + CompleteAsyncReq(make_unexpected(push_res.error())); + // We are done with this AsyncWriteReq. Caller might started + // a new one. + return; + } + last_push = *push_res; + state = AsyncWriteReq::HandleOpAsyncTag; } - last_push = *push_res; - state = AsyncWriteReq::HandleOpAsyncTag; - } - if (state == AsyncWriteReq::HandleOpAsyncTag) { - state = AsyncWriteReq::MaybeSendOutputAsyncTag; - if (last_push.engine_opcode < 0) { - if (last_push.engine_opcode == Engine::EOF_STREAM) { - VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); - CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); + if (state == AsyncWriteReq::HandleOpAsyncTag) { + state = AsyncWriteReq::MaybeSendOutputAsyncTag; + if (last_push.engine_opcode < 0) { + HandleOpAsync(last_push.engine_opcode); return; } - HandleOpAsync(last_push.engine_opcode); - return; } - } - if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { - state = AsyncWriteReq::PushToEngine; - if (last_push.written > 0) { - continuation = [this]() { CompleteAsyncReq(last_push.written); }; - MaybeSendOutputAsync(); - return; + if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { + state = AsyncWriteReq::PushToEngine; + if (last_push.written > 0) { + state = AsyncWriteReq::Done; + MaybeSendOutputAsync(); + return; + } + } + + if (state == AsyncWriteReq::Done) { + return CompleteAsyncReq(last_push.written); } - // Run again we are not done. - Run(); } } @@ -440,10 +450,6 @@ void TlsSocket::AsyncReadReq::Run() { } DCHECK(!continuation); - if (op_val == Engine::EOF_STREAM) { - VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); - CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); - } HandleOpAsync(op_val); return; } @@ -535,38 +541,43 @@ void TlsSocket::AsyncReqBase::StartUpstreamWrite() { }); } -void TlsSocket::AsyncReqBase::MaybeSendOutputAsync(bool should_read) { - auto body = [should_read, this]() { - if (should_read) { - return StartUpstreamRead(); - } - if (continuation) { - auto cont = std::exchange(continuation, std::function{}); - return cont(); - } - return Run(); - }; - +void TlsSocket::AsyncReqBase::MaybeSendOutputAsync() { if (owner->engine_->OutputPending() == 0) { - return body(); + return; } // This function is present in both read and write paths. // meaning that both of them can be called concurrently from differrent fibers and then // race over flushing the output buffer. We use state_ to prevent that. if (owner->state_ & WRITE_IN_PROGRESS) { - // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise - // we might deadlock. See the sync version of HandleOp for more info - return body(); + DCHECK(owner->pending_blocked_ == nullptr); + owner->pending_blocked_ = this; + return; } - if (should_read) { - DCHECK(!continuation); - continuation = [this]() { - // Yields to Run internally() - StartUpstreamRead(); - }; + StartUpstreamWrite(); +} + +void TlsSocket::AsyncReqBase::MaybeSendOutputAsyncWithRead() { + if (owner->engine_->OutputPending() == 0) { + return StartUpstreamRead(); } + + // This function is present in both read and write paths. + // meaning that both of them can be called concurrently from differrent fibers and then + // race over flushing the output buffer. We use state_ to prevent that. + if (owner->state_ & WRITE_IN_PROGRESS) { + DCHECK(owner->pending_blocked_ == nullptr); + owner->pending_blocked_ = this; + return; + } + + DCHECK(!continuation); + continuation = [this]() { + // Yields to Run internally() + StartUpstreamRead(); + }; + StartUpstreamWrite(); } @@ -599,8 +610,6 @@ auto TlsSocket::MaybeSendOutput() -> error_code { void TlsSocket::AsyncReqBase::StartUpstreamRead() { if (owner->state_ & READ_IN_PROGRESS) { - // TODO we should yield instead - Run(); return; } diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index df45320a..e3aadc0f 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -130,7 +130,10 @@ class TlsSocket final : public FiberSocketBase { iovec scratch_iovec; // Asynchronous helpers - void MaybeSendOutputAsync(bool should_read = false); + void MaybeSendOutputAsyncWithRead(); + + // Returns true if we did not start an upstream write + void MaybeSendOutputAsync(); void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); @@ -159,7 +162,7 @@ class TlsSocket final : public FiberSocketBase { // Main loop void Run() override; - virtual void CompleteAsyncReq(io::Result result) override; + void CompleteAsyncReq(io::Result result) override; }; friend AsyncWriteReq; @@ -172,7 +175,7 @@ class TlsSocket final : public FiberSocketBase { // Main loop void Run() override; - virtual void CompleteAsyncReq(io::Result result) override; + void CompleteAsyncReq(io::Result result) override; }; friend AsyncReadReq; @@ -181,6 +184,7 @@ class TlsSocket final : public FiberSocketBase { // async operation std::optional async_write_req_; std::optional async_read_req_; + AsyncReqBase* pending_blocked_ = nullptr; enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 }; uint8_t state_{0}; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index 823e00cd..847d752c 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -56,7 +56,6 @@ SSL_CTX* CreateSslCntx(TlsContextRole role) { if (role == TlsContextRole::SERVER) { ctx = SSL_CTX_new(TLS_server_method()); - // TODO init those to build on ci } else { ctx = SSL_CTX_new(TLS_client_method()); } @@ -216,7 +215,7 @@ TEST_P(TlsSocketTest, ShortWrite) { server_read_fb.Join(); } -class TlsFiberSocketTest : public testing::TestWithParam { +class AsyncTlsSocketTest : public testing::TestWithParam { protected: void SetUp() final; void TearDown() final; @@ -258,7 +257,7 @@ class TlsFiberSocketTest : public testing::TestWithParam { uint32_t conn_sock_err_mask_ = 0; }; -INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTest, +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, testing::Values("epoll" #ifdef __linux__ , @@ -267,7 +266,7 @@ INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTest, ), [](const auto& info) { return string(info.param); }); -void TlsFiberSocketTest::SetUp() { +void AsyncTlsSocketTest::SetUp() { #if __linux__ bool use_uring = GetParam() == "uring"; ProactorBase* proactor = nullptr; @@ -314,7 +313,7 @@ void TlsFiberSocketTest::SetUp() { }); } -void TlsFiberSocketTest::TearDown() { +void AsyncTlsSocketTest::TearDown() { VLOG(1) << "TearDown"; proactor_->Await([&] { @@ -338,7 +337,7 @@ void TlsFiberSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } -TEST_P(TlsFiberSocketTest, AsyncRW) { +TEST_P(AsyncTlsSocketTest, AsyncRW) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); @@ -390,7 +389,7 @@ TEST_P(TlsFiberSocketTest, AsyncRW) { SSL_CTX_free(ssl_ctx); } -class TlsFiberSocketTestPartialRW : public TlsFiberSocketTest { +class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { virtual void HandleRequest() { tls_socket_ = std::make_unique(conn_socket_.release()); ssl_ctx_ = CreateSslCntx(SERVER); @@ -414,7 +413,7 @@ class TlsFiberSocketTestPartialRW : public TlsFiberSocketTest { static constexpr size_t payload_sz_ = 32768; }; -INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTestPartialRW, +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, testing::Values("epoll" #ifdef __linux__ // , @@ -423,7 +422,7 @@ INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTestPartialRW, ), [](const auto& info) { return string(info.param); }); -TEST_P(TlsFiberSocketTestPartialRW, PartialAsyncReadWrite) { +TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); @@ -442,8 +441,7 @@ TEST_P(TlsFiberSocketTestPartialRW, PartialAsyncReadWrite) { Done done; iovec v{.iov_base = &res, .iov_len = payload_sz_}; - // TODO replace this to show that here are partial reads/writes - tls_sock->AsyncWrite(&v, 1, [done](auto result) mutable { + tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { EXPECT_FALSE(result); done.Notify(); }); @@ -455,7 +453,6 @@ TEST_P(TlsFiberSocketTestPartialRW, PartialAsyncReadWrite) { Done done; iovec v{.iov_base = &buf, .iov_len = payload_sz_}; - // TODO replace this to show that here are partial reads/writes tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { EXPECT_FALSE(result); done.Notify(); From 8187dbc5c839047bc3e39c895a3ce33566bb08ed Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 7 Apr 2025 17:29:22 +0300 Subject: [PATCH 5/8] add renegotiate test --- util/tls/tls_socket.cc | 26 ++- util/tls/tls_socket.h | 2 + util/tls/tls_socket_test.cc | 367 ++++++++++++++++++++++-------------- 3 files changed, 241 insertions(+), 154 deletions(-) diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index cb482330..2f803a06 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -334,20 +334,11 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result result) { auto current = std::exchange(owner->async_write_req_, std::nullopt); current->caller_completion_cb(result); - if (owner->pending_blocked_) { - auto* blocked = std::exchange(owner->pending_blocked_, nullptr); - blocked->Run(); - } } void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result result) { auto current = std::exchange(owner->async_read_req_, std::nullopt); current->caller_completion_cb(result); - // Run pending if blocked - if (owner->pending_blocked_) { - auto* blocked = std::exchange(owner->pending_blocked_, nullptr); - blocked->Run(); - } } void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { @@ -482,6 +473,8 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ << " pending output: " << owner->engine_->OutputPending() << " HandleUpstreamAsyncWrite failed " << write_result.error(); } + // Run pending if blocked + RunPending(); // We are done. Errornous exit. CompleteAsyncReq(write_result); @@ -491,7 +484,9 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ CHECK_GT(*write_result, 0u); owner->upstream_write_ += *write_result; owner->engine_->ConsumeOutputBuf(*write_result); - buffer.remove_prefix(*write_result); + // We could preempt while calling WriteSome, and the engine could get more data to write. + // Therefore we sync the buffer. + buffer = owner->engine_->PeekOutputBuf(); // We are not done. Re-arm the async write until we drive it to completion or error. if (!buffer.empty()) { @@ -510,6 +505,9 @@ void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_ owner->state_ &= ~WRITE_IN_PROGRESS; + // Run pending if blocked + RunPending(); + // If there is a continuation run it and let it yield back to the main loop. // Continuation is responsible for calling AsyncRequest::Run again. if (continuation) { @@ -622,6 +620,7 @@ void TlsSocket::AsyncReqBase::StartUpstreamRead() { owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { owner->state_ &= ~READ_IN_PROGRESS; + RunPending(); if (!read_result) { // log any errors as well as situations where we have unflushed output. if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { @@ -642,6 +641,13 @@ void TlsSocket::AsyncReqBase::StartUpstreamRead() { }); } +void TlsSocket::AsyncReqBase::RunPending() { + if (owner->pending_blocked_) { + auto* blocked = std::exchange(owner->pending_blocked_, nullptr); + blocked->Run(); + } +} + auto TlsSocket::HandleUpstreamRead() -> error_code { RETURN_ON_ERROR(MaybeSendOutput()); diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index e3aadc0f..a0bf2b3b 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -142,6 +142,8 @@ class TlsSocket final : public FiberSocketBase { void StartUpstreamWrite(); void StartUpstreamRead(); + void RunPending(); + virtual void Run() = 0; virtual void CompleteAsyncReq(io::Result result) = 0; }; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index 847d752c..7a5cff1a 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -91,15 +91,15 @@ class TlsSocketTest : public testing::TestWithParam { FiberSocketBase::endpoint_type listen_ep_; }; -INSTANTIATE_TEST_SUITE_P(Engines, TlsSocketTest, - testing::Values("epoll" -#ifdef __linux__ - , - "uring" -#endif - ), - [](const auto& info) { return string(info.param); }); - +// INSTANTIATE_TEST_SUITE_P(Engines, TlsSocketTest, +// testing::Values("epoll" +//#ifdef __linux__ +// , +// "uring" +//#endif +// ), +// [](const auto& info) { return string(info.param); }); +// void TlsSocketTest::SetUp() { #if __linux__ bool use_uring = GetParam() == "uring"; @@ -165,55 +165,55 @@ void TlsSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } -TEST_P(TlsSocketTest, ShortWrite) { - unique_ptr client_sock; - { - SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); - - proactor_->Await([&] { - client_sock.reset(new tls::TlsSocket(proactor_->CreateSocket())); - client_sock->InitSSL(ssl_ctx); - }); - SSL_CTX_free(ssl_ctx); - } - - error_code ec = proactor_->Await([&] { - LOG(INFO) << "Connecting to " << listen_ep_; - return client_sock->Connect(listen_ep_); - }); - ASSERT_FALSE(ec) << ec.message(); - - auto client_fb = proactor_->LaunchFiber([&] { - uint8_t buf[256]; - iovec iov{buf, sizeof(buf)}; - - client_sock->ReadSome(&iov, 1); - }); - - // Server side. - auto server_read_fb = proactor_->LaunchFiber([&] { - // This read actually causes the fiber to flush pending writes and preempt on iouring. - uint8_t buf[256]; - iovec iov; - iov.iov_base = buf; - iov.iov_len = sizeof(buf); - server_socket_->ReadSome(&iov, 1); - }); - - auto write_res = proactor_->Await([&] { - ThisFiber::Yield(); - uint8_t buf[16] = {0}; - - VLOG(1) << "Writing to client"; - return server_socket_->Write(buf); - }); - - ASSERT_FALSE(write_res) << write_res; - LOG(INFO) << "Finished"; - client_fb.Join(); - proactor_->Await([&] { std::ignore = client_sock->Close(); }); - server_read_fb.Join(); -} +// TEST_P(TlsSocketTest, ShortWrite) { +// unique_ptr client_sock; +// { +// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); +// +// proactor_->Await([&] { +// client_sock.reset(new tls::TlsSocket(proactor_->CreateSocket())); +// client_sock->InitSSL(ssl_ctx); +// }); +// SSL_CTX_free(ssl_ctx); +// } +// +// error_code ec = proactor_->Await([&] { +// LOG(INFO) << "Connecting to " << listen_ep_; +// return client_sock->Connect(listen_ep_); +// }); +// ASSERT_FALSE(ec) << ec.message(); +// +// auto client_fb = proactor_->LaunchFiber([&] { +// uint8_t buf[256]; +// iovec iov{buf, sizeof(buf)}; +// +// client_sock->ReadSome(&iov, 1); +// }); +// +// // Server side. +// auto server_read_fb = proactor_->LaunchFiber([&] { +// // This read actually causes the fiber to flush pending writes and preempt on iouring. +// uint8_t buf[256]; +// iovec iov; +// iov.iov_base = buf; +// iov.iov_len = sizeof(buf); +// server_socket_->ReadSome(&iov, 1); +// }); +// +// auto write_res = proactor_->Await([&] { +// ThisFiber::Yield(); +// uint8_t buf[16] = {0}; +// +// VLOG(1) << "Writing to client"; +// return server_socket_->Write(buf); +// }); +// +// ASSERT_FALSE(write_res) << write_res; +// LOG(INFO) << "Finished"; +// client_fb.Join(); +// proactor_->Await([&] { std::ignore = client_sock->Close(); }); +// server_read_fb.Join(); +// } class AsyncTlsSocketTest : public testing::TestWithParam { protected: @@ -257,15 +257,16 @@ class AsyncTlsSocketTest : public testing::TestWithParam { uint32_t conn_sock_err_mask_ = 0; }; -INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, - testing::Values("epoll" -#ifdef __linux__ - , - "uring" -#endif - ), - [](const auto& info) { return string(info.param); }); - +// Epoll is blocking so this test works only on iouring +// INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, +// testing::Values( +//#ifdef __linux__ +// , +// "uring" +//#endif +// ), +// [](const auto& info) { return string(info.param); }); +// void AsyncTlsSocketTest::SetUp() { #if __linux__ bool use_uring = GetParam() == "uring"; @@ -336,60 +337,144 @@ void AsyncTlsSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } +// +// TEST_P(AsyncTlsSocketTest, AsyncRW) { +// unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); +// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); +// tls_sock->InitSSL(ssl_ctx); +// +// proactor_->Await([&] { +// ThisFiber::SetName("ConnectFb"); +// +// LOG(INFO) << "Connecting to " << listen_ep_; +// error_code ec = tls_sock->Connect(listen_ep_); +// EXPECT_FALSE(ec); +// uint8_t res[16]; +// std::fill(std::begin(res), std::end(res), uint8_t(120)); +// { +// VLOG(1) << "Before writesome"; +// +// Done done; +// iovec v{.iov_base = &res, .iov_len = 16}; +// +// tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { +// EXPECT_TRUE(result.has_value()); +// EXPECT_EQ(*result, 16); +// done.Notify(); +// }); +// +// done.Wait(); +// } +// { +// uint8_t buf[16]; +// Done done; +// iovec v{.iov_base = &buf, .iov_len = 16}; +// tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { +// EXPECT_TRUE(result.has_value()); +// EXPECT_EQ(*result, 16); +// done.Notify(); +// }); +// +// done.Wait(); +// +// EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); +// } +// +// VLOG(1) << "closing client sock " << tls_sock->native_handle(); +// std::ignore = tls_sock->Close(); +// accept_fb_.Join(); +// VLOG(1) << "After join"; +// ASSERT_FALSE(ec) << ec.message(); +// ASSERT_FALSE(accept_ec_); +// }); +// SSL_CTX_free(ssl_ctx); +//} -TEST_P(AsyncTlsSocketTest, AsyncRW) { - unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); - SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); - tls_sock->InitSSL(ssl_ctx); - - proactor_->Await([&] { - ThisFiber::SetName("ConnectFb"); - - LOG(INFO) << "Connecting to " << listen_ep_; - error_code ec = tls_sock->Connect(listen_ep_); - EXPECT_FALSE(ec); - uint8_t res[16]; - std::fill(std::begin(res), std::end(res), uint8_t(120)); - { - VLOG(1) << "Before writesome"; - - Done done; - iovec v{.iov_base = &res, .iov_len = 16}; - - tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { - EXPECT_TRUE(result.has_value()); - EXPECT_EQ(*result, 16); - done.Notify(); - }); - - done.Wait(); - } - { - uint8_t buf[16]; - Done done; - iovec v{.iov_base = &buf, .iov_len = 16}; - tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { - EXPECT_TRUE(result.has_value()); - EXPECT_EQ(*result, 16); - done.Notify(); - }); +class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); + tls_socket_->Accept(); - done.Wait(); + uint8_t buf[payload_sz_]; + auto res = tls_socket_->ReadAtLeast(buf, payload_sz_); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == payload_sz_) << res.value(); - EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); - } + absl::Span partial_write(buf, payload_sz_ / 2); + // We split the write to two small ones. + auto write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + } - VLOG(1) << "closing client sock " << tls_sock->native_handle(); - std::ignore = tls_sock->Close(); - accept_fb_.Join(); - VLOG(1) << "After join"; - ASSERT_FALSE(ec) << ec.message(); - ASSERT_FALSE(accept_ec_); - }); - SSL_CTX_free(ssl_ctx); -} + public: + static constexpr size_t payload_sz_ = 32768; +}; -class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { +// INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, +// testing::Values("epoll" +//#ifdef __linux__ +// , +// "uring" +//#endif +// ), +// [](const auto& info) { return string(info.param); }); +// +// TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { +// unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); +// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); +// tls_sock->InitSSL(ssl_ctx); +// +// proactor_->Await([&] { +// ThisFiber::SetName("ConnectFb"); +// +// LOG(INFO) << "Connecting to " << listen_ep_; +// error_code ec = tls_sock->Connect(listen_ep_); +// EXPECT_FALSE(ec); +// uint8_t res[payload_sz_]; +// std::fill(std::begin(res), std::end(res), uint8_t(120)); +// { +// VLOG(1) << "Before writesome"; +// +// Done done; +// iovec v{.iov_base = &res, .iov_len = payload_sz_}; +// +// tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { +// EXPECT_FALSE(result); +// done.Notify(); +// }); +// +// done.Wait(); +// } +// { +// uint8_t buf[payload_sz_]; +// Done done; +// iovec v{.iov_base = &buf, .iov_len = payload_sz_}; +// +// tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { +// EXPECT_FALSE(result); +// done.Notify(); +// }); +// +// done.Wait(); +// +// EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); +// } +// +// VLOG(1) << "closing client sock " << tls_sock->native_handle(); +// std::ignore = tls_sock->Close(); +// accept_fb_.Join(); +// VLOG(1) << "After join"; +// ASSERT_FALSE(ec) << ec.message(); +// ASSERT_FALSE(accept_ec_); +// }); +// SSL_CTX_free(ssl_ctx); +// } + +class AsyncTlsSocketRenegotiate : public AsyncTlsSocketTest { virtual void HandleRequest() { tls_socket_ = std::make_unique(conn_socket_.release()); ssl_ctx_ = CreateSslCntx(SERVER); @@ -409,20 +494,20 @@ class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { EXPECT_FALSE(write_res); } + tls::TlsSocket* Handle() { + return tls_socket_.get(); + } + public: static constexpr size_t payload_sz_ = 32768; }; -INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, - testing::Values("epoll" #ifdef __linux__ - // , - "uring" -#endif - ), +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketRenegotiate, testing::Values("uring"), [](const auto& info) { return string(info.param); }); +#endif -TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { +TEST_P(AsyncTlsSocketRenegotiate, Renegotiate) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); @@ -430,37 +515,32 @@ TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { proactor_->Await([&] { ThisFiber::SetName("ConnectFb"); - LOG(INFO) << "Connecting to " << listen_ep_; error_code ec = tls_sock->Connect(listen_ep_); EXPECT_FALSE(ec); + + uint8_t send_buf[payload_sz_]; uint8_t res[payload_sz_]; - std::fill(std::begin(res), std::end(res), uint8_t(120)); + std::fill(std::begin(send_buf), std::end(send_buf), uint8_t(120)); { - VLOG(1) << "Before writesome"; + Done done_read, done_write; + iovec send_vec{.iov_base = &send_buf, .iov_len = payload_sz_}; + iovec read_vec{.iov_base = &res, .iov_len = payload_sz_}; - Done done; - iovec v{.iov_base = &res, .iov_len = payload_sz_}; - - tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { + // We don't need to call ssl_renegotiate here, the first read will also negotiate the protocol + tls_sock->AsyncRead(&read_vec, 1, [&](auto result) mutable { EXPECT_FALSE(result); - done.Notify(); + done_read.Notify(); }); - done.Wait(); - } - { - uint8_t buf[payload_sz_]; - Done done; - iovec v{.iov_base = &buf, .iov_len = payload_sz_}; - - tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { + // Here AsyncWrite will resume later since write_in_progress bit is set + tls_sock->AsyncWrite(&send_vec, 1, [&](auto result) mutable { EXPECT_FALSE(result); - done.Notify(); + done_write.Notify(); }); - done.Wait(); - - EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); + done_write.Wait(); + done_read.Wait(); + EXPECT_EQ(memcmp(begin(res), begin(send_buf), payload_sz_), 0); } VLOG(1) << "closing client sock " << tls_sock->native_handle(); @@ -472,6 +552,5 @@ TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { }); SSL_CTX_free(ssl_ctx); } - } // namespace fb2 } // namespace util From 0385a16fb7d7a403bdb96cf60e2fd70e10c50d9a Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 7 Apr 2025 17:33:36 +0300 Subject: [PATCH 6/8] check in commented out code --- util/tls/tls_socket_test.cc | 357 ++++++++++++++++++------------------ 1 file changed, 178 insertions(+), 179 deletions(-) diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index 7a5cff1a..2fed738a 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -91,15 +91,15 @@ class TlsSocketTest : public testing::TestWithParam { FiberSocketBase::endpoint_type listen_ep_; }; -// INSTANTIATE_TEST_SUITE_P(Engines, TlsSocketTest, -// testing::Values("epoll" -//#ifdef __linux__ -// , -// "uring" -//#endif -// ), -// [](const auto& info) { return string(info.param); }); -// +INSTANTIATE_TEST_SUITE_P(Engines, TlsSocketTest, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + void TlsSocketTest::SetUp() { #if __linux__ bool use_uring = GetParam() == "uring"; @@ -165,55 +165,55 @@ void TlsSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } -// TEST_P(TlsSocketTest, ShortWrite) { -// unique_ptr client_sock; -// { -// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); -// -// proactor_->Await([&] { -// client_sock.reset(new tls::TlsSocket(proactor_->CreateSocket())); -// client_sock->InitSSL(ssl_ctx); -// }); -// SSL_CTX_free(ssl_ctx); -// } -// -// error_code ec = proactor_->Await([&] { -// LOG(INFO) << "Connecting to " << listen_ep_; -// return client_sock->Connect(listen_ep_); -// }); -// ASSERT_FALSE(ec) << ec.message(); -// -// auto client_fb = proactor_->LaunchFiber([&] { -// uint8_t buf[256]; -// iovec iov{buf, sizeof(buf)}; -// -// client_sock->ReadSome(&iov, 1); -// }); -// -// // Server side. -// auto server_read_fb = proactor_->LaunchFiber([&] { -// // This read actually causes the fiber to flush pending writes and preempt on iouring. -// uint8_t buf[256]; -// iovec iov; -// iov.iov_base = buf; -// iov.iov_len = sizeof(buf); -// server_socket_->ReadSome(&iov, 1); -// }); -// -// auto write_res = proactor_->Await([&] { -// ThisFiber::Yield(); -// uint8_t buf[16] = {0}; -// -// VLOG(1) << "Writing to client"; -// return server_socket_->Write(buf); -// }); -// -// ASSERT_FALSE(write_res) << write_res; -// LOG(INFO) << "Finished"; -// client_fb.Join(); -// proactor_->Await([&] { std::ignore = client_sock->Close(); }); -// server_read_fb.Join(); -// } +TEST_P(TlsSocketTest, ShortWrite) { + unique_ptr client_sock; + { + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + + proactor_->Await([&] { + client_sock.reset(new tls::TlsSocket(proactor_->CreateSocket())); + client_sock->InitSSL(ssl_ctx); + }); + SSL_CTX_free(ssl_ctx); + } + + error_code ec = proactor_->Await([&] { + LOG(INFO) << "Connecting to " << listen_ep_; + return client_sock->Connect(listen_ep_); + }); + ASSERT_FALSE(ec) << ec.message(); + + auto client_fb = proactor_->LaunchFiber([&] { + uint8_t buf[256]; + iovec iov{buf, sizeof(buf)}; + + client_sock->ReadSome(&iov, 1); + }); + + // Server side. + auto server_read_fb = proactor_->LaunchFiber([&] { + // This read actually causes the fiber to flush pending writes and preempt on iouring. + uint8_t buf[256]; + iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + server_socket_->ReadSome(&iov, 1); + }); + + auto write_res = proactor_->Await([&] { + ThisFiber::Yield(); + uint8_t buf[16] = {0}; + + VLOG(1) << "Writing to client"; + return server_socket_->Write(buf); + }); + + ASSERT_FALSE(write_res) << write_res; + LOG(INFO) << "Finished"; + client_fb.Join(); + proactor_->Await([&] { std::ignore = client_sock->Close(); }); + server_read_fb.Join(); +} class AsyncTlsSocketTest : public testing::TestWithParam { protected: @@ -257,16 +257,15 @@ class AsyncTlsSocketTest : public testing::TestWithParam { uint32_t conn_sock_err_mask_ = 0; }; -// Epoll is blocking so this test works only on iouring -// INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, -// testing::Values( -//#ifdef __linux__ -// , -// "uring" -//#endif -// ), -// [](const auto& info) { return string(info.param); }); -// +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + void AsyncTlsSocketTest::SetUp() { #if __linux__ bool use_uring = GetParam() == "uring"; @@ -337,58 +336,58 @@ void AsyncTlsSocketTest::TearDown() { SSL_CTX_free(ssl_ctx_); } -// -// TEST_P(AsyncTlsSocketTest, AsyncRW) { -// unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); -// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); -// tls_sock->InitSSL(ssl_ctx); -// -// proactor_->Await([&] { -// ThisFiber::SetName("ConnectFb"); -// -// LOG(INFO) << "Connecting to " << listen_ep_; -// error_code ec = tls_sock->Connect(listen_ep_); -// EXPECT_FALSE(ec); -// uint8_t res[16]; -// std::fill(std::begin(res), std::end(res), uint8_t(120)); -// { -// VLOG(1) << "Before writesome"; -// -// Done done; -// iovec v{.iov_base = &res, .iov_len = 16}; -// -// tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { -// EXPECT_TRUE(result.has_value()); -// EXPECT_EQ(*result, 16); -// done.Notify(); -// }); -// -// done.Wait(); -// } -// { -// uint8_t buf[16]; -// Done done; -// iovec v{.iov_base = &buf, .iov_len = 16}; -// tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { -// EXPECT_TRUE(result.has_value()); -// EXPECT_EQ(*result, 16); -// done.Notify(); -// }); -// -// done.Wait(); -// -// EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); -// } -// -// VLOG(1) << "closing client sock " << tls_sock->native_handle(); -// std::ignore = tls_sock->Close(); -// accept_fb_.Join(); -// VLOG(1) << "After join"; -// ASSERT_FALSE(ec) << ec.message(); -// ASSERT_FALSE(accept_ec_); -// }); -// SSL_CTX_free(ssl_ctx); -//} + +TEST_P(AsyncTlsSocketTest, AsyncRW) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[16]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = 16}; + + tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[16]; + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); + SSL_CTX_free(ssl_ctx); +} class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { virtual void HandleRequest() { @@ -414,65 +413,65 @@ class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { static constexpr size_t payload_sz_ = 32768; }; -// INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, -// testing::Values("epoll" -//#ifdef __linux__ -// , -// "uring" -//#endif -// ), -// [](const auto& info) { return string(info.param); }); -// -// TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { -// unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); -// SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); -// tls_sock->InitSSL(ssl_ctx); -// -// proactor_->Await([&] { -// ThisFiber::SetName("ConnectFb"); -// -// LOG(INFO) << "Connecting to " << listen_ep_; -// error_code ec = tls_sock->Connect(listen_ep_); -// EXPECT_FALSE(ec); -// uint8_t res[payload_sz_]; -// std::fill(std::begin(res), std::end(res), uint8_t(120)); -// { -// VLOG(1) << "Before writesome"; -// -// Done done; -// iovec v{.iov_base = &res, .iov_len = payload_sz_}; -// -// tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { -// EXPECT_FALSE(result); -// done.Notify(); -// }); -// -// done.Wait(); -// } -// { -// uint8_t buf[payload_sz_]; -// Done done; -// iovec v{.iov_base = &buf, .iov_len = payload_sz_}; -// -// tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { -// EXPECT_FALSE(result); -// done.Notify(); -// }); -// -// done.Wait(); -// -// EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); -// } -// -// VLOG(1) << "closing client sock " << tls_sock->native_handle(); -// std::ignore = tls_sock->Close(); -// accept_fb_.Join(); -// VLOG(1) << "After join"; -// ASSERT_FALSE(ec) << ec.message(); -// ASSERT_FALSE(accept_ec_); -// }); -// SSL_CTX_free(ssl_ctx); -// } +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[payload_sz_]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = payload_sz_}; + + tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[payload_sz_]; + Done done; + iovec v{.iov_base = &buf, .iov_len = payload_sz_}; + + tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); + SSL_CTX_free(ssl_ctx); +} class AsyncTlsSocketRenegotiate : public AsyncTlsSocketTest { virtual void HandleRequest() { From 3cbfc7a0b0d64aab87884e3548f7a61e584a72ec Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 10 Apr 2025 18:20:21 +0300 Subject: [PATCH 7/8] fix stack size blow --- util/tls/tls_socket_test.cc | 176 +++++++++++++++++++----------------- 1 file changed, 94 insertions(+), 82 deletions(-) diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index 2fed738a..b14c832c 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -292,8 +292,10 @@ void AsyncTlsSocketTest::SetUp() { CHECK(!ec); listen_ep_ = listen_socket_->LocalEndpoint(); + std::string name("accept"); + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; - accept_fb_ = proactor_->LaunchFiber("AcceptFb", [this] { + accept_fb_ = proactor_->LaunchFiber(opts, [this] { auto accept_res = listen_socket_->Accept(); VLOG_IF(1, !accept_res) << "Accept res: " << accept_res.error(); @@ -427,49 +429,53 @@ TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); - proactor_->Await([&] { - ThisFiber::SetName("ConnectFb"); - - LOG(INFO) << "Connecting to " << listen_ep_; - error_code ec = tls_sock->Connect(listen_ep_); - EXPECT_FALSE(ec); - uint8_t res[payload_sz_]; - std::fill(std::begin(res), std::end(res), uint8_t(120)); - { - VLOG(1) << "Before writesome"; - - Done done; - iovec v{.iov_base = &res, .iov_len = payload_sz_}; - - tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { - EXPECT_FALSE(result); - done.Notify(); - }); - - done.Wait(); - } - { - uint8_t buf[payload_sz_]; - Done done; - iovec v{.iov_base = &buf, .iov_len = payload_sz_}; - - tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { - EXPECT_FALSE(result); - done.Notify(); - }); - - done.Wait(); - - EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); - } - - VLOG(1) << "closing client sock " << tls_sock->native_handle(); - std::ignore = tls_sock->Close(); - accept_fb_.Join(); - VLOG(1) << "After join"; - ASSERT_FALSE(ec) << ec.message(); - ASSERT_FALSE(accept_ec_); - }); + std::string name = "main stack"; + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; + proactor_->Await( + [&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[payload_sz_]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = payload_sz_}; + + tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[payload_sz_]; + Done done; + iovec v{.iov_base = &buf, .iov_len = payload_sz_}; + + tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }, + opts); SSL_CTX_free(ssl_ctx); } @@ -511,44 +517,50 @@ TEST_P(AsyncTlsSocketRenegotiate, Renegotiate) { SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); tls_sock->InitSSL(ssl_ctx); - proactor_->Await([&] { - ThisFiber::SetName("ConnectFb"); - - error_code ec = tls_sock->Connect(listen_ep_); - EXPECT_FALSE(ec); - - uint8_t send_buf[payload_sz_]; - uint8_t res[payload_sz_]; - std::fill(std::begin(send_buf), std::end(send_buf), uint8_t(120)); - { - Done done_read, done_write; - iovec send_vec{.iov_base = &send_buf, .iov_len = payload_sz_}; - iovec read_vec{.iov_base = &res, .iov_len = payload_sz_}; - - // We don't need to call ssl_renegotiate here, the first read will also negotiate the protocol - tls_sock->AsyncRead(&read_vec, 1, [&](auto result) mutable { - EXPECT_FALSE(result); - done_read.Notify(); - }); - - // Here AsyncWrite will resume later since write_in_progress bit is set - tls_sock->AsyncWrite(&send_vec, 1, [&](auto result) mutable { - EXPECT_FALSE(result); - done_write.Notify(); - }); - - done_write.Wait(); - done_read.Wait(); - EXPECT_EQ(memcmp(begin(res), begin(send_buf), payload_sz_), 0); - } - - VLOG(1) << "closing client sock " << tls_sock->native_handle(); - std::ignore = tls_sock->Close(); - accept_fb_.Join(); - VLOG(1) << "After join"; - ASSERT_FALSE(ec) << ec.message(); - ASSERT_FALSE(accept_ec_); - }); + std::string name = "main stack"; + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; + + proactor_->Await( + [&] { + ThisFiber::SetName("ConnectFb"); + + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + + uint8_t send_buf[payload_sz_]; + uint8_t res[payload_sz_]; + std::fill(std::begin(send_buf), std::end(send_buf), uint8_t(120)); + { + Done done_read, done_write; + iovec send_vec{.iov_base = &send_buf, .iov_len = payload_sz_}; + iovec read_vec{.iov_base = &res, .iov_len = payload_sz_}; + + // We don't need to call ssl_renegotiate here, the first read will also negotiate the + // protocol + tls_sock->AsyncRead(&read_vec, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done_read.Notify(); + }); + + // Here AsyncWrite will resume later since write_in_progress bit is set + tls_sock->AsyncWrite(&send_vec, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done_write.Notify(); + }); + + done_write.Wait(); + done_read.Wait(); + EXPECT_EQ(memcmp(begin(res), begin(send_buf), payload_sz_), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }, + opts); SSL_CTX_free(ssl_ctx); } } // namespace fb2 From e39ae8b51d37444a83f20e83e3d936638a41e4e9 Mon Sep 17 00:00:00 2001 From: kostas Date: Fri, 11 Apr 2025 10:12:08 +0300 Subject: [PATCH 8/8] fix mac --- util/tls/tls_socket_test.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index b14c832c..6fc5c491 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -507,10 +507,12 @@ class AsyncTlsSocketRenegotiate : public AsyncTlsSocketTest { static constexpr size_t payload_sz_ = 32768; }; +// TODO once we fix epoll AsyncRead from blocking to nonblocking, we should add it here as well +// For now also disable this on mac since there is no iouring on mac #ifdef __linux__ + INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketRenegotiate, testing::Values("uring"), [](const auto& info) { return string(info.param); }); -#endif TEST_P(AsyncTlsSocketRenegotiate, Renegotiate) { unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); @@ -563,5 +565,8 @@ TEST_P(AsyncTlsSocketRenegotiate, Renegotiate) { opts); SSL_CTX_free(ssl_ctx); } + +#endif + } // namespace fb2 } // namespace util