diff --git a/execution/graphql/result_writer.go b/execution/graphql/result_writer.go index dc1ce288dc..f97cbf721b 100644 --- a/execution/graphql/result_writer.go +++ b/execution/graphql/result_writer.go @@ -35,6 +35,10 @@ func (e *EngineResultWriter) Complete() { } +func (e *EngineResultWriter) Heartbeat() error { + return nil +} + func (e *EngineResultWriter) Close(_ resolve.SubscriptionCloseKind) { } diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/event_loop_test.go index 7acb69ca83..11389630a9 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/event_loop_test.go @@ -51,6 +51,14 @@ func (f *FakeSubscriptionWriter) Complete() { f.messageCountOnComplete = len(f.writtenMessages) } +// Heartbeat writes directly to the writtenMessages slice, as the real implementations implicitly flush +func (f *FakeSubscriptionWriter) Heartbeat() error { + f.mu.Lock() + defer f.mu.Unlock() + f.writtenMessages = append(f.writtenMessages, string("heartbeat")) + return nil +} + func (f *FakeSubscriptionWriter) Close(SubscriptionCloseKind) { f.mu.Lock() defer f.mu.Unlock() @@ -123,7 +131,7 @@ func TestEventLoop(t *testing.T) { SubgraphErrorPropagationMode: SubgraphErrorPropagationModePassThrough, DefaultErrorExtensionCode: "TEST", MaxRecyclableParserSize: 1024 * 1024, - MultipartSubHeartbeatInterval: DefaultHeartbeatInterval, + SubscriptionHeartbeatInterval: DefaultHeartbeatInterval, Reporter: testReporter, }) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 0d0bed5f53..20a606fce1 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -21,10 +21,6 @@ const ( DefaultHeartbeatInterval = 5 * time.Second ) -var ( - multipartHeartbeat = []byte("{}") -) - // ConnectionIDs is used to create unique connection IDs for each subscription // Whenever a new connection is created, use this to generate a new ID // It is public because it can be used in more high level packages to instantiate a new connection @@ -69,7 +65,7 @@ type Resolver struct { propagateSubgraphErrors bool propagateSubgraphStatusCodes bool - // Multipart heartbeat interval + // Subscription heartbeat interval heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration @@ -143,8 +139,8 @@ type ResolverOptions struct { ResolvableOptions ResolvableOptions // AllowedCustomSubgraphErrorFields defines which fields are allowed in the subgraph error when in passthrough mode AllowedSubgraphErrorFields []string - // MultipartSubHeartbeatInterval defines the interval in which a heartbeat is sent to all multipart subscriptions - MultipartSubHeartbeatInterval time.Duration + // SubscriptionHeartbeatInterval defines the interval in which a heartbeat is sent to all subscriptions (whether or not this does anything is determined by the subscription response writer) + SubscriptionHeartbeatInterval time.Duration // MaxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out MaxSubscriptionFetchTimeout time.Duration // ApolloRouterCompatibilitySubrequestHTTPError is a compatibility flag for Apollo Router, it is used to handle HTTP errors in subrequests differently @@ -158,8 +154,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { options.MaxConcurrency = 32 } - if options.MultipartSubHeartbeatInterval <= 0 { - options.MultipartSubHeartbeatInterval = DefaultHeartbeatInterval + if options.SubscriptionHeartbeatInterval <= 0 { + options.SubscriptionHeartbeatInterval = DefaultHeartbeatInterval } // We transform the allowed fields into a map for faster lookups @@ -202,7 +198,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), allowedErrorExtensionFields: allowedExtensionFields, allowedErrorFields: allowedErrorFields, - heartbeatInterval: options.MultipartSubHeartbeatInterval, + heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) @@ -310,8 +306,8 @@ func (s *sub) startWorker() { s.startWorkerWithoutHeartbeat() } -// startWorkerWithHeartbeat is similar to startWorker but sends heartbeats to the client when -// subscription over multipart is used. It sends a heartbeat to the client every heartbeatInterval. +// startWorkerWithHeartbeat is similar to startWorker but sends heartbeats to the client when enabled. +// It sends a heartbeat to the client every heartbeatInterval. Heartbeats are handled by the SubscriptionResponseWriter interface. // TODO: Implement a shared timer implementation to avoid creating a new ticker for each subscription. func (s *sub) startWorkerWithHeartbeat() { heartbeatTicker := time.NewTicker(s.resolver.heartbeatInterval) @@ -330,7 +326,7 @@ func (s *sub) startWorkerWithHeartbeat() { return case <-heartbeatTicker.C: - s.resolver.handleHeartbeat(s, multipartHeartbeat) + s.resolver.handleHeartbeat(s) case work := <-s.workChan: work.fn() @@ -501,7 +497,7 @@ func (r *Resolver) handleEvent(event subscriptionEvent) { } // handleHeartbeat sends a heartbeat to the client. It needs to be executed on the same goroutine as the writer. -func (r *Resolver) handleHeartbeat(sub *sub, data []byte) { +func (r *Resolver) handleHeartbeat(sub *sub) { if r.options.Debug { fmt.Printf("resolver:heartbeat\n") } @@ -518,24 +514,16 @@ func (r *Resolver) handleHeartbeat(sub *sub, data []byte) { fmt.Printf("resolver:heartbeat:subscription:%d\n", sub.id.SubscriptionID) } - if _, err := sub.writer.Write(data); err != nil { - if errors.Is(err, context.Canceled) { - // If Write fails (e.g. client disconnected), remove the subscription. - _ = r.AsyncUnsubscribeSubscription(sub.id) - return - } - r.asyncErrorWriter.WriteError(sub.ctx, err, nil, sub.writer) - } - err := sub.writer.Flush() - if err != nil { - // If flush fails (e.g. client disconnected), remove the subscription. + if err := sub.writer.Heartbeat(); err != nil { + // If heartbeat fails (e.g. client disconnected), remove the subscription. _ = r.AsyncUnsubscribeSubscription(sub.id) return } if r.options.Debug { - fmt.Printf("resolver:heartbeat:subscription:flushed:%d\n", sub.id.SubscriptionID) + fmt.Printf("resolver:heartbeat:subscription:done:%d\n", sub.id.SubscriptionID) } + if r.reporter != nil { r.reporter.SubscriptionUpdateSent() } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 67b83ab20d..cc0d162710 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -86,7 +86,7 @@ func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLRespon } } -var multipartSubHeartbeatInterval = 100 * time.Millisecond +var subscriptionHeartbeatInterval = 100 * time.Millisecond func newResolver(ctx context.Context) *Resolver { return New(ctx, ResolverOptions{ @@ -95,7 +95,7 @@ func newResolver(ctx context.Context) *Resolver { PropagateSubgraphErrors: true, PropagateSubgraphStatusCodes: true, AsyncErrorWriter: &TestErrorWriter{}, - MultipartSubHeartbeatInterval: multipartSubHeartbeatInterval, + SubscriptionHeartbeatInterval: subscriptionHeartbeatInterval, }) } @@ -4777,6 +4777,13 @@ func (s *SubscriptionRecorder) Complete() { s.complete.Store(true) } +func (s *SubscriptionRecorder) Heartbeat() error { + s.mux.Lock() + defer s.mux.Unlock() + s.messages = append(s.messages, "heartbeat") + return nil +} + func (s *SubscriptionRecorder) Close(_ SubscriptionCloseKind) { s.closed.Store(true) } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index 926508f5ca..d340908a86 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -67,6 +67,7 @@ type SubscriptionResponseWriter interface { ResponseWriter Flush() error Complete() + Heartbeat() error Close(kind SubscriptionCloseKind) }