Skip to content

Commit d625ab8

Browse files
authored
Refactor process state management (#70) (#73)
* add isValidStateTransition helper function * Replace Process.setState() with Process.swapState() * Refactor locking logic in Process
1 parent a3f82c1 commit d625ab8

File tree

2 files changed

+119
-128
lines changed

2 files changed

+119
-128
lines changed

proxy/process.go

+93-107
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ const (
3030
)
3131

3232
type Process struct {
33-
ID string
34-
config ModelConfig
35-
cmd *exec.Cmd
36-
logMonitor *LogMonitor
37-
healthCheckTimeout int
33+
ID string
34+
config ModelConfig
35+
cmd *exec.Cmd
36+
logMonitor *LogMonitor
37+
38+
healthCheckTimeout int
39+
healthCheckLoopInterval time.Duration
3840

3941
lastRequestHandled time.Time
4042

@@ -54,51 +56,57 @@ type Process struct {
5456
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
5557
ctx, cancel := context.WithCancel(context.Background())
5658
return &Process{
57-
ID: ID,
58-
config: config,
59-
cmd: nil,
60-
logMonitor: logMonitor,
61-
healthCheckTimeout: healthCheckTimeout,
62-
state: StateStopped,
63-
shutdownCtx: ctx,
64-
shutdownCancel: cancel,
59+
ID: ID,
60+
config: config,
61+
cmd: nil,
62+
logMonitor: logMonitor,
63+
healthCheckTimeout: healthCheckTimeout,
64+
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
65+
state: StateStopped,
66+
shutdownCtx: ctx,
67+
shutdownCancel: cancel,
6568
}
6669
}
6770

68-
func (p *Process) setState(newState ProcessState) error {
69-
// enforce valid state transitions
70-
invalidTransition := false
71-
if p.state == StateStopped {
72-
// stopped -> starting
73-
if newState != StateStarting {
74-
invalidTransition = true
75-
}
76-
} else if p.state == StateStarting {
77-
// starting -> ready | failed | stopping
78-
if newState != StateReady && newState != StateFailed && newState != StateStopping {
79-
invalidTransition = true
80-
}
81-
} else if p.state == StateReady {
82-
// ready -> stopping
83-
if newState != StateStopping {
84-
invalidTransition = true
85-
}
86-
} else if p.state == StateStopping {
87-
// stopping -> stopped | shutdown
88-
if newState != StateStopped && newState != StateShutdown {
89-
invalidTransition = true
90-
}
91-
} else if p.state == StateFailed || p.state == StateShutdown {
92-
invalidTransition = true
71+
// custom error types for swapping state
72+
var (
73+
ErrExpectedStateMismatch = errors.New("expected state mismatch")
74+
ErrInvalidStateTransition = errors.New("invalid state transition")
75+
)
76+
77+
// swapState performs a compare and swap of the state atomically. It returns the current state
78+
// and an error if the swap failed.
79+
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
80+
p.stateMutex.Lock()
81+
defer p.stateMutex.Unlock()
82+
83+
if p.state != expectedState {
84+
return p.state, ErrExpectedStateMismatch
9385
}
9486

95-
if invalidTransition {
96-
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
97-
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
87+
if !isValidTransition(p.state, newState) {
88+
return p.state, ErrInvalidStateTransition
9889
}
9990

10091
p.state = newState
101-
return nil
92+
return p.state, nil
93+
}
94+
95+
// Helper function to encapsulate transition rules
96+
func isValidTransition(from, to ProcessState) bool {
97+
switch from {
98+
case StateStopped:
99+
return to == StateStarting
100+
case StateStarting:
101+
return to == StateReady || to == StateFailed || to == StateStopping
102+
case StateReady:
103+
return to == StateStopping
104+
case StateStopping:
105+
return to == StateStopped || to == StateShutdown
106+
case StateFailed, StateShutdown:
107+
return false // No transitions allowed from these states
108+
}
109+
return false
102110
}
103111

104112
func (p *Process) CurrentState() ProcessState {
@@ -116,65 +124,48 @@ func (p *Process) start() error {
116124
return fmt.Errorf("can not start(), upstream proxy missing")
117125
}
118126

119-
// multiple start() calls will wait for the one that is actually starting to
120-
// complete before proceeding.
121-
// ===========
122-
curState := p.CurrentState()
123-
124-
if curState == StateReady {
125-
return nil
126-
}
127-
128-
if curState == StateStarting {
129-
p.waitStarting.Wait()
130-
131-
if state := p.CurrentState(); state != StateReady {
132-
return fmt.Errorf("start() failed current state: %v", state)
133-
}
134-
135-
return nil
127+
args, err := p.config.SanitizedCommand()
128+
if err != nil {
129+
return fmt.Errorf("unable to get sanitized command: %v", err)
136130
}
137-
// ===========
138-
139-
// There is the possibility of a hard to replicate race condition where
140-
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
141-
// below, it's value has changed!
142-
143-
p.stateMutex.Lock()
144-
defer p.stateMutex.Unlock()
145131

146-
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
147-
// to transition from to StateReady
148-
149-
if p.state != StateStopped {
150-
if p.state == StateReady {
151-
return nil
132+
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
133+
if err == ErrExpectedStateMismatch {
134+
// already starting, just wait for it to complete and expect
135+
// it to be be in the Ready start after. If not, return an error
136+
if curState == StateStarting {
137+
p.waitStarting.Wait()
138+
if state := p.CurrentState(); state == StateReady {
139+
return nil
140+
} else {
141+
return fmt.Errorf("process was already starting but wound up in state %v", state)
142+
}
143+
} else {
144+
return fmt.Errorf("processes was in state %v when start() was called", curState)
145+
}
152146
} else {
153-
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
147+
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
154148
}
155149
}
156150

157-
if err := p.setState(StateStarting); err != nil {
158-
return err
159-
}
160-
161151
p.waitStarting.Add(1)
162152
defer p.waitStarting.Done()
163153

164-
args, err := p.config.SanitizedCommand()
165-
if err != nil {
166-
return fmt.Errorf("unable to get sanitized command: %v", err)
167-
}
168-
169154
p.cmd = exec.Command(args[0], args[1:]...)
170155
p.cmd.Stdout = p.logMonitor
171156
p.cmd.Stderr = p.logMonitor
172157
p.cmd.Env = p.config.Env
173158

174159
err = p.cmd.Start()
175160

161+
// Set process state to failed
176162
if err != nil {
177-
p.setState(StateFailed)
163+
if curState, swapErr := p.swapState(StateStarting, StateFailed); err != nil {
164+
return fmt.Errorf(
165+
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
166+
err, curState, swapErr,
167+
)
168+
}
178169
return fmt.Errorf("start() failed: %v", err)
179170
}
180171

@@ -209,13 +200,16 @@ func (p *Process) start() error {
209200
)
210201
defer cancelHealthCheck()
211202

212-
// Health check loop
213203
loop:
204+
// Ready Check loop
214205
for {
215206
select {
216207
case <-checkDeadline.Done():
217-
p.setState(StateFailed)
218-
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
208+
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
209+
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
210+
} else {
211+
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
212+
}
219213
case <-p.shutdownCtx.Done():
220214
return errors.New("health check interrupted due to shutdown")
221215
default:
@@ -233,7 +227,7 @@ func (p *Process) start() error {
233227
}
234228
}
235229

236-
<-time.After(5 * time.Second)
230+
<-time.After(p.healthCheckLoopInterval)
237231
}
238232
}
239233

@@ -244,7 +238,7 @@ func (p *Process) start() error {
244238
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
245239

246240
for range time.Tick(time.Second) {
247-
if p.state != StateReady {
241+
if p.CurrentState() != StateReady {
248242
return
249243
}
250244

@@ -260,46 +254,38 @@ func (p *Process) start() error {
260254
}()
261255
}
262256

263-
return p.setState(StateReady)
257+
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
258+
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
259+
} else {
260+
return nil
261+
}
264262
}
265263

266264
func (p *Process) Stop() {
267265
// wait for any inflight requests before proceeding
268266
p.inFlightRequests.Wait()
269-
p.stateMutex.Lock()
270-
defer p.stateMutex.Unlock()
271267

272268
// calling Stop() when state is invalid is a no-op
273-
if err := p.setState(StateStopping); err != nil {
274-
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
269+
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
270+
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
275271
return
276272
}
277273

278274
// stop the process with a graceful exit timeout
279275
p.stopCommand(5 * time.Second)
280276

281-
if err := p.setState(StateStopped); err != nil {
282-
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
277+
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
278+
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
283279
}
284280
}
285281

286282
// Shutdown is called when llama-swap is shutting down. It will give a little bit
287283
// of time for any inflight requests to complete before shutting down. If the Process
288284
// is in the state of starting, it will cancel it and shut it down
289285
func (p *Process) Shutdown() {
290-
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
291286
p.shutdownCancel()
292-
293-
p.stateMutex.Lock()
294-
defer p.stateMutex.Unlock()
295-
p.setState(StateStopping)
296-
297-
// 5 seconds to stop the process
298287
p.stopCommand(5 * time.Second)
299-
if err := p.setState(StateShutdown); err != nil {
300-
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
301-
}
302-
p.setState(StateShutdown)
288+
p.state = StateShutdown
303289
}
304290

305291
// stopCommand will send a SIGTERM to the process and wait for it to exit.

proxy/process_test.go

+26-21
Original file line numberDiff line numberDiff line change
@@ -225,30 +225,32 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
225225
}
226226
}
227227

228-
func TestSetState(t *testing.T) {
228+
func TestProcess_SwapState(t *testing.T) {
229229
tests := []struct {
230230
name string
231231
currentState ProcessState
232+
expectedState ProcessState
232233
newState ProcessState
233234
expectedError error
234235
expectedResult ProcessState
235236
}{
236-
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
237-
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
238-
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
239-
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
240-
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
241-
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
242-
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
243-
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
244-
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
245-
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
246-
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
247-
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
248-
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
249-
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
250-
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
251-
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
237+
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
238+
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
239+
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
240+
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
241+
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
242+
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
243+
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
244+
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
245+
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
246+
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
247+
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
248+
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
249+
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
250+
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
251+
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
252+
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
253+
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
252254
}
253255

254256
for _, test := range tests {
@@ -257,7 +259,7 @@ func TestSetState(t *testing.T) {
257259
state: test.currentState,
258260
}
259261

260-
err := p.setState(test.newState)
262+
resultState, err := p.swapState(test.expectedState, test.newState)
261263
if err != nil && test.expectedError == nil {
262264
t.Errorf("Unexpected error: %v", err)
263265
} else if err == nil && test.expectedError != nil {
@@ -268,8 +270,8 @@ func TestSetState(t *testing.T) {
268270
}
269271
}
270272

271-
if p.state != test.expectedResult {
272-
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
273+
if resultState != test.expectedResult {
274+
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
273275
}
274276
})
275277
}
@@ -290,11 +292,14 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
290292
healthCheckTTLSeconds := 30
291293
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
292294

295+
// make it a lot faster
296+
process.healthCheckLoopInterval = time.Second
297+
293298
// start a goroutine to simulate a shutdown
294299
var wg sync.WaitGroup
295300
go func() {
296301
defer wg.Done()
297-
<-time.After(time.Second * 2)
302+
<-time.After(time.Millisecond * 500)
298303
process.Shutdown()
299304
}()
300305
wg.Add(1)

0 commit comments

Comments
 (0)