-
Notifications
You must be signed in to change notification settings - Fork 62
feat: TlsSocket AsyncWriteSome and AsyncReadSome #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
64e18d0
701425a
916678c
802a896
1b28ab8
8187dbc
0385a16
acc6ac2
3cbfc7a
e39ae8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -386,6 +386,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress | |
async_write_pending_ = 1; | ||
} | ||
|
||
// TODO implement async functionality | ||
void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. epoll |
||
auto res = ReadSome(v, len); | ||
cb(res); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -331,22 +331,245 @@ io::Result<TlsSocket::PushResult> TlsSocket::PushToEngine(const iovec* ptr, uint | |
return res; | ||
} | ||
|
||
// TODO: to implement async functionality. | ||
void TlsSocket::AsyncWriteReq::CompleteAsyncReq(io::Result<size_t> result) { | ||
auto current = std::exchange(owner->async_write_req_, std::nullopt); | ||
current->caller_completion_cb(result); | ||
} | ||
|
||
void TlsSocket::AsyncReadReq::CompleteAsyncReq(io::Result<size_t> result) { | ||
auto current = std::exchange(owner->async_read_req_, std::nullopt); | ||
current->caller_completion_cb(result); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::HandleOpAsync(int op_val) { | ||
switch (op_val) { | ||
case Engine::EOF_STREAM: | ||
break; | ||
case Engine::NEED_READ_AND_MAYBE_WRITE: | ||
MaybeSendOutputAsync(true); | ||
break; | ||
case Engine::NEED_WRITE: | ||
MaybeSendOutputAsync(); | ||
break; | ||
default: | ||
LOG(DFATAL) << "Unsupported " << op_val; | ||
} | ||
} | ||
|
||
void TlsSocket::AsyncWriteReq::Run() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what I am trying to understand - what is the path to the second loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are no fiber blocks, everything is async and yes we need that. An example would be:
|
||
if (state == AsyncWriteReq::PushToEngine) { | ||
// We never preempt here | ||
io::Result<PushResult> push_res = owner->PushToEngine(vec, len); | ||
if (!push_res) { | ||
CompleteAsyncReq(make_unexpected(push_res.error())); | ||
// We are done with this AsyncWriteReq. Caller might started | ||
// a new one. | ||
return; | ||
} | ||
last_push = *push_res; | ||
state = AsyncWriteReq::HandleOpAsyncTag; | ||
} | ||
|
||
if (state == AsyncWriteReq::HandleOpAsyncTag) { | ||
state = AsyncWriteReq::MaybeSendOutputAsyncTag; | ||
if (last_push.engine_opcode < 0) { | ||
if (last_push.engine_opcode == Engine::EOF_STREAM) { | ||
VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); | ||
CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); | ||
return; | ||
} | ||
HandleOpAsync(last_push.engine_opcode); | ||
return; | ||
} | ||
} | ||
|
||
if (state == AsyncWriteReq::MaybeSendOutputAsyncTag) { | ||
state = AsyncWriteReq::PushToEngine; | ||
if (last_push.written > 0) { | ||
continuation = [this]() { CompleteAsyncReq(last_push.written); }; | ||
MaybeSendOutputAsync(); | ||
return; | ||
} | ||
// Run again we are not done. | ||
Run(); | ||
} | ||
} | ||
|
||
void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
io::Result<size_t> res = WriteSome(v, len); | ||
cb(res); | ||
CHECK(!async_write_req_.has_value()); | ||
async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len)); | ||
async_write_req_->Run(); | ||
} | ||
|
||
void TlsSocket::AsyncReadReq::Run() { | ||
DCHECK_GT(len, 0u); | ||
|
||
while (true) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here as well, why do we need a loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to check this one |
||
DCHECK(!dest.empty()); | ||
|
||
size_t read_len = std::min(dest.size(), size_t(INT_MAX)); | ||
|
||
Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len); | ||
|
||
int op_val = op_result; | ||
|
||
DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; | ||
|
||
if (op_val > 0) { | ||
read_total += op_val; | ||
|
||
// I do not understand this code and what the hell I meant to do here. Seems to work | ||
// though. | ||
if (size_t(op_val) == read_len) { | ||
if (size_t(op_val) < dest.size()) { | ||
dest.remove_prefix(op_val); | ||
} else { | ||
++vec; | ||
--len; | ||
if (len == 0) { | ||
// We are done. Call completion callback. | ||
CompleteAsyncReq(read_total); | ||
return; | ||
} | ||
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(vec->iov_base), vec->iov_len}; | ||
} | ||
// We read everything we asked for but there are still buffers left to fill. | ||
continue; | ||
} | ||
break; | ||
} | ||
|
||
DCHECK(!continuation); | ||
if (op_val == Engine::EOF_STREAM) { | ||
VLOG(1) << "EOF_STREAM received " << owner->next_sock_->native_handle(); | ||
CompleteAsyncReq(make_unexpected(make_error_code(errc::connection_aborted))); | ||
} | ||
HandleOpAsync(op_val); | ||
return; | ||
} | ||
|
||
// We are done. Call completion callback. | ||
CompleteAsyncReq(read_total); | ||
} | ||
|
||
// TODO: to implement async functionality. | ||
void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { | ||
io::Result<size_t> res = ReadSome(v, len); | ||
cb(res); | ||
CHECK(!async_read_req_.has_value()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not how I would design it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, isn't that what I do here? We don't start/submit any async operation. We construct |
||
auto req = AsyncReadReq(this, std::move(cb), v, len); | ||
req.dest = {reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len}; | ||
async_read_req_.emplace(std::move(req)); | ||
async_read_req_->Run(); | ||
} | ||
|
||
SSL* TlsSocket::ssl_handle() { | ||
return engine_ ? engine_->native_handle() : nullptr; | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::HandleUpstreamAsyncWrite(io::Result<size_t> write_result, | ||
Engine::Buffer buffer) { | ||
if (!write_result) { | ||
owner->state_ &= ~WRITE_IN_PROGRESS; | ||
|
||
// broken_pipe - happens when the other side closes the connection. do not log this. | ||
if (write_result.error() != errc::broken_pipe) { | ||
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) | ||
<< ", write_total:" << owner->upstream_write_ << " " | ||
<< " pending output: " << owner->engine_->OutputPending() | ||
<< " HandleUpstreamAsyncWrite failed " << write_result.error(); | ||
} | ||
|
||
// We are done. Errornous exit. | ||
CompleteAsyncReq(write_result); | ||
return; | ||
} | ||
|
||
CHECK_GT(*write_result, 0u); | ||
owner->upstream_write_ += *write_result; | ||
owner->engine_->ConsumeOutputBuf(*write_result); | ||
buffer.remove_prefix(*write_result); | ||
|
||
// We are not done. Re-arm the async write until we drive it to completion or error. | ||
if (!buffer.empty()) { | ||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
return owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { | ||
HandleUpstreamAsyncWrite(write_result, buffer); | ||
}); | ||
} | ||
|
||
if (owner->engine_->OutputPending() > 0) { | ||
LOG(INFO) << "ssl buffer is not empty with " << owner->engine_->OutputPending() | ||
<< " bytes. short write detected"; | ||
} | ||
|
||
owner->state_ &= ~WRITE_IN_PROGRESS; | ||
|
||
// If there is a continuation run it and let it yield back to the main loop. | ||
// Continuation is responsible for calling AsyncRequest::Run again. | ||
if (continuation) { | ||
auto cont = std::exchange(continuation, std::function<void()>{}); | ||
cont(); | ||
return; | ||
} | ||
|
||
// Yield back to main loop | ||
Run(); | ||
return; | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::StartUpstreamWrite() { | ||
Engine::Buffer buffer = owner->engine_->PeekOutputBuf(); | ||
DCHECK(!buffer.empty()); | ||
DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0); | ||
|
||
DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); | ||
// we do not allow concurrent writes from multiple fibers. | ||
owner->state_ |= WRITE_IN_PROGRESS; | ||
|
||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
|
||
owner->next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { | ||
HandleUpstreamAsyncWrite(write_result, buffer); | ||
}); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::MaybeSendOutputAsync(bool should_read) { | ||
auto body = [should_read, this]() { | ||
if (should_read) { | ||
return StartUpstreamRead(); | ||
} | ||
if (continuation) { | ||
auto cont = std::exchange(continuation, std::function<void()>{}); | ||
return cont(); | ||
} | ||
return Run(); | ||
}; | ||
|
||
if (owner->engine_->OutputPending() == 0) { | ||
return body(); | ||
} | ||
|
||
// This function is present in both read and write paths. | ||
// meaning that both of them can be called concurrently from differrent fibers and then | ||
// race over flushing the output buffer. We use state_ to prevent that. | ||
if (owner->state_ & WRITE_IN_PROGRESS) { | ||
// TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise | ||
// we might deadlock. See the sync version of HandleOp for more info | ||
return body(); | ||
} | ||
|
||
if (should_read) { | ||
DCHECK(!continuation); | ||
continuation = [this]() { | ||
// Yields to Run internally() | ||
StartUpstreamRead(); | ||
}; | ||
} | ||
StartUpstreamWrite(); | ||
} | ||
|
||
auto TlsSocket::MaybeSendOutput() -> error_code { | ||
if (engine_->OutputPending() == 0) | ||
return {}; | ||
|
@@ -374,6 +597,42 @@ auto TlsSocket::MaybeSendOutput() -> error_code { | |
return HandleUpstreamWrite(); | ||
} | ||
|
||
void TlsSocket::AsyncReqBase::StartUpstreamRead() { | ||
if (owner->state_ & READ_IN_PROGRESS) { | ||
// TODO we should yield instead | ||
Run(); | ||
return; | ||
} | ||
|
||
auto buffer = owner->engine_->PeekInputBuf(); | ||
owner->state_ |= READ_IN_PROGRESS; | ||
|
||
auto& scratch = scratch_iovec; | ||
scratch.iov_base = const_cast<uint8_t*>(buffer.data()); | ||
scratch.iov_len = buffer.size(); | ||
|
||
owner->next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { | ||
owner->state_ &= ~READ_IN_PROGRESS; | ||
if (!read_result) { | ||
// log any errors as well as situations where we have unflushed output. | ||
if (read_result.error() != errc::connection_aborted || owner->engine_->OutputPending() > 0) { | ||
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_) | ||
<< ", write_total:" << owner->upstream_write_ << " " | ||
<< " pending output: " << owner->engine_->OutputPending() << " " | ||
<< "StartUpstreamRead failed " << read_result.error(); | ||
} | ||
// Erronous path. Apply the completion callback and exit. | ||
CompleteAsyncReq(read_result); | ||
return; | ||
} | ||
|
||
DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; | ||
owner->engine_->CommitInput(*read_result); | ||
// We are not done. Give back control to the main loop. | ||
Run(); | ||
}); | ||
} | ||
|
||
auto TlsSocket::HandleUpstreamRead() -> error_code { | ||
RETURN_ON_ERROR(MaybeSendOutput()); | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch -- we now got unit tests though 😄