Skip to content

feat: TlsSocket AsyncWriteSome and AsyncReadSome #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions io/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct AsyncReadState {

AsyncReadState(AsyncSource* source, const iovec* v, uint32_t length)
: arr(length), owner(source) {
cur = arr.data();
Copy link
Collaborator Author

@kostasrim kostasrim Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch -- we now got unit tests though 😄

std::copy(v, v + length, arr.data());
}

Expand Down
1 change: 1 addition & 0 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress
async_write_pending_ = 1;
}

// TODO implement async functionality
void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epoll AsyncRead is synchronous

auto res = ReadSome(v, len);
cb(res);
Expand Down
271 changes: 265 additions & 6 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,22 +331,245 @@ io::Result<TlsSocket::PushResult> TlsSocket::PushToEngine(const iovec* ptr, uint
return res;
}

// TODO: to implement async functionality.
void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result<size_t> result) {
auto current = std::exchange(owner->async_write_req_, std::nullopt);
current->caller_completion_cb(result);
}

void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result<size_t> result) {
auto current = std::exchange(owner->async_read_req_, std::nullopt);
current->caller_completion_cb(result);
}

void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) {
switch (op_val) {
case Engine::EOF_STREAM:
break;
case Engine::NEED_READ_AND_MAYBE_WRITE:
MaybeSendOutputAsync(true);
break;
case Engine::NEED_WRITE:
MaybeSendOutputAsync();
break;
default:
LOG(DFATAL) << "Unsupported " << op_val;
}
}

void TlsSocket::AsyncWriteReq::Run() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have while (true) here but we return in most cases?
it's hard to understand this logic - at first I thought Run is fiber-blocking.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I am trying to understand - what is the path to the second loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no fiber blocks, everything is async and yes we need that. An example would be:

  1. We push to the engine
  2. We call HandleOpAsync because last_push.engine_opcode < 0 and we submit an async operation to the underline socket. Its completion handler will call Run() again.
  3. The fiber continues doing its work while the underline async operation is in flight
  4. The async operation completes, invokes the completion handler which calls Run() again. Our state now is MaybeSendOutputAsyncTag so we reach that branch but the condition last_push.written() > 0 is not true. Now we loop, push to the engine again and proceed.

if (state == AsyncWriteReq::PushToEngine) {
// We never preempt here
io::Result<PushResult> push_res = owner->PushToEngine(vec, len);
if (!push_res) {
CompleteAsyncReq(make_unexpected(push_res.error()));
// We are done with this AsyncWriteReq. Caller might started
// a new one.
return;
}
last_push = *push_res;
state = AsyncWriteReq::HandleOpAsyncTag;
}

if (state == AsyncWriteReq::HandleOpAsyncTag) {
state = AsyncWriteReq::MaybeSendOutputAsyncTag;
if (last_push.engine_opcode < 0) {
if (last_push.engine_opcode == Engine::EOF_STREAM) {
VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle();
CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted)));
return;
}
HandleOpAsync(last_push.engine_opcode);
return;
}
}

if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) {
state = AsyncWriteReq::PushToEngine;
if (last_push.written > 0) {
continuation = [this]() { CompleteAsyncReq(last_push.written); };
MaybeSendOutputAsync();
return;
}
// Run again we are not done.
Run();
}
}

void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
io::Result<size_t> res = WriteSome(v, len);
cb(res);
CHECK(!async_write_req_.has_value());
async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len));
async_write_req_->Run();
}

void TlsSocket::AsyncReadReq::Run() {
DCHECK_GT(len, 0u);

while (true) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here as well, why do we need a loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check this one

DCHECK(!dest.empty());

size_t read_len = std::min(dest.size(), size_t(INT_MAX));

Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len);

int op_val = op_result;

DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val;

if (op_val > 0) {
read_total += op_val;

// I do not understand this code and what the hell I meant to do here. Seems to work
// though.
if (size_t(op_val) == read_len) {
if (size_t(op_val) < dest.size()) {
dest.remove_prefix(op_val);
} else {
++vec;
--len;
if (len == 0) {
// We are done. Call completion callback.
CompleteAsyncReq(read_total);
return;
}
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(vec->iov_base), vec->iov_len};
}
// We read everything we asked for but there are still buffers left to fill.
continue;
}
break;
}

DCHECK(!continuation);
if (op_val == Engine::EOF_STREAM) {
VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle();
CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted)));
}
HandleOpAsync(op_val);
return;
}

// We are done. Call completion callback.
CompleteAsyncReq(read_total);
}

// TODO: to implement async functionality.
void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
io::Result<size_t> res = ReadSome(v, len);
cb(res);
CHECK(!async_read_req_.has_value());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how I would design it.
TlsEngine could already have a ready read buffer. you could return quickly by fetching it without doing any async stuff. Only if the buffer is empty it's worth dispatching the upstream read.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, isn't that what I do here? We don't start/submit any async operation. We construct
the object and call Run() and within this function we will read the EngineBuffer first if there are bytes available -- so we are on the fast path already, that is, we don't dispatch upstream.
As I am changing it now to a unique ptr, we now always get an allocation. I think the cost would be insignificant because the allocator will be able to satisfy that without actually requesting more memory and this is not a common path anyways.

auto req = AsyncReadReq(this, std::move(cb), v, len);
req.dest = {reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len};
async_read_req_.emplace(std::move(req));
async_read_req_->Run();
}

SSL* TlsSocket::ssl_handle() {
return engine_ ? engine_->native_handle() : nullptr;
}

void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result<size_t> write_result,
Engine::Buffer buffer) {
if (!write_result) {
owner->state_ &= ~WRITE_IN_PROGRESS;

// broken_pipe - happens when the other side closes the connection. do not log this.
if (write_result.error() != errc::broken_pipe) {
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_)
<< ", write_total:" << owner->upstream_write_ << " "
<< " pending output: " << owner->engine_->OutputPending()
<< " HandleUpstreamAsyncWrite failed " << write_result.error();
}

// We are done. Errornous exit.
CompleteAsyncReq(write_result);
return;
}

CHECK_GT(*write_result, 0u);
owner->upstream_write_ += *write_result;
owner->engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);

// We are not done. Re-arm the async write until we drive it to completion or error.
if (!buffer.empty()) {
auto& scratch = scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();
return owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) {
HandleUpstreamAsyncWrite(write_result, buffer);
});
}

if (owner->engine_->OutputPending() > 0) {
LOG(INFO) << "ssl buffer is not empty with " << owner->engine_->OutputPending()
<< " bytes. short write detected";
}

owner->state_ &= ~WRITE_IN_PROGRESS;

// If there is a continuation run it and let it yield back to the main loop.
// Continuation is responsible for calling AsyncRequest::Run again.
if (continuation) {
auto cont = std::exchange(continuation, std::function<void()>{});
cont();
return;
}

// Yield back to main loop
Run();
return;
}

void TlsSocket::AsyncReqBase::StartUpstreamWrite() {
Engine::Buffer buffer = owner->engine_->PeekOutputBuf();
DCHECK(!buffer.empty());
DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0);

DVLOG(2) << "HandleUpstreamWrite " << buffer.size();
// we do not allow concurrent writes from multiple fibers.
owner->state_ |= WRITE_IN_PROGRESS;

auto& scratch = scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();

owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) {
HandleUpstreamAsyncWrite(write_result, buffer);
});
}

void TlsSocket::AsyncReqBase::MaybeSendOutputAsync(bool should_read) {
auto body = [should_read, this]() {
if (should_read) {
return StartUpstreamRead();
}
if (continuation) {
auto cont = std::exchange(continuation, std::function<void()>{});
return cont();
}
return Run();
};

if (owner->engine_->OutputPending() == 0) {
return body();
}

// This function is present in both read and write paths.
// meaning that both of them can be called concurrently from differrent fibers and then
// race over flushing the output buffer. We use state_ to prevent that.
if (owner->state_ & WRITE_IN_PROGRESS) {
// TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise
// we might deadlock. See the sync version of HandleOp for more info
return body();
}

if (should_read) {
DCHECK(!continuation);
continuation = [this]() {
// Yields to Run internally()
StartUpstreamRead();
};
}
StartUpstreamWrite();
}

auto TlsSocket::MaybeSendOutput() -> error_code {
if (engine_->OutputPending() == 0)
return {};
Expand Down Expand Up @@ -374,6 +597,42 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
return HandleUpstreamWrite();
}

void TlsSocket::AsyncReqBase::StartUpstreamRead() {
if (owner->state_ & READ_IN_PROGRESS) {
// TODO we should yield instead
Run();
return;
}

auto buffer = owner->engine_->PeekInputBuf();
owner->state_ |= READ_IN_PROGRESS;

auto& scratch = scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();

owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) {
owner->state_ &= ~READ_IN_PROGRESS;
if (!read_result) {
// log any errors as well as situations where we have unflushed output.
if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) {
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_)
<< ", write_total:" << owner->upstream_write_ << " "
<< " pending output: " << owner->engine_->OutputPending() << " "
<< "StartUpstreamRead failed " << read_result.error();
}
// Erronous path. Apply the completion callback and exit.
CompleteAsyncReq(read_result);
return;
}

DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes";
owner->engine_->CommitInput(*read_result);
// We are not done. Give back control to the main loop.
Run();
});
}

auto TlsSocket::HandleUpstreamRead() -> error_code {
RETURN_ON_ERROR(MaybeSendOutput());

Expand Down
Loading
Loading