@@ -126,27 +126,80 @@ func (c *simpleServiceClient) Ping(
126
126
return out , nil
127
127
}
128
128
129
+ // CustomDialer implements the dialer function with controlled delays
130
+ type CustomDialer struct {
131
+ // Map of address to delay before connection
132
+ delays map [string ]time.Duration
133
+ // Mutex for thread safety
134
+ mu sync.Mutex
135
+ // Keeps track of dial attempt count
136
+ dialAttempts map [string ]int
137
+ }
138
+
139
+ // DialContext is used by gRPC to establish connections
140
+ func (d * CustomDialer ) DialContext (ctx context.Context , addr string ) (net.Conn , error ) {
141
+ d .mu .Lock ()
142
+ delay , exists := d .delays [addr ]
143
+ d .dialAttempts [addr ]++
144
+ attemptCount := d .dialAttempts [addr ]
145
+ d .mu .Unlock ()
146
+
147
+ // Log the dial attempt
148
+ fmt .Printf ("Attempting to dial %s (attempt #%d)\n " , addr , attemptCount )
149
+
150
+ if exists && delay > 0 {
151
+ // Simulating connection delay or timeout
152
+ fmt .Printf ("Simulating delay of %v for %s\n " , delay , addr )
153
+
154
+ select {
155
+ case <- time .After (delay ):
156
+ // If this is a simulated failure, return error
157
+ if delay >= 2 * time .Second {
158
+ fmt .Printf ("Connection to %s timed out after %v\n " , addr , delay )
159
+
160
+ return nil , fmt .Errorf ("connection timeout" )
161
+ }
162
+ case <- ctx .Done ():
163
+ return nil , ctx .Err ()
164
+ }
165
+ }
166
+
167
+ // Establish a real connection to the address
168
+ dialer := & net.Dialer {}
169
+
170
+ return dialer .DialContext (ctx , "tcp" , addr )
171
+ }
172
+
173
+ // GetDialAttempts returns the number of dial attempts for an address
174
+ func (d * CustomDialer ) GetDialAttempts (addr string ) int {
175
+ d .mu .Lock ()
176
+ defer d .mu .Unlock ()
177
+
178
+ return d .dialAttempts [addr ]
179
+ }
180
+
129
181
// TestGRPCLoadBalancingPolicies tests how different load balancing policies behave
130
182
// This is a test function, so we can ignore the staticcheck warnings about deprecated methods
131
183
// as we need to use these specific gRPC APIs for testing the load balancing behavior.
132
- //
133
- //nolint:staticcheck
134
184
func TestGRPCLoadBalancingPolicies (t * testing.T ) {
135
- // Start several real gRPC servers with different characteristics
136
- servers := make ([]* simpleServer , 3 )
137
- listeners := make ([]net.Listener , 3 )
138
- grpcServers := make ([]* grpc.Server , 3 )
139
- addresses := make ([]string , 3 )
140
-
141
- // Create servers with different behaviors
142
- for i := 0 ; i < 3 ; i ++ {
143
- // First server has a delay, others respond immediately
144
- delay := time .Duration (0 )
145
- if i == 0 {
146
- delay = 500 * time .Millisecond
147
- }
185
+ // Total number of servers to test
186
+ const totalServers = 6
187
+
188
+ // Setup servers
189
+ servers := make ([]* simpleServer , totalServers )
190
+ listeners := make ([]net.Listener , totalServers )
191
+ grpcServers := make ([]* grpc.Server , totalServers )
192
+ addresses := make ([]string , totalServers )
193
+
194
+ // Custom dialer with controlled delays
195
+ dialer := & CustomDialer {
196
+ delays : make (map [string ]time.Duration ),
197
+ dialAttempts : make (map [string ]int ),
198
+ }
148
199
149
- servers [i ] = & simpleServer {delay : delay }
200
+ // Start all servers
201
+ for i := 0 ; i < totalServers ; i ++ {
202
+ servers [i ] = & simpleServer {}
150
203
grpcServers [i ] = grpc .NewServer ()
151
204
RegisterSimpleServiceServer (grpcServers [i ], servers [i ])
152
205
@@ -158,14 +211,24 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
158
211
listeners [i ] = lis
159
212
addresses [i ] = lis .Addr ().String ()
160
213
214
+ // First 4 servers will have a "connection delay" of 2.5 seconds, simulating timeout
215
+ if i < 4 {
216
+ dialer.delays [addresses [i ]] = 2500 * time .Millisecond
217
+ } else {
218
+ // Last two servers connect quickly
219
+ dialer.delays [addresses [i ]] = 50 * time .Millisecond
220
+ }
221
+
222
+ t .Logf ("Started server %d at %s with delay %v" , i , addresses [i ], dialer.delays [addresses [i ]])
223
+
161
224
go func (gs * grpc.Server , l net.Listener ) {
162
225
_ = gs .Serve (l )
163
226
}(grpcServers [i ], lis )
164
227
}
165
228
166
229
// Cleanup after test
167
230
defer func () {
168
- for i := 0 ; i < 3 ; i ++ {
231
+ for i := 0 ; i < totalServers ; i ++ {
169
232
if grpcServers [i ] != nil {
170
233
grpcServers [i ].Stop ()
171
234
}
@@ -180,38 +243,56 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
180
243
resolver .Register (r )
181
244
182
245
// Prepare addresses for the resolver
183
- addrs := make ([]resolver.Address , 0 , len (addresses ))
184
- for _ , addr := range addresses {
246
+ addrs := make ([]resolver.Address , 0 , totalServers )
247
+ for i , addr := range addresses {
248
+ t .Logf ("Adding server %d at address %s to resolver" , i , addr )
185
249
addrs = append (addrs , resolver.Address {Addr : addr })
186
250
}
187
251
r .InitialState (resolver.State {Addresses : addrs })
188
252
189
253
// Test different load balancing policies
190
254
tests := []struct {
191
- name string
192
- balancingPolicy string
255
+ name string
256
+ balancingPolicy string
257
+ minExpectedDuration time.Duration
258
+ maxExpectedDuration time.Duration
193
259
}{
194
- {"RoundRobin" , "round_robin" },
195
- {"PickFirst" , "pick_first" },
260
+ {
261
+ name : "RoundRobin" ,
262
+ balancingPolicy : "round_robin" ,
263
+ minExpectedDuration : 50 * time .Millisecond , // Should connect to a fast server quickly
264
+ maxExpectedDuration : 1 * time .Second , // Should not take too long
265
+ },
266
+ {
267
+ name : "PickFirst" ,
268
+ balancingPolicy : "pick_first" ,
269
+ minExpectedDuration : 8 * time .Second , // Should try first 4 slow servers (4 * 2.5s with some overhead)
270
+ maxExpectedDuration : 15 * time .Second , // Upper bound
271
+ },
196
272
}
197
273
198
274
for _ , tc := range tests {
199
275
t .Run (tc .name , func (t * testing.T ) {
276
+ // Reset dial attempts for this test
277
+ dialer .dialAttempts = make (map [string ]int )
278
+
200
279
// Monitor connection establishment time
201
280
dialStart := time .Now ()
202
281
203
282
// Create context with timeout for connection establishment
204
- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
283
+ ctx , cancel := context .WithTimeout (context .Background (), 20 * time .Second )
205
284
defer cancel ()
206
285
207
- // #nosec G402 - Using insecure credentials is acceptable for testing
286
+ t .Logf ("Attempting to connect with %s balancing policy" , tc .balancingPolicy )
287
+
208
288
// Establish connection with our balancing policy
209
289
conn , err := grpc .DialContext (
210
290
ctx ,
211
- "test:///unused" , // Address doesn't matter as we use manual resolver
291
+ "test:///unused" ,
292
+ grpc .WithContextDialer (dialer .DialContext ),
212
293
grpc .WithTransportCredentials (insecure .NewCredentials ()),
213
294
grpc .WithDefaultServiceConfig (fmt .Sprintf (`{"loadBalancingPolicy": "%s"}` , tc .balancingPolicy )),
214
- grpc .WithBlock (), // Wait for connection establishment to complete
295
+ grpc .WithBlock (),
215
296
)
216
297
217
298
dialDuration := time .Since (dialStart )
@@ -221,6 +302,13 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
221
302
}
222
303
defer conn .Close ()
223
304
305
+ // Log all dial attempts
306
+ t .Logf ("Connection established in %v" , dialDuration )
307
+ for i , addr := range addresses {
308
+ attempts := dialer .GetDialAttempts (addr )
309
+ t .Logf ("Server %d at %s: %d dial attempts" , i , addr , attempts )
310
+ }
311
+
224
312
// Create client and make a request
225
313
client := NewSimpleServiceClient (conn )
226
314
_ , err = client .Ping (context .Background (), & emptypb.Empty {})
@@ -231,39 +319,58 @@ func TestGRPCLoadBalancingPolicies(t *testing.T) {
231
319
// Analyze behavior based on balancing policy
232
320
switch tc .balancingPolicy {
233
321
case "round_robin" :
234
- // For round_robin, we expect fast connection as it connects
235
- // to all servers in parallel and should quickly find working ones
236
- if dialDuration >= 400 * time .Millisecond {
237
- t .Logf ("round_robin dial took %v, expected less than 400ms" , dialDuration )
322
+ if dialDuration < tc .minExpectedDuration || dialDuration > tc .maxExpectedDuration {
323
+ t .Errorf ("round_robin dial took %v, expected between %v and %v" ,
324
+ dialDuration , tc .minExpectedDuration , tc .maxExpectedDuration )
238
325
}
239
326
240
- // Verify that requests execute successfully
241
- for i := 0 ; i < 10 ; i ++ {
242
- _ , err = client . Ping ( context . Background (), & emptypb. Empty {})
243
- if err != nil {
244
- t . Fatalf ( "Ping failed: %v" , err )
327
+ // Check if multiple servers were attempted
328
+ attemptedServers := 0
329
+ for _ , addr := range addresses {
330
+ if dialer . GetDialAttempts ( addr ) > 0 {
331
+ attemptedServers ++
245
332
}
246
333
}
247
334
248
- t .Logf ("round_robin successfully established connection in %v" , dialDuration )
335
+ // round_robin should try multiple servers in parallel
336
+ if attemptedServers < 2 {
337
+ t .Errorf ("Expected round_robin to attempt multiple servers, but only %d were attempted" , attemptedServers )
338
+ }
339
+
340
+ t .Logf ("round_robin successfully established connection" )
249
341
250
342
case "pick_first" :
251
- // For pick_first, connection time is important - if the first server is unavailable,
252
- // connection might take longer
253
- if servers [0 ].delay > 0 {
254
- t .Logf ("pick_first dial took %v (expected to be affected by the delay)" , dialDuration )
343
+ if dialDuration < tc .minExpectedDuration {
344
+ t .Errorf ("pick_first connected too quickly: %v, expected at least %v" ,
345
+ dialDuration , tc .minExpectedDuration )
255
346
}
256
347
257
- // Verify that requests execute successfully
258
- for i := 0 ; i < 10 ; i ++ {
259
- _ , err = client .Ping (context .Background (), & emptypb.Empty {})
260
- if err != nil {
261
- t .Fatalf ("Ping failed: %v" , err )
348
+ // Check sequential dialing pattern
349
+ for i := 1 ; i < totalServers ; i ++ {
350
+ prevAddr := addresses [i - 1 ]
351
+ currAddr := addresses [i ]
352
+
353
+ prevAttempts := dialer .GetDialAttempts (prevAddr )
354
+ currAttempts := dialer .GetDialAttempts (currAddr )
355
+
356
+ if currAttempts > 0 && prevAttempts == 0 {
357
+ t .Errorf ("pick_first should try servers sequentially, but server %d was attempted before server %d" ,
358
+ i , i - 1 )
262
359
}
263
360
}
264
361
265
- t .Logf ("pick_first successfully established connection in %v" , dialDuration )
362
+ t .Logf ("pick_first eventually found a working server after trying slow ones" )
266
363
}
364
+
365
+ // Make additional ping requests to verify connection works
366
+ for i := 0 ; i < 3 ; i ++ {
367
+ _ , err = client .Ping (context .Background (), & emptypb.Empty {})
368
+ if err != nil {
369
+ t .Fatalf ("Ping %d failed: %v" , i , err )
370
+ }
371
+ }
372
+
373
+ t .Logf ("Successfully completed ping requests with %s policy" , tc .balancingPolicy )
267
374
})
268
375
}
269
376
}
0 commit comments