Skip to content

Commit 59cfdc9

Browse files
authored
Merge pull request capnproto#519 from zenhack/clienthook-snapshot
Clienthook snapshot
2 parents 2ad05d6 + b57e496 commit 59cfdc9

File tree

9 files changed

+249
-92
lines changed

9 files changed

+249
-92
lines changed

answer.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,9 @@ func (pc PipelineClient) Brand() Brand {
493493
r := mutex.With1(&pc.p.state, func(p *promiseState) resolution {
494494
return p.resolution(pc.p.method)
495495
})
496-
return r.client(pc.transform).State().Brand
496+
snapshot := r.client(pc.transform).Snapshot()
497+
defer snapshot.Release()
498+
return snapshot.Brand()
497499
default:
498500
return Brand{Value: pc}
499501
}

capability.go

Lines changed: 139 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -519,32 +519,17 @@ func (c Client) IsSame(c2 Client) bool {
519519
// Resolve only returns an error if the context is canceled; it returns nil even
520520
// if the capability resolves to an error.
521521
func (c Client) Resolve(ctx context.Context) error {
522-
for {
523-
h, resolved, released := c.startCall()
524-
defer h.Release()
525-
if released {
526-
return errors.New("cannot resolve released client")
527-
}
528-
529-
if resolved {
530-
return nil
531-
}
532-
533-
r, ok := h.Value().resolution.Get()
534-
if !ok {
535-
return nil
536-
}
537-
538-
resolvedCh := mutex.With1(r, func(s *resolveState) <-chan struct{} {
539-
return s.resolved
540-
})
541-
542-
select {
543-
case <-resolvedCh:
544-
case <-ctx.Done():
545-
return ctx.Err()
546-
}
522+
h, resolved, released := c.startCall()
523+
defer h.Release()
524+
if released {
525+
return errors.New("cannot resolve released client")
547526
}
527+
if resolved {
528+
return nil
529+
}
530+
h, err := resolveClientHook(ctx, h)
531+
h.Release()
532+
return err
548533
}
549534

550535
// AddRef creates a new Client that refers to the same capability as c.
@@ -577,39 +562,142 @@ func (c Client) WeakRef() WeakClient {
577562
return WeakClient{r: cursor}
578563
}
579564

580-
// State reads the current state of the client. It returns the zero
581-
// ClientState if c is nil, has resolved to null, or has been released.
582-
func (c Client) State() ClientState {
583-
h, resolved, _ := c.startCall()
584-
defer h.Release()
585-
if h == nil {
586-
return ClientState{}
587-
}
588-
return ClientState{
589-
Brand: h.Value().Brand(),
590-
IsPromise: !resolved,
591-
Metadata: &h.Value().metadata,
592-
}
565+
// Snapshot reads the current state of the client. It returns the zero
566+
// ClientSnapshot if c is nil, has resolved to null, or has been released.
567+
func (c Client) Snapshot() ClientSnapshot {
568+
h, _, _ := c.startCall()
569+
return ClientSnapshot{hook: h}
593570
}
594571

595572
// A Brand is an opaque value used to identify a capability.
596573
type Brand struct {
597574
Value any
598575
}
599576

600-
// ClientState is a snapshot of a client's identity.
601-
type ClientState struct {
602-
// Brand is the value returned from the hook's Brand method.
603-
Brand Brand
604-
// IsPromise is true if the client has not resolved yet.
605-
IsPromise bool
606-
// Arbitrary metadata. Note that, if a Client is a promise,
607-
// when it resolves its metadata will be replaced with that
608-
// of its resolution.
609-
//
610-
// TODO: this might change before the v3 API is stabilized;
611-
// we are not sure the above is the correct semantics.
612-
Metadata *Metadata
577+
// ClientSnapshot is a snapshot of a client's identity. If the Client
578+
// is a promise, then the corresponding ClientSnapshot will *not*
579+
// redirect to point at the resolution.
580+
type ClientSnapshot struct {
581+
hook *rc.Ref[clientHook]
582+
}
583+
584+
func (cs ClientSnapshot) IsValid() bool {
585+
return cs.hook.IsValid()
586+
}
587+
588+
// IsPromise returns true if the snapshot is a promise.
589+
func (cs ClientSnapshot) IsPromise() bool {
590+
if cs.hook == nil {
591+
return false
592+
}
593+
_, ret := cs.hook.Value().resolution.Get()
594+
return ret
595+
}
596+
597+
// Send implements ClientHook.Send
598+
func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
599+
return cs.hook.Value().Send(ctx, s)
600+
}
601+
602+
// Recv implements ClientHook.Recv
603+
func (cs ClientSnapshot) Recv(ctx context.Context, r Recv) PipelineCaller {
604+
return cs.hook.Value().Recv(ctx, r)
605+
}
606+
607+
// Client returns a client pointing at the most-resolved version of the snapshot.
608+
func (cs ClientSnapshot) Client() Client {
609+
cursor := rc.NewRefInPlace(func(c *clientCursor) func() {
610+
*c = clientCursor{hook: mutex.New(cs.hook.AddRef())}
611+
c.compress()
612+
return c.Release
613+
})
614+
c := Client{client: &client{
615+
state: mutex.New(clientState{cursor: cursor}),
616+
}}
617+
setupLeakReporting(c)
618+
return c
619+
}
620+
621+
// Brand is the value returned from the ClientHook's Brand method.
622+
// Returns the zero Brand if the receiver is the zero ClientSnapshot.
623+
func (cs ClientSnapshot) Brand() Brand {
624+
if cs.hook == nil {
625+
return Brand{}
626+
}
627+
return cs.hook.Value().Brand()
628+
}
629+
630+
// Return a the reference to the Metadata associated with this client hook.
631+
// Callers may store whatever they need here.
632+
func (cs ClientSnapshot) Metadata() *Metadata {
633+
return &cs.hook.Value().metadata
634+
}
635+
636+
// Create a copy of the snapshot, with its own underlying reference.
637+
func (cs ClientSnapshot) AddRef() ClientSnapshot {
638+
cs.hook = cs.hook.AddRef()
639+
return cs
640+
}
641+
642+
// Release the reference to the hook.
643+
func (cs ClientSnapshot) Release() {
644+
cs.hook.Release()
645+
}
646+
647+
func (cs *ClientSnapshot) Resolve1(ctx context.Context) error {
648+
var err error
649+
cs.hook, _, err = resolve1ClientHook(ctx, cs.hook)
650+
return err
651+
}
652+
653+
func (cs *ClientSnapshot) resolve1(ctx context.Context) (more bool, err error) {
654+
cs.hook, more, err = resolve1ClientHook(ctx, cs.hook)
655+
return
656+
}
657+
658+
func (cs *ClientSnapshot) Resolve(ctx context.Context) error {
659+
var err error
660+
cs.hook, err = resolveClientHook(ctx, cs.hook)
661+
return err
662+
}
663+
664+
func resolveClientHook(ctx context.Context, h *rc.Ref[clientHook]) (_ *rc.Ref[clientHook], err error) {
665+
for {
666+
var more bool
667+
h, more, err = resolve1ClientHook(ctx, h)
668+
if !more || err != nil {
669+
return h, err
670+
}
671+
}
672+
}
673+
674+
func resolve1ClientHook(ctx context.Context, h *rc.Ref[clientHook]) (_ *rc.Ref[clientHook], more bool, err error) {
675+
if !h.IsValid() {
676+
return h, false, nil
677+
}
678+
defer h.Release()
679+
680+
r, ok := h.Value().resolution.Get()
681+
if !ok {
682+
return h.AddRef(), false, nil
683+
}
684+
685+
resolvedCh := mutex.With1(r, func(s *resolveState) <-chan struct{} {
686+
return s.resolved
687+
})
688+
689+
select {
690+
case <-resolvedCh:
691+
rh := mutex.With1(r, func(r *resolveState) *rc.Ref[clientHook] {
692+
return r.resolvedHook
693+
})
694+
if rh == nil {
695+
return nil, false, nil
696+
}
697+
return rh.AddRef(), true, nil
698+
case <-ctx.Done():
699+
return h.AddRef(), true, ctx.Err()
700+
}
613701
}
614702

615703
// String returns a string that identifies this capability for debugging

capability_test.go

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
1213
)
1314

1415
func TestClient(t *testing.T) {
@@ -23,13 +24,14 @@ func TestClient(t *testing.T) {
2324
if !c.IsValid() {
2425
t.Error("new client is not valid")
2526
}
26-
state := c.State()
27-
if state.IsPromise {
27+
state := c.Snapshot()
28+
if state.IsPromise() {
2829
t.Error("c.State().IsPromise = true; want false")
2930
}
30-
if state.Brand.Value != int(42) {
31-
t.Errorf("c.State().Brand.Value = %#v; want 42", state.Brand.Value)
31+
if state.Brand().Value != int(42) {
32+
t.Errorf("c.State().Brand().Value = %#v; want 42", state.Brand().Value)
3233
}
34+
state.Release()
3335
ans, finish := c.SendCall(ctx, Send{})
3436
if _, err := ans.Struct(); err != nil {
3537
t.Error("SendCall:", err)
@@ -78,13 +80,14 @@ func TestReleasedClient(t *testing.T) {
7880
if c.IsValid() {
7981
t.Error("released client is valid")
8082
}
81-
state := c.State()
82-
if state.Brand.Value != nil {
83-
t.Errorf("c.State().Brand.Value = %#v; want <nil>", state.Brand.Value)
83+
state := c.Snapshot()
84+
if state.Brand().Value != nil {
85+
t.Errorf("c.Snapshot().Brand().Value = %#v; want <nil>", state.Brand().Value)
8486
}
85-
if state.IsPromise {
86-
t.Error("c.State().IsPromise = true; want false")
87+
if state.IsPromise() {
88+
t.Error("c.Snapshot().IsPromise = true; want false")
8789
}
90+
state.Release()
8891
ans, finish := c.SendCall(ctx, Send{})
8992
if _, err := ans.Struct(); err == nil {
9093
t.Error("SendCall did not return error")
@@ -116,6 +119,49 @@ func TestReleasedClient(t *testing.T) {
116119
t.Error("second Release made more calls to ClientHook.Shutdown")
117120
}
118121
}
122+
func TestResolve(t *testing.T) {
123+
test := func(t *testing.T, name string, f func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client])) {
124+
t.Run(name, func(t *testing.T) {
125+
t.Parallel()
126+
p1, r1 := NewLocalPromise[Client]()
127+
p2, r2 := NewLocalPromise[Client]()
128+
defer p1.Release()
129+
defer p2.Release()
130+
f(t, p1, p2, r1, r2)
131+
})
132+
}
133+
t.Run("Clients", func(t *testing.T) {
134+
test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
135+
r1.Fulfill(p2)
136+
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
137+
defer cancel()
138+
require.NotNil(t, p1.Resolve(ctx), "blocks on second promise")
139+
r2.Fulfill(Client{})
140+
require.NoError(t, p1.Resolve(context.Background()), "resolves after second resolution")
141+
assert.True(t, p1.IsSame(Client{}), "p1 resolves to null")
142+
assert.True(t, p2.IsSame(Client{}), "p2 resolves to null")
143+
assert.True(t, p1.IsSame(p2), "p1 & p2 are the same")
144+
})
145+
})
146+
t.Run("Snapshots", func(t *testing.T) {
147+
test(t, "Resolve1 only waits for one link", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
148+
s1 := p1.Snapshot()
149+
defer s1.Release()
150+
r1.Fulfill(p2)
151+
require.NoError(t, s1.Resolve1(context.Background()), "Resolve1 returns after first resolution")
152+
})
153+
test(t, "Resolve waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
154+
s1 := p1.Snapshot()
155+
defer s1.Release()
156+
r1.Fulfill(p2)
157+
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
158+
defer cancel()
159+
require.NotNil(t, s1.Resolve(ctx), "blocks on second promise")
160+
r2.Fulfill(Client{})
161+
require.NoError(t, s1.Resolve(context.Background()), "resolves after second resolution")
162+
})
163+
})
164+
}
119165

120166
func TestNullClient(t *testing.T) {
121167
ctx := context.Background()
@@ -141,13 +187,14 @@ func TestNullClient(t *testing.T) {
141187
if c.IsValid() {
142188
t.Error("null client is valid")
143189
}
144-
state := c.State()
145-
if state.Brand.Value != nil {
146-
t.Errorf("c.State().Brand = %#v; want <nil>", state.Brand)
190+
state := c.Snapshot()
191+
if state.Brand().Value != nil {
192+
t.Errorf("c.Snapshot().Brand() = %#v; want <nil>", state.Brand())
147193
}
148-
if state.IsPromise {
149-
t.Error("c.State().IsPromise = true; want false")
194+
if state.IsPromise() {
195+
t.Error("c.Snapshot().IsPromise = true; want false")
150196
}
197+
state.Release()
151198
ans, finish := c.SendCall(ctx, Send{})
152199
if _, err := ans.Struct(); err == nil {
153200
t.Error("SendCall did not return error")
@@ -186,13 +233,14 @@ func TestPromisedClient(t *testing.T) {
186233
if ca.IsSame(cb) {
187234
t.Error("before resolution, ca == cb")
188235
}
189-
state := ca.State()
190-
if state.Brand.Value != int(111) {
191-
t.Errorf("before resolution, ca.State().Brand.Value = %#v; want 111", state.Brand.Value)
236+
state := ca.Snapshot()
237+
if state.Brand().Value != int(111) {
238+
t.Errorf("before resolution, ca.Snapshot().Brand().Value = %#v; want 111", state.Brand().Value)
192239
}
193-
if !state.IsPromise {
194-
t.Error("before resolution, ca.State().IsPromise = false; want true")
240+
if !state.IsPromise() {
241+
t.Error("before resolution, ca.Snapshot().IsPromise = false; want true")
195242
}
243+
state.Release()
196244
_, finish := ca.SendCall(ctx, Send{})
197245
finish()
198246
pa.Fulfill(cb)
@@ -207,13 +255,14 @@ func TestPromisedClient(t *testing.T) {
207255
if !ca.IsSame(cb) {
208256
t.Errorf("after resolution, ca != cb (%v vs. %v)", ca, cb)
209257
}
210-
state = ca.State()
211-
if state.Brand.Value != int(222) {
212-
t.Errorf("after resolution, ca.State().Brand.Value = %#v; want 222", state.Brand.Value)
258+
state = ca.Snapshot()
259+
if state.Brand().Value != int(222) {
260+
t.Errorf("after resolution, ca.Snapshot().Brand().Value = %#v; want 222", state.Brand().Value)
213261
}
214-
if state.IsPromise {
215-
t.Error("after resolution, ca.State().IsPromise = true; want false")
262+
if state.IsPromise() {
263+
t.Error("after resolution, ca.Snapshot().IsPromise = true; want false")
216264
}
265+
state.Release()
217266

218267
if b.shutdowns > 0 {
219268
t.Error("b shut down before clients released")

captable_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ func TestCapTable(t *testing.T) {
2929

3030
errTest := errors.New("test")
3131
ct.Set(capnp.CapabilityID(0), capnp.ErrorClient(errTest))
32-
err := ct.At(0).State().Brand.Value.(error)
32+
snapshot := ct.At(0).Snapshot()
33+
defer snapshot.Release()
34+
err := snapshot.Brand().Value.(error)
3335
assert.ErrorIs(t, errTest, err, "should update client at index 0")
3436
}

0 commit comments

Comments
 (0)