Skip to content

Commit 50c8dbd

Browse files
Add Redis readiness verification (#3555)
1 parent 67ab74d commit 50c8dbd

File tree

11 files changed

+471
-127
lines changed

11 files changed

+471
-127
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
encoding: str = "utf-8",
230230
encoding_errors: str = "strict",
231231
decode_responses: bool = False,
232+
check_server_ready: bool = False,
232233
retry_on_timeout: bool = False,
233234
retry: Retry = Retry(
234235
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -276,6 +277,10 @@ def __init__(
276277
277278
When 'connection_pool' is provided - the retry configuration of the
278279
provided pool will be used.
280+
281+
Args:
282+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
283+
connect and send operations work even when Redis server is not ready.
279284
"""
280285
kwargs: Dict[str, Any]
281286
if event_dispatcher is None:
@@ -310,6 +315,7 @@ def __init__(
310315
"encoding": encoding,
311316
"encoding_errors": encoding_errors,
312317
"decode_responses": decode_responses,
318+
"check_server_ready": check_server_ready,
313319
"retry_on_error": retry_on_error,
314320
"retry": copy.deepcopy(retry),
315321
"max_connections": max_connections,

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def __init__(
289289
encoding_errors: str = "strict",
290290
decode_responses: bool = False,
291291
# Connection related kwargs
292+
check_server_ready: bool = False,
292293
health_check_interval: float = 0,
293294
socket_connect_timeout: Optional[float] = None,
294295
socket_keepalive: bool = False,
@@ -342,6 +343,7 @@ def __init__(
342343
"encoding_errors": encoding_errors,
343344
"decode_responses": decode_responses,
344345
# Connection related kwargs
346+
"check_server_ready": check_server_ready,
345347
"health_check_interval": health_check_interval,
346348
"socket_connect_timeout": socket_connect_timeout,
347349
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
encoding_errors: str = "strict",
149149
decode_responses: bool = False,
150150
parser_class: Type[BaseParser] = DefaultParser,
151+
check_server_ready: bool = False,
151152
socket_read_size: int = 65536,
152153
health_check_interval: float = 0,
153154
client_name: Optional[str] = None,
@@ -204,6 +205,7 @@ def __init__(
204205
self.health_check_interval = health_check_interval
205206
self.next_health_check: float = -1
206207
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208+
self.check_server_ready = check_server_ready
207209
self.redis_connect_func = redis_connect_func
208210
self._reader: Optional[asyncio.StreamReader] = None
209211
self._writer: Optional[asyncio.StreamWriter] = None
@@ -303,11 +305,13 @@ async def connect_check_health(
303305
try:
304306
if retry_socket_connect:
305307
await self.retry.call_with_retry(
306-
lambda: self._connect(), lambda error: self.disconnect()
308+
lambda: self._connect_check_server_ready(),
309+
lambda error: self.disconnect(),
307310
)
308311
else:
309-
await self._connect()
312+
await self._connect_check_server_ready()
310313
except asyncio.CancelledError:
314+
self._close()
311315
raise # in 3.7 and earlier, this is an Exception, not BaseException
312316
except (socket.timeout, asyncio.TimeoutError):
313317
raise TimeoutError("Timeout connecting to server")
@@ -342,6 +346,33 @@ async def connect_check_health(
342346
if task and inspect.isawaitable(task):
343347
await task
344348

349+
async def _connect_check_server_ready(self):
350+
await self._connect()
351+
352+
# Doing handshake since connect and send operations work even when Redis is not ready
353+
if self.check_server_ready:
354+
try:
355+
await self.send_command("PING", check_health=False)
356+
357+
if self.socket_timeout is not None:
358+
async with async_timeout(self.socket_timeout):
359+
response = str_if_bytes(await self._reader.read(1024))
360+
else:
361+
response = str_if_bytes(await self._reader.read(1024))
362+
363+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
364+
raise ResponseError(f"Invalid PING response: {response}")
365+
except (
366+
socket.timeout,
367+
asyncio.TimeoutError,
368+
ResponseError,
369+
ConnectionResetError,
370+
) as e:
371+
# `socket_keepalive_options` might contain invalid options
372+
# causing an error. Do not leave the connection open.
373+
self._close()
374+
raise ConnectionError(self._error_message(e))
375+
345376
@abstractmethod
346377
async def _connect(self):
347378
pass
@@ -531,8 +562,7 @@ async def send_packed_command(
531562
self._send_packed_command(command), self.socket_timeout
532563
)
533564
else:
534-
self._writer.writelines(command)
535-
await self._writer.drain()
565+
await self._send_packed_command(command)
536566
except asyncio.TimeoutError:
537567
await self.disconnect(nowait=True)
538568
raise TimeoutError("Timeout writing to socket") from None
@@ -775,7 +805,7 @@ async def _connect(self):
775805
except (OSError, TypeError):
776806
# `socket_keepalive_options` might contain invalid options
777807
# causing an error. Do not leave the connection open.
778-
writer.close()
808+
self._close()
779809
raise
780810

781811
def _host_error(self) -> str:
@@ -936,7 +966,6 @@ async def _connect(self):
936966
reader, writer = await asyncio.open_unix_connection(path=self.path)
937967
self._reader = reader
938968
self._writer = writer
939-
await self.on_connect()
940969

941970
def _host_error(self) -> str:
942971
return self.path

redis/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(
211211
encoding: str = "utf-8",
212212
encoding_errors: str = "strict",
213213
decode_responses: bool = False,
214+
check_server_ready: bool = False,
214215
retry_on_timeout: bool = False,
215216
retry: Retry = Retry(
216217
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -267,10 +268,11 @@ def __init__(
267268
provided pool will be used.
268269
269270
Args:
270-
271-
single_connection_client:
272-
if `True`, connection pool is not used. In that case `Redis`
273-
instance use is not thread safe.
271+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
272+
connect and send operations work even when Redis server is not ready.
273+
single_connection_client:
274+
if `True`, connection pool is not used. In that case `Redis`
275+
instance use is not thread safe.
274276
"""
275277
if event_dispatcher is None:
276278
self._event_dispatcher = EventDispatcher()
@@ -287,6 +289,7 @@ def __init__(
287289
"encoding": encoding,
288290
"encoding_errors": encoding_errors,
289291
"decode_responses": decode_responses,
292+
"check_server_ready": check_server_ready,
290293
"retry_on_error": retry_on_error,
291294
"retry": copy.deepcopy(retry),
292295
"max_connections": max_connections,

redis/connection.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(
237237
encoding: str = "utf-8",
238238
encoding_errors: str = "strict",
239239
decode_responses: bool = False,
240+
check_server_ready: bool = False,
240241
parser_class=DefaultParser,
241242
socket_read_size: int = 65536,
242243
health_check_interval: int = 0,
@@ -303,6 +304,7 @@ def __init__(
303304
self.redis_connect_func = redis_connect_func
304305
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
305306
self.handshake_metadata = None
307+
self.check_server_ready = check_server_ready
306308
self._sock = None
307309
self._socket_read_size = socket_read_size
308310
self.set_parser(parser_class)
@@ -386,17 +388,17 @@ def connect_check_health(
386388
return
387389
try:
388390
if retry_socket_connect:
389-
sock = self.retry.call_with_retry(
390-
lambda: self._connect(), lambda error: self.disconnect(error)
391+
self.retry.call_with_retry(
392+
lambda: self._connect_check_server_ready(),
393+
lambda error: self.disconnect(error),
391394
)
392395
else:
393-
sock = self._connect()
396+
self._connect_check_server_ready()
394397
except socket.timeout:
395398
raise TimeoutError("Timeout connecting to server")
396399
except OSError as e:
397400
raise ConnectionError(self._error_message(e))
398401

399-
self._sock = sock
400402
try:
401403
if self.redis_connect_func is None:
402404
# Use the default on_connect function
@@ -418,8 +420,27 @@ def connect_check_health(
418420
if callback:
419421
callback(self)
420422

423+
def _connect_check_server_ready(self):
424+
self._connect()
425+
426+
# Doing handshake since connect and send operations work even when Redis is not ready
427+
if self.check_server_ready:
428+
try:
429+
self.send_command("PING", check_health=False)
430+
431+
response = str_if_bytes(self._sock.recv(1024))
432+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
433+
raise ResponseError(f"Invalid PING response: {response}")
434+
except (ConnectionResetError, ResponseError) as err:
435+
try:
436+
self._sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
437+
except OSError:
438+
pass
439+
self._sock.close()
440+
raise ConnectionError(self._error_message(err))
441+
421442
@abstractmethod
422-
def _connect(self):
443+
def _connect(self) -> None:
423444
pass
424445

425446
@abstractmethod
@@ -758,7 +779,7 @@ def repr_pieces(self):
758779
pieces.append(("client_name", self.client_name))
759780
return pieces
760781

761-
def _connect(self):
782+
def _connect(self) -> None:
762783
"Create a TCP socket connection"
763784
# we want to mimic what socket.create_connection does to support
764785
# ipv4/ipv6, but we want to set options prior to calling
@@ -788,7 +809,8 @@ def _connect(self):
788809

789810
# set the socket_timeout now that we're connected
790811
sock.settimeout(self.socket_timeout)
791-
return sock
812+
self._sock = sock
813+
return
792814

793815
except OSError as _:
794816
err = _
@@ -1101,15 +1123,15 @@ def __init__(
11011123
self.ssl_ciphers = ssl_ciphers
11021124
super().__init__(**kwargs)
11031125

1104-
def _connect(self):
1126+
def _connect(self) -> None:
11051127
"""
11061128
Wrap the socket with SSL support, handling potential errors.
11071129
"""
1108-
sock = super()._connect()
1130+
super()._connect()
11091131
try:
1110-
return self._wrap_socket_with_ssl(sock)
1132+
self._sock = self._wrap_socket_with_ssl(self._sock)
11111133
except (OSError, RedisError):
1112-
sock.close()
1134+
self._sock.close()
11131135
raise
11141136

11151137
def _wrap_socket_with_ssl(self, sock):
@@ -1206,7 +1228,7 @@ def repr_pieces(self):
12061228
pieces.append(("client_name", self.client_name))
12071229
return pieces
12081230

1209-
def _connect(self):
1231+
def _connect(self) -> None:
12101232
"Create a Unix domain socket connection"
12111233
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
12121234
sock.settimeout(self.socket_connect_timeout)
@@ -1221,7 +1243,7 @@ def _connect(self):
12211243
sock.close()
12221244
raise
12231245
sock.settimeout(self.socket_timeout)
1224-
return sock
1246+
self._sock = sock
12251247

12261248
def _host_error(self):
12271249
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ async def test_reading_with_load_balancing_strategies(
729729
Connection,
730730
send_command=mock.DEFAULT,
731731
read_response=mock.DEFAULT,
732-
_connect=mock.DEFAULT,
732+
_connect_check_server_ready=mock.DEFAULT,
733733
can_read_destructive=mock.DEFAULT,
734734
on_connect=mock.DEFAULT,
735735
) as mocks:
@@ -761,7 +761,7 @@ def execute_command_mock_third(self, *args, **options):
761761
execute_command.side_effect = execute_command_mock_first
762762
mocks["send_command"].return_value = True
763763
mocks["read_response"].return_value = "OK"
764-
mocks["_connect"].return_value = True
764+
mocks["_connect_check_server_ready"].return_value = True
765765
mocks["can_read_destructive"].return_value = False
766766
mocks["on_connect"].return_value = True
767767

@@ -3117,13 +3117,19 @@ async def execute_command(self, *args, **kwargs):
31173117

31183118
return _create_client
31193119

3120+
@pytest.mark.parametrize("check_server_ready", [True, False])
31203121
async def test_ssl_connection_without_ssl(
3121-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3122+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
31223123
) -> None:
31233124
with pytest.raises(RedisClusterException) as e:
3124-
await create_client(mocked=False, ssl=False)
3125+
await create_client(
3126+
mocked=False, ssl=False, check_server_ready=check_server_ready
3127+
)
31253128
e = e.value.__cause__
3126-
assert "Connection closed by server" in str(e)
3129+
if check_server_ready:
3130+
assert "Invalid PING response" in str(e)
3131+
else:
3132+
assert "Connection closed by server" in str(e)
31273133

31283134
async def test_ssl_with_invalid_cert(
31293135
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)