Skip to content

Commit 5acbde3

Browse files
chayimjamestiotiodvora-h
authored
Fixing cancelled async futures (#2666)
Co-authored-by: James R T <jamestiotio@gmail.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
1 parent 6d886d7 commit 5acbde3

File tree

7 files changed

+234
-75
lines changed

7 files changed

+234
-75
lines changed

.github/workflows/integration.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
timeout-minutes: 30
5252
strategy:
5353
max-parallel: 15
54+
fail-fast: false
5455
matrix:
5556
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
5657
test-type: ['standalone', 'cluster']
@@ -108,6 +109,7 @@ jobs:
108109
name: Install package from commit hash
109110
runs-on: ubuntu-latest
110111
strategy:
112+
fail-fast: false
111113
matrix:
112114
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
113115
steps:

redis/asyncio/client.py

+69-30
Original file line numberDiff line numberDiff line change
@@ -500,28 +500,37 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
500500
):
501501
raise error
502502

503-
# COMMAND EXECUTION AND PROTOCOL PARSING
504-
async def execute_command(self, *args, **options):
505-
"""Execute a command and return a parsed response"""
506-
await self.initialize()
507-
pool = self.connection_pool
508-
command_name = args[0]
509-
conn = self.connection or await pool.get_connection(command_name, **options)
510-
511-
if self.single_connection_client:
512-
await self._single_conn_lock.acquire()
503+
async def _try_send_command_parse_response(self, conn, *args, **options):
513504
try:
514505
return await conn.retry.call_with_retry(
515506
lambda: self._send_command_parse_response(
516-
conn, command_name, *args, **options
507+
conn, args[0], *args, **options
517508
),
518509
lambda error: self._disconnect_raise(conn, error),
519510
)
511+
except asyncio.CancelledError:
512+
await conn.disconnect(nowait=True)
513+
raise
520514
finally:
521515
if self.single_connection_client:
522516
self._single_conn_lock.release()
523517
if not self.connection:
524-
await pool.release(conn)
518+
await self.connection_pool.release(conn)
519+
520+
# COMMAND EXECUTION AND PROTOCOL PARSING
521+
async def execute_command(self, *args, **options):
522+
"""Execute a command and return a parsed response"""
523+
await self.initialize()
524+
pool = self.connection_pool
525+
command_name = args[0]
526+
conn = self.connection or await pool.get_connection(command_name, **options)
527+
528+
if self.single_connection_client:
529+
await self._single_conn_lock.acquire()
530+
531+
return await asyncio.shield(
532+
self._try_send_command_parse_response(conn, *args, **options)
533+
)
525534

526535
async def parse_response(
527536
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -765,10 +774,18 @@ async def _disconnect_raise_connect(self, conn, error):
765774
is not a TimeoutError. Otherwise, try to reconnect
766775
"""
767776
await conn.disconnect()
777+
768778
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
769779
raise error
770780
await conn.connect()
771781

782+
async def _try_execute(self, conn, command, *arg, **kwargs):
783+
try:
784+
return await command(*arg, **kwargs)
785+
except asyncio.CancelledError:
786+
await conn.disconnect()
787+
raise
788+
772789
async def _execute(self, conn, command, *args, **kwargs):
773790
"""
774791
Connect manually upon disconnection. If the Redis server is down,
@@ -777,9 +794,11 @@ async def _execute(self, conn, command, *args, **kwargs):
777794
called by the # connection to resubscribe us to any channels and
778795
patterns we were previously listening to
779796
"""
780-
return await conn.retry.call_with_retry(
781-
lambda: command(*args, **kwargs),
782-
lambda error: self._disconnect_raise_connect(conn, error),
797+
return await asyncio.shield(
798+
conn.retry.call_with_retry(
799+
lambda: self._try_execute(conn, command, *args, **kwargs),
800+
lambda error: self._disconnect_raise_connect(conn, error),
801+
)
783802
)
784803

785804
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1181,6 +1200,18 @@ async def _disconnect_reset_raise(self, conn, error):
11811200
await self.reset()
11821201
raise
11831202

1203+
async def _try_send_command_parse_response(self, conn, *args, **options):
1204+
try:
1205+
return await conn.retry.call_with_retry(
1206+
lambda: self._send_command_parse_response(
1207+
conn, args[0], *args, **options
1208+
),
1209+
lambda error: self._disconnect_reset_raise(conn, error),
1210+
)
1211+
except asyncio.CancelledError:
1212+
await conn.disconnect()
1213+
raise
1214+
11841215
async def immediate_execute_command(self, *args, **options):
11851216
"""
11861217
Execute a command immediately, but don't auto-retry on a
@@ -1196,13 +1227,13 @@ async def immediate_execute_command(self, *args, **options):
11961227
command_name, self.shard_hint
11971228
)
11981229
self.connection = conn
1199-
1200-
return await conn.retry.call_with_retry(
1201-
lambda: self._send_command_parse_response(
1202-
conn, command_name, *args, **options
1203-
),
1204-
lambda error: self._disconnect_reset_raise(conn, error),
1205-
)
1230+
try:
1231+
return await asyncio.shield(
1232+
self._try_send_command_parse_response(conn, *args, **options)
1233+
)
1234+
except asyncio.CancelledError:
1235+
await conn.disconnect()
1236+
raise
12061237

12071238
def pipeline_execute_command(self, *args, **options):
12081239
"""
@@ -1369,6 +1400,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13691400
await self.reset()
13701401
raise
13711402

1403+
async def _try_execute(self, conn, execute, stack, raise_on_error):
1404+
try:
1405+
return await conn.retry.call_with_retry(
1406+
lambda: execute(conn, stack, raise_on_error),
1407+
lambda error: self._disconnect_raise_reset(conn, error),
1408+
)
1409+
except asyncio.CancelledError:
1410+
# not supposed to be possible, yet here we are
1411+
await conn.disconnect(nowait=True)
1412+
raise
1413+
finally:
1414+
await self.reset()
1415+
13721416
async def execute(self, raise_on_error: bool = True):
13731417
"""Execute all the commands in the current pipeline"""
13741418
stack = self.command_stack
@@ -1391,15 +1435,10 @@ async def execute(self, raise_on_error: bool = True):
13911435

13921436
try:
13931437
return await asyncio.shield(
1394-
conn.retry.call_with_retry(
1395-
lambda: execute(conn, stack, raise_on_error),
1396-
lambda error: self._disconnect_raise_reset(conn, error),
1397-
)
1438+
self._try_execute(conn, execute, stack, raise_on_error)
13981439
)
1399-
except asyncio.CancelledError:
1400-
# not supposed to be possible, yet here we are
1401-
await conn.disconnect(nowait=True)
1402-
raise
1440+
except RuntimeError:
1441+
await self.reset()
14031442
finally:
14041443
await self.reset()
14051444

redis/asyncio/cluster.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
10161016
finally:
10171017
self._free.append(connection)
10181018

1019+
async def _try_parse_response(self, cmd, connection, ret):
1020+
try:
1021+
cmd.result = await asyncio.shield(
1022+
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
1023+
)
1024+
except asyncio.CancelledError:
1025+
await connection.disconnect(nowait=True)
1026+
raise
1027+
except Exception as e:
1028+
cmd.result = e
1029+
ret = True
1030+
return ret
1031+
10191032
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10201033
# Acquire connection
10211034
connection = self.acquire_connection()
@@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10281041
# Read responses
10291042
ret = False
10301043
for cmd in commands:
1031-
try:
1032-
cmd.result = await self.parse_response(
1033-
connection, cmd.args[0], **cmd.kwargs
1034-
)
1035-
except Exception as e:
1036-
cmd.result = e
1037-
ret = True
1044+
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
10381045

10391046
# Release connection
10401047
self._free.append(connection)

tests/test_asyncio/test_cluster.py

-17
Original file line numberDiff line numberDiff line change
@@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
340340
rc = RedisCluster.from_url("rediss://localhost:16379")
341341
assert rc.connection_kwargs["connection_class"] is SSLConnection
342342

343-
async def test_asynckills(self, r) -> None:
344-
345-
await r.set("foo", "foo")
346-
await r.set("bar", "bar")
347-
348-
t = asyncio.create_task(r.get("foo"))
349-
await asyncio.sleep(1)
350-
t.cancel()
351-
try:
352-
await t
353-
except asyncio.CancelledError:
354-
pytest.fail("connection is left open with unread response")
355-
356-
assert await r.get("bar") == b"bar"
357-
assert await r.ping()
358-
assert await r.get("foo") == b"foo"
359-
360343
async def test_max_connections(
361344
self, create_redis: Callable[..., RedisCluster]
362345
) -> None:

tests/test_asyncio/test_connection.py

-21
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,6 @@ async def test_invalid_response(create_redis):
4444
await r.connection.disconnect()
4545

4646

47-
async def test_asynckills():
48-
49-
for b in [True, False]:
50-
r = Redis(single_connection_client=b)
51-
52-
await r.set("foo", "foo")
53-
await r.set("bar", "bar")
54-
55-
t = asyncio.create_task(r.get("foo"))
56-
await asyncio.sleep(1)
57-
t.cancel()
58-
try:
59-
await t
60-
except asyncio.CancelledError:
61-
pytest.fail("connection left open with unread response")
62-
63-
assert await r.get("bar") == b"bar"
64-
assert await r.ping()
65-
assert await r.get("foo") == b"foo"
66-
67-
6847
@pytest.mark.onlynoncluster
6948
async def test_single_connection():
7049
"""Test that concurrent requests on a single client are synchronised."""

0 commit comments

Comments
 (0)