Skip to content

Commit ccc6245

Browse files
fix(GODT-1586): Implement context cancellation for commands
Check for context cancellation for long running commands such as list, fetch, lsub and search. The command reading process has also been updated so that it returns the error in the channel. This way we can report the error back to the client.
1 parent bdbefef commit ccc6245

File tree

9 files changed

+127
-69
lines changed

9 files changed

+127
-69
lines changed

internal/session/command.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package session
22

33
import (
44
"context"
5-
"errors"
6-
"io"
75
"runtime/pprof"
86
"strconv"
97

@@ -14,39 +12,38 @@ import (
1412
type command struct {
1513
tag string
1614
cmd *proto.Command
15+
err error
1716
}
1817

19-
func (s *Session) getCommandCh(ctx context.Context, del string) <-chan command {
18+
func (s *Session) startCommandReader(ctx context.Context, del string) <-chan command {
2019
cmdCh := make(chan command)
2120

2221
go func() {
2322
labels := pprof.Labels("go", "CommandReader", "SessionID", strconv.Itoa(s.sessionID))
2423
pprof.Do(ctx, labels, func(_ context.Context) {
2524
defer close(cmdCh)
25+
2626
for {
2727
tag, cmd, err := s.readCommand(del)
28-
if err != nil {
29-
if errors.Is(err, io.EOF) {
30-
return
31-
} else if err := response.Bad(tag).WithError(err).Send(s); err != nil {
32-
return
28+
29+
if err == nil && cmd.GetStartTLS() != nil {
30+
// TLS needs to be handled here in order to ensure that next command read is over the
31+
// tls connection.
32+
if e := s.handleStartTLS(tag, cmd.GetStartTLS()); e != nil {
33+
cmd = nil
34+
err = e
35+
} else {
36+
continue
3337
}
3438

35-
continue
3639
}
3740

38-
switch {
39-
case cmd.GetStartTLS() != nil:
40-
if err := s.handleStartTLS(tag, cmd.GetStartTLS()); err != nil {
41-
if err := response.Bad(tag).WithError(err).Send(s); err != nil {
42-
return
43-
}
41+
select {
42+
case cmdCh <- command{tag: tag, cmd: cmd, err: err}:
4443

45-
continue
46-
}
44+
case <-ctx.Done():
45+
return
4746

48-
default:
49-
cmdCh <- command{tag: tag, cmd: cmd}
5047
}
5148
}
5249
})

internal/session/handle_list.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ func (s *Session) handleList(ctx context.Context, tag string, cmd *proto.List, c
1717

1818
return s.state.List(ctx, cmd.GetReference(), nameUTF8, false, func(matches map[string]state.Match) error {
1919
for _, match := range matches {
20-
ch <- response.List().
20+
select {
21+
case ch <- response.List().
2122
WithName(match.Name).
2223
WithDelimiter(match.Delimiter).
23-
WithAttributes(match.Atts)
24+
WithAttributes(match.Atts):
25+
26+
case <-ctx.Done():
27+
return ctx.Err()
28+
}
2429
}
2530

2631
ch <- response.Ok(tag).WithMessage("LIST")

internal/session/handle_lsub.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ func (s *Session) handleLsub(ctx context.Context, tag string, cmd *proto.Lsub, c
1717

1818
return s.state.List(ctx, cmd.GetReference(), nameUTF8, true, func(matches map[string]state.Match) error {
1919
for _, match := range matches {
20-
ch <- response.Lsub().
20+
select {
21+
case ch <- response.Lsub().
2122
WithName(match.Name).
2223
WithDelimiter(match.Delimiter).
23-
WithAttributes(match.Atts)
24+
WithAttributes(match.Atts):
25+
26+
case <-ctx.Done():
27+
return ctx.Err()
28+
}
2429
}
2530

2631
ch <- response.Ok(tag).WithMessage("LSUB")

internal/session/handle_search.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ func (s *Session) handleSearch(ctx context.Context, tag string, cmd *proto.Searc
3131
return err
3232
}
3333

34-
ch <- response.Search(seq...)
34+
select {
35+
case ch <- response.Search(seq...):
36+
37+
case <-ctx.Done():
38+
return ctx.Err()
39+
}
3540

3641
var items []response.Item
3742

internal/session/session.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package session
55
import (
66
"context"
77
"crypto/tls"
8+
"errors"
89
"fmt"
910
"io"
1011
"net"
@@ -114,15 +115,15 @@ func (s *Session) Serve(ctx context.Context) error {
114115
return err
115116
}
116117

117-
if err := s.serve(ctx, s.getCommandCh(ctx, s.backend.GetDelimiter())); err != nil {
118+
if err := s.serve(ctx); err != nil {
118119
logrus.WithError(err).Errorf("Failed to serve session %v", s.sessionID)
119120
return err
120121
}
121122

122123
return nil
123124
}
124125

125-
func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
126+
func (s *Session) serve(ctx context.Context) error {
126127
ctx, cancel := context.WithCancel(ctx)
127128
defer cancel()
128129

@@ -134,15 +135,31 @@ func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
134135
cmd *proto.Command
135136
)
136137

138+
cmdCh := s.startCommandReader(ctx, s.backend.GetDelimiter())
139+
137140
for {
138141
select {
139142
case res, ok := <-cmdCh:
140143
if !ok {
144+
logrus.Debugf("Failed to read from command channel")
141145
return nil
142146
}
143147

144148
tag, cmd = res.tag, res.cmd
145149

150+
if res.err != nil {
151+
logrus.WithError(res.err).Debugf("Error during command parsing")
152+
153+
if errors.Is(res.err, io.EOF) {
154+
logrus.Debugf("Connection to client lost")
155+
return nil
156+
} else if err := response.Bad(tag).WithError(res.err).Send(s); err != nil {
157+
return err
158+
}
159+
160+
continue
161+
}
162+
146163
case <-s.state.Done():
147164
return nil
148165

@@ -178,12 +195,15 @@ func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
178195
responseCh := s.handleOther(withStartTime(ctx, time.Now()), tag, cmd, profiler)
179196
for res := range responseCh {
180197
if err := res.Send(s); err != nil {
181-
// Consume all remaining channel response since the connection is no longer available.
182-
// Failing to do so can cause a deadlock in the program as `s.handleOther` never finishes
183-
// executing and can hold onto a number of locks indefinitely.
184-
for range responseCh {
185-
// ...
186-
}
198+
go func() {
199+
// Consume all remaining channel response since the connection is no longer available.
200+
// Failing to do so can cause a deadlock in the program as `s.handleOther` never finishes
201+
// executing and can hold onto a number of locks indefinitely.
202+
// Consumed on a separate go routine to not block the return.
203+
for range responseCh {
204+
// ...
205+
}
206+
}()
187207

188208
return fmt.Errorf("failed to send response to client: %w", err)
189209
}

internal/state/mailbox_fetch.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes
2929
return err
3030
}
3131

32-
ch <- response.Fetch(seq).WithItems(items...)
32+
select {
33+
case ch <- response.Fetch(seq).WithItems(items...):
34+
35+
case <-ctx.Done():
36+
return ctx.Err()
37+
}
3338
}
3439

3540
return nil

0 commit comments

Comments
 (0)