-
Notifications
You must be signed in to change notification settings - Fork 62
feat: TlsSocket AsyncWriteSome and AsyncReadSome #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
64e18d0
701425a
916678c
802a896
1b28ab8
8187dbc
0385a16
acc6ac2
3cbfc7a
e39ae8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -386,6 +386,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress | |
async_write_pending_ = 1; | ||
} | ||
|
||
// TODO implement async functionality | ||
void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. epoll |
||
auto res = ReadSome(v, len); | ||
cb(res); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -331,22 +331,254 @@ io::Result<TlsSocket::PushResult> TlsSocket::PushToEngine(const iovec* ptr, uint | |
return res; | ||
} | ||
|
||
// TODO: to implement async functionality. | ||
void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result<size_t> result) { | ||
auto current = std::exchange(owner->async_write_req_, std::nullopt); | ||
current->caller_completion_cb(result); | ||
} | ||
|
||
void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result<size_t> result) { | ||
auto current = std::exchange(owner->async_read_req_, std::nullopt); | ||
current->caller_completion_cb(result); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { | ||
switch (op_val) { | ||
case Engine::EOF_STREAM: | ||
VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); | ||
CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); | ||
break; | ||
case Engine::NEED_READ_AND_MAYBE_WRITE: | ||
MaybeSendOutputAsyncWithRead(); | ||
break; | ||
case Engine::NEED_WRITE: | ||
MaybeSendOutputAsync(); | ||
break; | ||
default: | ||
LOG(DFATAL) << "Unsupported " << op_val; | ||
} | ||
} | ||
|
||
void TlsSocket::AsyncWriteReq::Run() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what I am trying to understand - what is the path to the second loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are no fiber blocks, everything is async and yes we need that. An example would be:
|
||
while (true) { | ||
if (state == AsyncWriteReq::PushToEngine) { | ||
// We never preempt here | ||
io::Result<PushResult> push_res = owner->PushToEngine(vec, len); | ||
if (!push_res) { | ||
CompleteAsyncReq(make_unexpected(push_res.error())); | ||
// We are done with this AsyncWriteReq. Caller might started | ||
// a new one. | ||
return; | ||
} | ||
last_push = *push_res; | ||
state = AsyncWriteReq::HandleOpAsyncTag; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like switch case without breaks would work here |
||
} | ||
|
||
if (state == AsyncWriteReq::HandleOpAsyncTag) { | ||
state = AsyncWriteReq::MaybeSendOutputAsyncTag; | ||
if (last_push.engine_opcode < 0) { | ||
HandleOpAsync(last_push.engine_opcode); | ||
return; | ||
} | ||
} | ||
|
||
if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { | ||
state = AsyncWriteReq::PushToEngine; | ||
if (last_push.written > 0) { | ||
state = AsyncWriteReq::Done; | ||
MaybeSendOutputAsync(); | ||
return; | ||
} | ||
} | ||
|
||
if (state == AsyncWriteReq::Done) { | ||
return CompleteAsyncReq(last_push.written); | ||
} | ||
} | ||
} | ||
|
||
void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
io::Result<size_t> res = WriteSome(v, len); | ||
cb(res); | ||
CHECK(!async_write_req_.has_value()); | ||
async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len)); | ||
async_write_req_->Run(); | ||
} | ||
|
||
void TlsSocket::AsyncReadReq::Run() { | ||
DCHECK_GT(len, 0u); | ||
|
||
while (true) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here as well, why do we need a loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to check this one |
||
DCHECK(!dest.empty()); | ||
|
||
size_t read_len = std::min(dest.size(), size_t(INT_MAX)); | ||
|
||
Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len); | ||
|
||
int op_val = op_result; | ||
|
||
DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; | ||
|
||
if (op_val > 0) { | ||
read_total += op_val; | ||
|
||
// I do not understand this code and what the hell I meant to do here. Seems to work | ||
// though. | ||
if (size_t(op_val) == read_len) { | ||
if (size_t(op_val) < dest.size()) { | ||
dest.remove_prefix(op_val); | ||
} else { | ||
++vec; | ||
--len; | ||
if (len == 0) { | ||
// We are done. Call completion callback. | ||
CompleteAsyncReq(read_total); | ||
return; | ||
} | ||
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(vec->iov_base), vec->iov_len}; | ||
} | ||
// We read everything we asked for but there are still buffers left to fill. | ||
continue; | ||
} | ||
break; | ||
} | ||
|
||
DCHECK(!continuation); | ||
HandleOpAsync(op_val); | ||
return; | ||
} | ||
|
||
// We are done. Call completion callback. | ||
CompleteAsyncReq(read_total); | ||
} | ||
|
||
// TODO: to implement async functionality. | ||
void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
io::Result<size_t> res = ReadSome(v, len); | ||
cb(res); | ||
CHECK(!async_read_req_.has_value()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not how I would design it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, isn't that what I do here? We don't start/submit any async operation. We construct |
||
auto req = AsyncReadReq(this, std::move(cb), v, len); | ||
req.dest = {reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len}; | ||
async_read_req_.emplace(std::move(req)); | ||
async_read_req_->Run(); | ||
} | ||
|
||
SSL* TlsSocket::ssl_handle() { | ||
return engine_ ? engine_->native_handle() : nullptr; | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result<size_t> write_result, | ||
Engine::Buffer buffer) { | ||
if (!write_result) { | ||
owner->state_ &= ~WRITE_IN_PROGRESS; | ||
|
||
// broken_pipe - happens when the other side closes the connection. do not log this. | ||
if (write_result.error() != errc::broken_pipe) { | ||
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) | ||
<< ", write_total:" << owner->upstream_write_ << " " | ||
<< " pending output: " << owner->engine_->OutputPending() | ||
<< " HandleUpstreamAsyncWrite failed " << write_result.error(); | ||
} | ||
// Run pending if blocked | ||
RunPending(); | ||
|
||
// We are done. Errornous exit. | ||
CompleteAsyncReq(write_result); | ||
return; | ||
} | ||
|
||
CHECK_GT(*write_result, 0u); | ||
owner->upstream_write_ += *write_result; | ||
owner->engine_->ConsumeOutputBuf(*write_result); | ||
// We could preempt while calling WriteSome, and the engine could get more data to write. | ||
// Therefore we sync the buffer. | ||
buffer = owner->engine_->PeekOutputBuf(); | ||
|
||
// We are not done. Re-arm the async write until we drive it to completion or error. | ||
if (!buffer.empty()) { | ||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
return owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { | ||
HandleUpstreamAsyncWrite(write_result, buffer); | ||
}); | ||
} | ||
|
||
if (owner->engine_->OutputPending() > 0) { | ||
LOG(INFO) << "ssl buffer is not empty with " << owner->engine_->OutputPending() | ||
<< " bytes. short write detected"; | ||
} | ||
|
||
owner->state_ &= ~WRITE_IN_PROGRESS; | ||
|
||
// Run pending if blocked | ||
RunPending(); | ||
|
||
// If there is a continuation run it and let it yield back to the main loop. | ||
// Continuation is responsible for calling AsyncRequest::Run again. | ||
if (continuation) { | ||
auto cont = std::exchange(continuation, std::function<void()>{}); | ||
cont(); | ||
return; | ||
} | ||
|
||
// Yield back to main loop | ||
Run(); | ||
return; | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::StartUpstreamWrite() { | ||
Engine::Buffer buffer = owner->engine_->PeekOutputBuf(); | ||
DCHECK(!buffer.empty()); | ||
DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0); | ||
|
||
DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); | ||
// we do not allow concurrent writes from multiple fibers. | ||
owner->state_ |= WRITE_IN_PROGRESS; | ||
|
||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
|
||
owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { | ||
HandleUpstreamAsyncWrite(write_result, buffer); | ||
}); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::MaybeSendOutputAsync() { | ||
if (owner->engine_->OutputPending() == 0) { | ||
return; | ||
} | ||
|
||
// This function is present in both read and write paths. | ||
// meaning that both of them can be called concurrently from differrent fibers and then | ||
// race over flushing the output buffer. We use state_ to prevent that. | ||
if (owner->state_ & WRITE_IN_PROGRESS) { | ||
DCHECK(owner->pending_blocked_ == nullptr); | ||
owner->pending_blocked_ = this; | ||
return; | ||
} | ||
|
||
StartUpstreamWrite(); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::MaybeSendOutputAsyncWithRead() { | ||
if (owner->engine_->OutputPending() == 0) { | ||
return StartUpstreamRead(); | ||
} | ||
|
||
// This function is present in both read and write paths. | ||
// meaning that both of them can be called concurrently from differrent fibers and then | ||
// race over flushing the output buffer. We use state_ to prevent that. | ||
if (owner->state_ & WRITE_IN_PROGRESS) { | ||
DCHECK(owner->pending_blocked_ == nullptr); | ||
owner->pending_blocked_ = this; | ||
return; | ||
} | ||
|
||
DCHECK(!continuation); | ||
continuation = [this]() { | ||
// Yields to Run internally() | ||
StartUpstreamRead(); | ||
}; | ||
|
||
StartUpstreamWrite(); | ||
} | ||
|
||
auto TlsSocket::MaybeSendOutput() -> error_code { | ||
if (engine_->OutputPending() == 0) | ||
return {}; | ||
|
@@ -374,6 +606,48 @@ auto TlsSocket::MaybeSendOutput() -> error_code { | |
return HandleUpstreamWrite(); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::StartUpstreamRead() { | ||
if (owner->state_ & READ_IN_PROGRESS) { | ||
return; | ||
} | ||
|
||
auto buffer = owner->engine_->PeekInputBuf(); | ||
owner->state_ |= READ_IN_PROGRESS; | ||
|
||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
|
||
owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { | ||
owner->state_ &= ~READ_IN_PROGRESS; | ||
RunPending(); | ||
if (!read_result) { | ||
// log any errors as well as situations where we have unflushed output. | ||
if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { | ||
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) | ||
<< ", write_total:" << owner->upstream_write_ << " " | ||
<< " pending output: " << owner->engine_->OutputPending() << " " | ||
<< "StartUpstreamRead failed " << read_result.error(); | ||
} | ||
// Erronous path. Apply the completion callback and exit. | ||
CompleteAsyncReq(read_result); | ||
return; | ||
} | ||
|
||
DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; | ||
owner->engine_->CommitInput(*read_result); | ||
// We are not done. Give back control to the main loop. | ||
Run(); | ||
}); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::RunPending() { | ||
if (owner->pending_blocked_) { | ||
auto* blocked = std::exchange(owner->pending_blocked_, nullptr); | ||
blocked->Run(); | ||
} | ||
} | ||
|
||
auto TlsSocket::HandleUpstreamRead() -> error_code { | ||
RETURN_ON_ERROR(MaybeSendOutput()); | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch -- we now got unit tests though 😄