From f6fc6c5d4debed16ea1017a2ec8ff071b08fb25a Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Fri, 31 Mar 2023 22:06:49 -0400 Subject: [PATCH 1/4] Simplify clientHook's refcounting scheme. There is no actual reason why we need separate refcounts for calls and clients. In preparation for generally unifying the refcounting in this library, I've decided we should just merge them. --- capability.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/capability.go b/capability.go index 67fde406..d41bb760 100644 --- a/capability.go +++ b/capability.go @@ -142,18 +142,21 @@ type clientHook struct { // Place for callers to attach arbitrary metadata to the client. metadata Metadata - // done is closed when refs == 0 and calls == 0. + // done is closed when refs == 0 done chan struct{} state mutex.Mutex[clientHookState] } type clientHookState struct { + // How many references there are to this clientHook. + // This includes both Clients that point to it and + // outstanding calls on it. + refs int + // resolved is closed after resolvedHook is set resolved chan struct{} - refs int // how many open Clients reference this clientHook - calls int // number of outstanding ClientHook accesses resolvedHook *clientHook // valid only if resolved is closed } @@ -232,14 +235,14 @@ func (c Client) startCall() (hook ClientHook, resolved, released bool, finish fu if c.h == nil { return nil, true, false, func() {} } - l.Value().calls++ + l.Value().refs++ isResolved := l.Value().isResolved() l.Unlock() savedHook := c.h return savedHook.ClientHook, isResolved, false, func() { savedHook.state.With(func(s *clientHookState) { - s.calls-- - if s.refs == 0 && s.calls == 0 { + s.refs-- + if s.refs == 0 { close(savedHook.done) } }) @@ -657,9 +660,7 @@ func (c Client) Release() { cl.Unlock() return } - if hl.Value().calls == 0 { - close(h.done) - } + close(h.done) hl.Unlock() cl.Unlock() <-h.done @@ -774,9 +775,7 @@ func (cp *clientPromise) fulfill(c Client) { } // Client still had references, so we're responsible for shutting it down. - if l.Value().calls == 0 { - close(cp.h.done) - } + close(cp.h.done) rh, l = resolveHook(cp.h, l) // swaps mutex on cp.h for mutex on rh if rh != nil { l.Value().refs += refs From 2815a6a7a27f62d1e1e9a0c5ee68b3fa29ad17a7 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Tue, 11 Apr 2023 22:33:43 -0400 Subject: [PATCH 2/4] Get rid of the bootstrapClient implementation of ClientHook. ...because fewer special cases are better. This also resolves everything but the documentation for #423. One of the tests had to be reworked a bit because the exact sequence of messages was a bit different, but still correct. --- rpc/level0_test.go | 50 +++++++++++++++++++++++++++++++++++++----- rpc/question.go | 17 ++------------- rpc/rpc.go | 54 ++++++++++------------------------------------ 3 files changed, 58 insertions(+), 63 deletions(-) diff --git a/rpc/level0_test.go b/rpc/level0_test.go index eb0340f8..a103274a 100644 --- a/rpc/level0_test.go +++ b/rpc/level0_test.go @@ -807,8 +807,35 @@ func TestSendBootstrapPipelineCall(t *testing.T) { } } - // 6. Release the client, read the finish. - client.Release() + // 6. Send back a return for the bootstrap message: + bootstrapExportID := uint32(99) + { + outMsg, err := p2.NewMessage() + require.NoError(t, err) + iface := capnp.NewInterface(outMsg.Message().Segment(), 0) + require.NoError(t, pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), + &rpcMessage{ + Which: rpccp.Message_Which_return, + Return: &rpcReturn{ + AnswerID: bootstrapQID, + Which: rpccp.Return_Which_results, + Results: &rpcPayload{ + Content: iface.ToPtr(), + CapTable: []rpcCapDescriptor{ + { + Which: rpccp.CapDescriptor_Which_senderHosted, + SenderHosted: bootstrapExportID, + }, + }, + }, + }, + }, + )) + require.NoError(t, outMsg.Send()) + outMsg.Release() + } + + // 7. Read the finish: { rmsg, release, err := recvMessage(ctx, p2) if err != nil { @@ -821,9 +848,22 @@ func TestSendBootstrapPipelineCall(t *testing.T) { if rmsg.Finish.QuestionID != bootstrapQID { t.Errorf("Received finish for question %d; want %d", rmsg.Finish.QuestionID, bootstrapQID) } - if !rmsg.Finish.ReleaseResultCaps { - t.Error("Received finish that does not release bootstrap") - } + require.False( + t, + rmsg.Finish.ReleaseResultCaps, + "Received finish that releases bootstrap (should receive separate releasemessage)", + ) + } + + // 8. Release the client, read the release message. + client.Release() + { + rmsg, release, err := recvMessage(ctx, p2) + require.NoError(t, err) + defer release() + require.Equal(t, rpccp.Message_Which_release, rmsg.Which) + require.Equal(t, bootstrapExportID, rmsg.Release.ID) + require.Equal(t, uint32(1), rmsg.Release.ReferenceCount) } } diff --git a/rpc/question.go b/rpc/question.go index 4a079a18..458ba229 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -15,8 +15,6 @@ type question struct { c *Conn id questionID - bootstrapPromise capnp.Resolver[capnp.Client] - p *capnp.Promise release capnp.ReleaseFunc // written before resolving p @@ -127,12 +125,7 @@ func (q *question) handleCancel(ctx context.Context) { q.c.er.ReportError(rpcerr.Annotate(err, "send finish")) } close(q.finishMsgSend) - q.p.Reject(rejectErr) - if q.bootstrapPromise != nil { - q.bootstrapPromise.Reject(rejectErr) - q.p.ReleaseClients() - } }) }) } @@ -278,14 +271,8 @@ func (q *question) mark(xform []capnp.PipelineOp) { } func (q *question) Reject(err error) { - if q != nil { - if q.bootstrapPromise != nil { - q.bootstrapPromise.Fulfill(capnp.ErrorClient(err)) - } - - if q.p != nil { - q.p.Reject(err) - } + if q != nil && q.p != nil { + q.p.Reject(err) } } diff --git a/rpc/rpc.go b/rpc/rpc.go index f624c44d..9f6936bf 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -308,12 +308,12 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) { } defer c.tasks.Done() - bootCtx, cancel := context.WithCancel(ctx) q := c.newQuestion(capnp.Method{}) - bc, q.bootstrapPromise = capnp.NewPromisedClient(bootstrapClient{ - c: q.p.Answer().Client().AddRef(), - cancel: cancel, - }) + bc = q.p.Answer().Client().AddRef() + go func() { + q.p.ReleaseClients() + q.release() + }() c.sendMessage(ctx, func(m rpccp.Message) error { boot, err := m.NewBootstrap() @@ -327,7 +327,7 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) { syncutil.With(&c.lk, func() { c.lk.questions[q.id] = nil }) - q.bootstrapPromise.Reject(exc.Annotate("rpc", "bootstrap", err)) + q.p.Reject(exc.Annotate("rpc", "bootstrap", err)) syncutil.With(&c.lk, func() { c.lk.questionID.remove(uint32(q.id)) }) @@ -337,7 +337,7 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) { c.tasks.Add(1) go func() { defer c.tasks.Done() - q.handleCancel(bootCtx) + q.handleCancel(ctx) }() }) @@ -345,32 +345,6 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) { }) } -type bootstrapClient struct { - c capnp.Client - cancel context.CancelFunc -} - -func (bc bootstrapClient) String() string { - return "bootstrapClient{c: " + bc.c.String() + "}" -} - -func (bc bootstrapClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp.ReleaseFunc) { - return bc.c.SendCall(ctx, s) -} - -func (bc bootstrapClient) Recv(ctx context.Context, r capnp.Recv) capnp.PipelineCaller { - return bc.c.RecvCall(ctx, r) -} - -func (bc bootstrapClient) Brand() capnp.Brand { - return bc.c.State().Brand -} - -func (bc bootstrapClient) Shutdown() { - bc.cancel() - bc.c.Release() -} - // Close sends an abort to the remote vat and closes the underlying // transport. func (c *Conn) Close() error { @@ -1151,9 +1125,9 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e c.er.ReportError(rpcerr.Annotate(pr.err, "incoming return")) } - if q.bootstrapPromise == nil && pr.err == nil { - // The result of the message contains actual data (not just a - // client or an error), so we save the ReleaseFunc for later: + if pr.err == nil { + // The result of the message contains actual data (not just + // an error), so we save the ReleaseFunc for later: q.release = in.Release } // We're going to potentially block fulfilling some promises so fork @@ -1161,13 +1135,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e go func() { c := unlockedConn q.p.Resolve(pr.result, pr.err) - if q.bootstrapPromise != nil { - q.bootstrapPromise.Fulfill(q.p.Answer().Client()) - q.p.ReleaseClients() - // We can release now; root pointer of the result is a client, so the - // message won't be accessed: - in.Release() - } else if pr.err != nil { + if pr.err != nil { // We can release now; the result is an error, so data from the message // won't be accessed: in.Release() From ad4f48e7b24ef3f32b8643359b4df5df987a5f21 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Tue, 11 Apr 2023 22:37:20 -0400 Subject: [PATCH 3/4] Correct documentation for ClientHook.Shutdown() Fixes #423 --- capability.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/capability.go b/capability.go index 278a2d0a..5fadf1b7 100644 --- a/capability.go +++ b/capability.go @@ -854,8 +854,8 @@ type ClientHook interface { // Shutdown releases any resources associated with this capability. // The behavior of calling any methods on the receiver after calling - // Shutdown is undefined. It is expected for the ClientHook to reject - // any outstanding call futures. + // Shutdown is undefined. Any already-outstanding calls should not + // be interrupted. Shutdown() // String formats the hook as a string (same as fmt.Stringer) From f9ad8800d47163c4d61d0b6d11602f0570a95d5c Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Thu, 13 Apr 2023 18:12:51 -0400 Subject: [PATCH 4/4] Remove ineffectual receive. This can't possibly have any effect, since we close the same channel immediately above. The close used to be conditional, so it presumably made more sense then. --- capability.go | 1 - 1 file changed, 1 deletion(-) diff --git a/capability.go b/capability.go index 1b88e10e..08506bab 100644 --- a/capability.go +++ b/capability.go @@ -659,7 +659,6 @@ func (c Client) Release() { close(h.done) hl.Unlock() cl.Unlock() - <-h.done h.Shutdown() c.GetFlowLimiter().Release() }