diff --git a/io/io.cc b/io/io.cc index 3727ca5b..7606657a 100644 --- a/io/io.cc +++ b/io/io.cc @@ -50,6 +50,7 @@ struct AsyncReadState { AsyncReadState(AsyncSource* source, const iovec* v, uint32_t length) : arr(length), owner(source) { + cur = arr.data(); std::copy(v, v + length, arr.data()); } diff --git a/util/fibers/epoll_socket.cc b/util/fibers/epoll_socket.cc index 04a5c241..aa39cc39 100644 --- a/util/fibers/epoll_socket.cc +++ b/util/fibers/epoll_socket.cc @@ -386,6 +386,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress async_write_pending_ = 1; } +// TODO implement async functionality void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { auto res = ReadSome(v, len); cb(res); diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index c5efe8ce..2f803a06 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -331,22 +331,254 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint return res; } -// TODO: to implement async functionality. +void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result result) { + auto current = std::exchange(owner->async_write_req_, std::nullopt); + current->caller_completion_cb(result); +} + +void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result result) { + auto current = std::exchange(owner->async_read_req_, std::nullopt); + current->caller_completion_cb(result); +} + +void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { + switch (op_val) { + case Engine::EOF_STREAM: + VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); + CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); + break; + case Engine::NEED_READ_AND_MAYBE_WRITE: + MaybeSendOutputAsyncWithRead(); + break; + case Engine::NEED_WRITE: + MaybeSendOutputAsync(); + break; + default: + LOG(DFATAL) << "Unsupported " << op_val; + } +} + +void TlsSocket::AsyncWriteReq::Run() { + while (true) { + if (state == AsyncWriteReq::PushToEngine) { + // We never preempt here + io::Result push_res = owner->PushToEngine(vec, len); + if (!push_res) { + CompleteAsyncReq(make_unexpected(push_res.error())); + // We are done with this AsyncWriteReq. Caller might started + // a new one. + return; + } + last_push = *push_res; + state = AsyncWriteReq::HandleOpAsyncTag; + } + + if (state == AsyncWriteReq::HandleOpAsyncTag) { + state = AsyncWriteReq::MaybeSendOutputAsyncTag; + if (last_push.engine_opcode < 0) { + HandleOpAsync(last_push.engine_opcode); + return; + } + } + + if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { + state = AsyncWriteReq::PushToEngine; + if (last_push.written > 0) { + state = AsyncWriteReq::Done; + MaybeSendOutputAsync(); + return; + } + } + + if (state == AsyncWriteReq::Done) { + return CompleteAsyncReq(last_push.written); + } + } +} + void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = WriteSome(v, len); - cb(res); + CHECK(!async_write_req_.has_value()); + async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len)); + async_write_req_->Run(); +} + +void TlsSocket::AsyncReadReq::Run() { + DCHECK_GT(len, 0u); + + while (true) { + DCHECK(!dest.empty()); + + size_t read_len = std::min(dest.size(), size_t(INT_MAX)); + + Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len); + + int op_val = op_result; + + DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; + + if (op_val > 0) { + read_total += op_val; + + // I do not understand this code and what the hell I meant to do here. Seems to work + // though. + if (size_t(op_val) == read_len) { + if (size_t(op_val) < dest.size()) { + dest.remove_prefix(op_val); + } else { + ++vec; + --len; + if (len == 0) { + // We are done. Call completion callback. + CompleteAsyncReq(read_total); + return; + } + dest = Engine::MutableBuffer{reinterpret_cast(vec->iov_base), vec->iov_len}; + } + // We read everything we asked for but there are still buffers left to fill. + continue; + } + break; + } + + DCHECK(!continuation); + HandleOpAsync(op_val); + return; + } + + // We are done. Call completion callback. + CompleteAsyncReq(read_total); } -// TODO: to implement async functionality. void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = ReadSome(v, len); - cb(res); + CHECK(!async_read_req_.has_value()); + auto req = AsyncReadReq(this, std::move(cb), v, len); + req.dest = {reinterpret_cast(v->iov_base), v->iov_len}; + async_read_req_.emplace(std::move(req)); + async_read_req_->Run(); } SSL* TlsSocket::ssl_handle() { return engine_ ? engine_->native_handle() : nullptr; } +void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result write_result, + Engine::Buffer buffer) { + if (!write_result) { + owner->state_ &= ~WRITE_IN_PROGRESS; + + // broken_pipe - happens when the other side closes the connection. do not log this. + if (write_result.error() != errc::broken_pipe) { + VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) + << ", write_total:" << owner->upstream_write_ << " " + << " pending output: " << owner->engine_->OutputPending() + << " HandleUpstreamAsyncWrite failed " << write_result.error(); + } + // Run pending if blocked + RunPending(); + + // We are done. Errornous exit. + CompleteAsyncReq(write_result); + return; + } + + CHECK_GT(*write_result, 0u); + owner->upstream_write_ += *write_result; + owner->engine_->ConsumeOutputBuf(*write_result); + // We could preempt while calling WriteSome, and the engine could get more data to write. + // Therefore we sync the buffer. + buffer = owner->engine_->PeekOutputBuf(); + + // We are not done. Re-arm the async write until we drive it to completion or error. + if (!buffer.empty()) { + auto& scratch = scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + return owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); + } + + if (owner->engine_->OutputPending() > 0) { + LOG(INFO) << "ssl buffer is not empty with " << owner->engine_->OutputPending() + << " bytes. short write detected"; + } + + owner->state_ &= ~WRITE_IN_PROGRESS; + + // Run pending if blocked + RunPending(); + + // If there is a continuation run it and let it yield back to the main loop. + // Continuation is responsible for calling AsyncRequest::Run again. + if (continuation) { + auto cont = std::exchange(continuation, std::function{}); + cont(); + return; + } + + // Yield back to main loop + Run(); + return; +} + +void TlsSocket::AsyncReqBase::StartUpstreamWrite() { + Engine::Buffer buffer = owner->engine_->PeekOutputBuf(); + DCHECK(!buffer.empty()); + DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0); + + DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); + // we do not allow concurrent writes from multiple fibers. + owner->state_ |= WRITE_IN_PROGRESS; + + auto& scratch = scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); +} + +void TlsSocket::AsyncReqBase::MaybeSendOutputAsync() { + if (owner->engine_->OutputPending() == 0) { + return; + } + + // This function is present in both read and write paths. + // meaning that both of them can be called concurrently from differrent fibers and then + // race over flushing the output buffer. We use state_ to prevent that. + if (owner->state_ & WRITE_IN_PROGRESS) { + DCHECK(owner->pending_blocked_ == nullptr); + owner->pending_blocked_ = this; + return; + } + + StartUpstreamWrite(); +} + +void TlsSocket::AsyncReqBase::MaybeSendOutputAsyncWithRead() { + if (owner->engine_->OutputPending() == 0) { + return StartUpstreamRead(); + } + + // This function is present in both read and write paths. + // meaning that both of them can be called concurrently from differrent fibers and then + // race over flushing the output buffer. We use state_ to prevent that. + if (owner->state_ & WRITE_IN_PROGRESS) { + DCHECK(owner->pending_blocked_ == nullptr); + owner->pending_blocked_ = this; + return; + } + + DCHECK(!continuation); + continuation = [this]() { + // Yields to Run internally() + StartUpstreamRead(); + }; + + StartUpstreamWrite(); +} + auto TlsSocket::MaybeSendOutput() -> error_code { if (engine_->OutputPending() == 0) return {}; @@ -374,6 +606,48 @@ auto TlsSocket::MaybeSendOutput() -> error_code { return HandleUpstreamWrite(); } +void TlsSocket::AsyncReqBase::StartUpstreamRead() { + if (owner->state_ & READ_IN_PROGRESS) { + return; + } + + auto buffer = owner->engine_->PeekInputBuf(); + owner->state_ |= READ_IN_PROGRESS; + + auto& scratch = scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { + owner->state_ &= ~READ_IN_PROGRESS; + RunPending(); + if (!read_result) { + // log any errors as well as situations where we have unflushed output. + if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { + VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) + << ", write_total:" << owner->upstream_write_ << " " + << " pending output: " << owner->engine_->OutputPending() << " " + << "StartUpstreamRead failed " << read_result.error(); + } + // Erronous path. Apply the completion callback and exit. + CompleteAsyncReq(read_result); + return; + } + + DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; + owner->engine_->CommitInput(*read_result); + // We are not done. Give back control to the main loop. + Run(); + }); +} + +void TlsSocket::AsyncReqBase::RunPending() { + if (owner->pending_blocked_) { + auto* blocked = std::exchange(owner->pending_blocked_, nullptr); + blocked->Run(); + } +} + auto TlsSocket::HandleUpstreamRead() -> error_code { RETURN_ON_ERROR(MaybeSendOutput()); diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 83a470b0..a0bf2b3b 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -7,6 +7,7 @@ #include #include +#include #include "util/fiber_socket_base.h" #include "util/tls/tls_engine.h" @@ -90,7 +91,6 @@ class TlsSocket final : public FiberSocketBase { virtual void SetProactor(ProactorBase* p) override; private: - struct PushResult { size_t written = 0; int engine_opcode = 0; // Engine::OpCode @@ -114,6 +114,80 @@ class TlsSocket final : public FiberSocketBase { std::unique_ptr engine_; size_t upstream_write_ = 0; + struct AsyncReqBase { + AsyncReqBase(TlsSocket* owner, io::AsyncProgressCb caller_cb, const iovec* vec, uint32_t len) + : owner(owner), caller_completion_cb(std::move(caller_cb)), vec(vec), len(len) { + } + + TlsSocket* owner; + // Callback passed from the user. + io::AsyncProgressCb caller_completion_cb; + + const iovec* vec; + uint32_t len; + + std::function continuation; + iovec scratch_iovec; + + // Asynchronous helpers + void MaybeSendOutputAsyncWithRead(); + + // Returns true if we did not start an upstream write + void MaybeSendOutputAsync(); + + void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); + + void HandleOpAsync(int op_val); + + void StartUpstreamWrite(); + void StartUpstreamRead(); + + void RunPending(); + + virtual void Run() = 0; + virtual void CompleteAsyncReq(io::Result result) = 0; + }; + + // Helper function that resets the internal async request, applies the + // user AsyncProgressCb and returns. We need this, because progress callbacks + // can start another async request and for that to work, we need to clean up + // the one we are running on. + void CompleteAsyncRequest(io::Result result); + + struct AsyncWriteReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + // TODO handle async yields to avoid deadlocks (see HandleOp) + enum State { PushToEngine, HandleOpAsyncTag, MaybeSendOutputAsyncTag, Done }; + State state = PushToEngine; + PushResult last_push; + + // Main loop + void Run() override; + void CompleteAsyncReq(io::Result result) override; + }; + + friend AsyncWriteReq; + + struct AsyncReadReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + Engine::MutableBuffer dest; + size_t read_total = 0; + + // Main loop + void Run() override; + void CompleteAsyncReq(io::Result result) override; + }; + + friend AsyncReadReq; + + // TODO clean up the optional before we yield such that progress callback can dispatch another + // async operation + std::optional async_write_req_; + std::optional async_read_req_; + AsyncReqBase* pending_blocked_ = nullptr; + enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 }; uint8_t state_{0}; }; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc index 876e0669..6fc5c491 100644 --- a/util/tls/tls_socket_test.cc +++ b/util/tls/tls_socket_test.cc @@ -56,7 +56,6 @@ SSL_CTX* CreateSslCntx(TlsContextRole role) { if (role == TlsContextRole::SERVER) { ctx = SSL_CTX_new(TLS_server_method()); - // TODO init those to build on ci } else { ctx = SSL_CTX_new(TLS_client_method()); } @@ -136,9 +135,7 @@ void TlsSocketTest::SetUp() { VLOG(1) << "Accepted connection " << sock->native_handle(); sock->SetProactor(proactor_.get()); - sock->RegisterOnErrorCb([](uint32_t mask) { - LOG(ERROR) << "Error mask: " << mask; - }); + sock->RegisterOnErrorCb([](uint32_t mask) { LOG(ERROR) << "Error mask: " << mask; }); server_socket_ = std::make_unique(sock); ssl_ctx_ = CreateSslCntx(SERVER); server_socket_->InitSSL(ssl_ctx_); @@ -218,5 +215,358 @@ TEST_P(TlsSocketTest, ShortWrite) { server_read_fb.Join(); } +class AsyncTlsSocketTest : public testing::TestWithParam { + protected: + void SetUp() final; + void TearDown() final; + + static void SetUpTestCase() { + testing::FLAGS_gtest_death_test_style = "threadsafe"; + } + + using IoResult = int; + + // TODO clean up + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); + tls_socket_->Accept(); + + uint8_t buf[16]; + auto res = tls_socket_->Recv(buf); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == 16); + + auto write_res = tls_socket_->Write(buf); + EXPECT_FALSE(write_res); + } + + unique_ptr proactor_; + thread proactor_thread_; + unique_ptr listen_socket_; + unique_ptr conn_socket_; + unique_ptr tls_socket_; + SSL_CTX* ssl_ctx_; + + uint16_t listen_port_ = 0; + Fiber accept_fb_; + Fiber conn_fb_; + std::error_code accept_ec_; + FiberSocketBase::endpoint_type listen_ep_; + uint32_t conn_sock_err_mask_ = 0; +}; + +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTest, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +void AsyncTlsSocketTest::SetUp() { +#if __linux__ + bool use_uring = GetParam() == "uring"; + ProactorBase* proactor = nullptr; + if (use_uring) + proactor = new UringProactor; + else + proactor = new EpollProactor; +#else + ProactorBase* proactor = new EpollProactor; +#endif + + proactor_thread_ = thread{[proactor] { + InitProactor(proactor); + proactor->Run(); + }}; + + proactor_.reset(proactor); + + error_code ec = proactor_->AwaitBrief([&] { + listen_socket_.reset(proactor_->CreateSocket()); + return listen_socket_->Listen(0, 0); + }); + + CHECK(!ec); + listen_ep_ = listen_socket_->LocalEndpoint(); + std::string name("accept"); + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; + + accept_fb_ = proactor_->LaunchFiber(opts, [this] { + auto accept_res = listen_socket_->Accept(); + VLOG_IF(1, !accept_res) << "Accept res: " << accept_res.error(); + + if (accept_res) { + VLOG(1) << "Accepted connection " << *accept_res; + FiberSocketBase* sock = *accept_res; + conn_socket_.reset(sock); + conn_socket_->SetProactor(proactor_.get()); + conn_socket_->RegisterOnErrorCb([this](uint32_t mask) { + LOG(INFO) << "Error mask: " << mask; + conn_sock_err_mask_ = mask; + }); + conn_fb_ = proactor_->LaunchFiber([this]() { HandleRequest(); }); + } else { + accept_ec_ = accept_res.error(); + } + }); +} + +void AsyncTlsSocketTest::TearDown() { + VLOG(1) << "TearDown"; + + proactor_->Await([&] { + std::ignore = listen_socket_->Shutdown(SHUT_RDWR); + if (conn_socket_) { + std::ignore = conn_socket_->Close(); + } else { + std::ignore = tls_socket_->Close(); + } + }); + + conn_fb_.JoinIfNeeded(); + accept_fb_.JoinIfNeeded(); + + proactor_->Await([&] { std::ignore = listen_socket_->Close(); }); + + proactor_->Stop(); + proactor_thread_.join(); + proactor_.reset(); + + SSL_CTX_free(ssl_ctx_); +} + +TEST_P(AsyncTlsSocketTest, AsyncRW) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[16]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = 16}; + + tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[16]; + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), 16), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); + SSL_CTX_free(ssl_ctx); +} + +class AsyncTlsSocketTestPartialRW : public AsyncTlsSocketTest { + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); + tls_socket_->Accept(); + + uint8_t buf[payload_sz_]; + auto res = tls_socket_->ReadAtLeast(buf, payload_sz_); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == payload_sz_) << res.value(); + + absl::Span partial_write(buf, payload_sz_ / 2); + // We split the write to two small ones. + auto write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + } + + public: + static constexpr size_t payload_sz_ = 32768; +}; + +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketTestPartialRW, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +TEST_P(AsyncTlsSocketTestPartialRW, PartialAsyncReadWrite) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + std::string name = "main stack"; + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; + proactor_->Await( + [&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + uint8_t res[payload_sz_]; + std::fill(std::begin(res), std::end(res), uint8_t(120)); + { + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &res, .iov_len = payload_sz_}; + + tls_sock->AsyncWrite(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + } + { + uint8_t buf[payload_sz_]; + Done done; + iovec v{.iov_base = &buf, .iov_len = payload_sz_}; + + tls_sock->AsyncRead(&v, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done.Notify(); + }); + + done.Wait(); + + EXPECT_EQ(memcmp(begin(res), begin(buf), payload_sz_), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }, + opts); + SSL_CTX_free(ssl_ctx); +} + +class AsyncTlsSocketRenegotiate : public AsyncTlsSocketTest { + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + ssl_ctx_ = CreateSslCntx(SERVER); + tls_socket_->InitSSL(ssl_ctx_); + tls_socket_->Accept(); + + uint8_t buf[payload_sz_]; + auto res = tls_socket_->ReadAtLeast(buf, payload_sz_); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == payload_sz_) << res.value(); + + absl::Span partial_write(buf, payload_sz_ / 2); + // We split the write to two small ones. + auto write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + write_res = tls_socket_->Write(partial_write); + EXPECT_FALSE(write_res); + } + + tls::TlsSocket* Handle() { + return tls_socket_.get(); + } + + public: + static constexpr size_t payload_sz_ = 32768; +}; + +// TODO once we fix epoll AsyncRead from blocking to nonblocking, we should add it here as well +// For now also disable this on mac since there is no iouring on mac +#ifdef __linux__ + +INSTANTIATE_TEST_SUITE_P(Engines, AsyncTlsSocketRenegotiate, testing::Values("uring"), + [](const auto& info) { return string(info.param); }); + +TEST_P(AsyncTlsSocketRenegotiate, Renegotiate) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + SSL_CTX* ssl_ctx = CreateSslCntx(CLIENT); + tls_sock->InitSSL(ssl_ctx); + + std::string name = "main stack"; + Fiber::Opts opts{.name = name, .stack_size = 128 * 1024}; + + proactor_->Await( + [&] { + ThisFiber::SetName("ConnectFb"); + + error_code ec = tls_sock->Connect(listen_ep_); + EXPECT_FALSE(ec); + + uint8_t send_buf[payload_sz_]; + uint8_t res[payload_sz_]; + std::fill(std::begin(send_buf), std::end(send_buf), uint8_t(120)); + { + Done done_read, done_write; + iovec send_vec{.iov_base = &send_buf, .iov_len = payload_sz_}; + iovec read_vec{.iov_base = &res, .iov_len = payload_sz_}; + + // We don't need to call ssl_renegotiate here, the first read will also negotiate the + // protocol + tls_sock->AsyncRead(&read_vec, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done_read.Notify(); + }); + + // Here AsyncWrite will resume later since write_in_progress bit is set + tls_sock->AsyncWrite(&send_vec, 1, [&](auto result) mutable { + EXPECT_FALSE(result); + done_write.Notify(); + }); + + done_write.Wait(); + done_read.Wait(); + EXPECT_EQ(memcmp(begin(res), begin(send_buf), payload_sz_), 0); + } + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }, + opts); + SSL_CTX_free(ssl_ctx); +} + +#endif + } // namespace fb2 } // namespace util