@@ -126,27 +126,82 @@ func (c *simpleServiceClient) Ping(
126
126
return out , nil
127
127
}
128
128
129
+ // CustomDialer implements the dialer function with controlled delays
130
+ type CustomDialer struct {
131
+ // Map of address to delay before connection
132
+ delays map [string ]time.Duration
133
+ // Mutex for thread safety
134
+ mu sync.Mutex
135
+ // Keeps track of dial attempt count
136
+ dialAttempts map [string ]int
137
+ }
138
+
139
+ // DialContext is used by gRPC to establish connections
140
+ func (d * CustomDialer ) DialContext (ctx context.Context , addr string ) (net.Conn , error ) {
141
+ d .mu .Lock ()
142
+ delay , exists := d .delays [addr ]
143
+ d .dialAttempts [addr ]++
144
+ attemptCount := d .dialAttempts [addr ]
145
+ d .mu .Unlock ()
146
+
147
+ // Log the dial attempt
148
+ fmt .Printf ("Attempting to dial %s (attempt #%d)\n " , addr , attemptCount )
149
+
150
+ if exists && delay > 0 {
151
+ // Simulating connection delay or timeout
152
+ fmt .Printf ("Simulating delay of %v for %s\n " , delay , addr )
153
+
154
+ select {
155
+ case <- time .After (delay ):
156
+ // If this is a simulated failure, return error
157
+ if delay >= 2 * time .Second {
158
+ fmt .Printf ("Connection to %s timed out after %v\n " , addr , delay )
159
+
160
+ return nil , fmt .Errorf ("connection timeout" )
161
+ }
162
+ case <- ctx .Done ():
163
+ return nil , ctx .Err ()
164
+ }
165
+ }
166
+
167
+ // Establish a real connection to the address
168
+ dialer := & net.Dialer {}
169
+
170
+ return dialer .DialContext (ctx , "tcp" , addr )
171
+ }
172
+
173
+ // GetDialAttempts returns the number of dial attempts for an address
174
+ func (d * CustomDialer ) GetDialAttempts (addr string ) int {
175
+ d .mu .Lock ()
176
+ defer d .mu .Unlock ()
177
+
178
+ return d .dialAttempts [addr ]
179
+ }
180
+
129
181
// TestGRPCLoadBalancingPolicies tests how different load balancing policies behave
130
182
// This is a test function, so we can ignore the staticcheck warnings about deprecated methods
131
183
// as we need to use these specific gRPC APIs for testing the load balancing behavior.
132
184
//
133
185
//nolint:staticcheck
134
186
func TestGRPCLoadBalancingPolicies (t * testing.T ) {
135
- // Start several real gRPC servers with different characteristics
136
- servers := make ([]* simpleServer , 3 )
137
- listeners := make ([]net.Listener , 3 )
138
- grpcServers := make ([]* grpc.Server , 3 )
139
- addresses := make ([]string , 3 )
140
-
141
- // Create servers with different behaviors
142
- for i := 0 ; i < 3 ; i ++ {
143
- // First server has a delay, others respond immediately
144
- delay := time .Duration (0 )
145
- if i == 0 {
146
- delay = 500 * time .Millisecond
147
- }
187
+ // Total number of servers to test
188
+ const totalServers = 6
189
+
190
+ // Setup servers
191
+ servers := make ([]* simpleServer , totalServers )
192
+ listeners := make ([]net.Listener , totalServers )
193
+ grpcServers := make ([]* grpc.Server , totalServers )
194
+ addresses := make ([]string , totalServers )
195
+
196
+ // Custom dialer with controlled delays
197
+ dialer := & CustomDialer {
198
+ delays : make (map [string ]time.Duration ),
199
+ dialAttempts : make (map [string ]int ),
200
+ }
148
201
149
- servers [i ] = & simpleServer {delay : delay }
202
+ // Start all servers
203
+ for i := 0 ; i < totalServers ; i ++ {
204
+ servers [i ] = & simpleServer {}
150
205
grpcServers [i ] = grpc .NewServer ()
151
206
RegisterSimpleServiceServer (grpcServers [i ], servers [i ])
152
207
@@ -158,14 +213,24 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
158
213
listeners [i ] = lis
159
214
addresses [i ] = lis .Addr ().String ()
160
215
216
+ // First 4 servers will have a "connection delay" of 2.5 seconds, simulating timeout
217
+ if i < 4 {
218
+ dialer.delays [addresses [i ]] = 2500 * time .Millisecond
219
+ } else {
220
+ // Last two servers connect quickly
221
+ dialer.delays [addresses [i ]] = 50 * time .Millisecond
222
+ }
223
+
224
+ t .Logf ("Started server %d at %s with delay %v" , i , addresses [i ], dialer.delays [addresses [i ]])
225
+
161
226
go func (gs * grpc.Server , l net.Listener ) {
162
227
_ = gs .Serve (l )
163
228
}(grpcServers [i ], lis )
164
229
}
165
230
166
231
// Cleanup after test
167
232
defer func () {
168
- for i := 0 ; i < 3 ; i ++ {
233
+ for i := 0 ; i < totalServers ; i ++ {
169
234
if grpcServers [i ] != nil {
170
235
grpcServers [i ].Stop ()
171
236
}
@@ -180,38 +245,56 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
180
245
resolver .Register (r )
181
246
182
247
// Prepare addresses for the resolver
183
- addrs := make ([]resolver.Address , 0 , len (addresses ))
184
- for _ , addr := range addresses {
248
+ addrs := make ([]resolver.Address , 0 , totalServers )
249
+ for i , addr := range addresses {
250
+ t .Logf ("Adding server %d at address %s to resolver" , i , addr )
185
251
addrs = append (addrs , resolver.Address {Addr : addr })
186
252
}
187
253
r .InitialState (resolver.State {Addresses : addrs })
188
254
189
255
// Test different load balancing policies
190
256
tests := []struct {
191
- name string
192
- balancingPolicy string
257
+ name string
258
+ balancingPolicy string
259
+ minExpectedDuration time.Duration
260
+ maxExpectedDuration time.Duration
193
261
}{
194
- {"RoundRobin" , "round_robin" },
195
- {"PickFirst" , "pick_first" },
262
+ {
263
+ name : "RoundRobin" ,
264
+ balancingPolicy : "round_robin" ,
265
+ minExpectedDuration : 50 * time .Millisecond , // Should connect to a fast server quickly
266
+ maxExpectedDuration : 1 * time .Second , // Should not take too long
267
+ },
268
+ {
269
+ name : "PickFirst" ,
270
+ balancingPolicy : "pick_first" ,
271
+ minExpectedDuration : 8 * time .Second , // Should try first 4 slow servers (4 * 2.5s with some overhead)
272
+ maxExpectedDuration : 15 * time .Second , // Upper bound
273
+ },
196
274
}
197
275
198
276
for _ , tc := range tests {
199
277
t .Run (tc .name , func (t * testing.T ) {
278
+ // Reset dial attempts for this test
279
+ dialer .dialAttempts = make (map [string ]int )
280
+
200
281
// Monitor connection establishment time
201
282
dialStart := time .Now ()
202
283
203
284
// Create context with timeout for connection establishment
204
- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
285
+ ctx , cancel := context .WithTimeout (context .Background (), 20 * time .Second )
205
286
defer cancel ()
206
287
207
- // #nosec G402 - Using insecure credentials is acceptable for testing
288
+ t .Logf ("Attempting to connect with %s balancing policy" , tc .balancingPolicy )
289
+
208
290
// Establish connection with our balancing policy
209
291
conn , err := grpc .DialContext (
210
292
ctx ,
211
- "test:///unused" , // Address doesn't matter as we use manual resolver
293
+ "test:///unused" ,
294
+ grpc .WithContextDialer (dialer .DialContext ),
212
295
grpc .WithTransportCredentials (insecure .NewCredentials ()),
213
296
grpc .WithDefaultServiceConfig (fmt .Sprintf (`{"loadBalancingPolicy": "%s"}` , tc .balancingPolicy )),
214
- grpc .WithBlock (), // Wait for connection establishment to complete
297
+ grpc .WithBlock (),
215
298
)
216
299
217
300
dialDuration := time .Since (dialStart )
@@ -221,6 +304,13 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
221
304
}
222
305
defer conn .Close ()
223
306
307
+ // Log all dial attempts
308
+ t .Logf ("Connection established in %v" , dialDuration )
309
+ for i , addr := range addresses {
310
+ attempts := dialer .GetDialAttempts (addr )
311
+ t .Logf ("Server %d at %s: %d dial attempts" , i , addr , attempts )
312
+ }
313
+
224
314
// Create client and make a request
225
315
client := NewSimpleServiceClient (conn )
226
316
_ , err = client .Ping (context .Background (), & emptypb.Empty {})
@@ -231,39 +321,58 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
231
321
// Analyze behavior based on balancing policy
232
322
switch tc .balancingPolicy {
233
323
case "round_robin" :
234
- // For round_robin, we expect fast connection as it connects
235
- // to all servers in parallel and should quickly find working ones
236
- if dialDuration >= 400 * time .Millisecond {
237
- t .Logf ("round_robin dial took %v, expected less than 400ms" , dialDuration )
324
+ if dialDuration < tc .minExpectedDuration || dialDuration > tc .maxExpectedDuration {
325
+ t .Errorf ("round_robin dial took %v, expected between %v and %v" ,
326
+ dialDuration , tc .minExpectedDuration , tc .maxExpectedDuration )
238
327
}
239
328
240
- // Verify that requests execute successfully
241
- for i := 0 ; i < 10 ; i ++ {
242
- _ , err = client . Ping ( context . Background (), & emptypb. Empty {})
243
- if err != nil {
244
- t . Fatalf ( "Ping failed: %v" , err )
329
+ // Check if multiple servers were attempted
330
+ attemptedServers := 0
331
+ for _ , addr := range addresses {
332
+ if dialer . GetDialAttempts ( addr ) > 0 {
333
+ attemptedServers ++
245
334
}
246
335
}
247
336
248
- t .Logf ("round_robin successfully established connection in %v" , dialDuration )
337
+ // round_robin should try multiple servers in parallel
338
+ if attemptedServers < 2 {
339
+ t .Errorf ("Expected round_robin to attempt multiple servers, but only %d were attempted" , attemptedServers )
340
+ }
341
+
342
+ t .Logf ("round_robin successfully established connection" )
249
343
250
344
case "pick_first" :
251
- // For pick_first, connection time is important - if the first server is unavailable,
252
- // connection might take longer
253
- if servers [0 ].delay > 0 {
254
- t .Logf ("pick_first dial took %v (expected to be affected by the delay)" , dialDuration )
345
+ if dialDuration < tc .minExpectedDuration {
346
+ t .Errorf ("pick_first connected too quickly: %v, expected at least %v" ,
347
+ dialDuration , tc .minExpectedDuration )
255
348
}
256
349
257
- // Verify that requests execute successfully
258
- for i := 0 ; i < 10 ; i ++ {
259
- _ , err = client .Ping (context .Background (), & emptypb.Empty {})
260
- if err != nil {
261
- t .Fatalf ("Ping failed: %v" , err )
350
+ // Check sequential dialing pattern
351
+ for i := 1 ; i < totalServers ; i ++ {
352
+ prevAddr := addresses [i - 1 ]
353
+ currAddr := addresses [i ]
354
+
355
+ prevAttempts := dialer .GetDialAttempts (prevAddr )
356
+ currAttempts := dialer .GetDialAttempts (currAddr )
357
+
358
+ if currAttempts > 0 && prevAttempts == 0 {
359
+ t .Errorf ("pick_first should try servers sequentially, but server %d was attempted before server %d" ,
360
+ i , i - 1 )
262
361
}
263
362
}
264
363
265
- t .Logf ("pick_first successfully established connection in %v" , dialDuration )
364
+ t .Logf ("pick_first eventually found a working server after trying slow ones" )
266
365
}
366
+
367
+ // Make additional ping requests to verify connection works
368
+ for i := 0 ; i < 3 ; i ++ {
369
+ _ , err = client .Ping (context .Background (), & emptypb.Empty {})
370
+ if err != nil {
371
+ t .Fatalf ("Ping %d failed: %v" , i , err )
372
+ }
373
+ }
374
+
375
+ t .Logf ("Successfully completed ping requests with %s policy" , tc .balancingPolicy )
267
376
})
268
377
}
269
378
}
0 commit comments