Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT] Fix long exception of asyncio.gather (#2748) #2753

Merged
merged 2 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mars/dataframe/datasource/read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def read_group_to_pandas(
def read_partitioned_to_pandas(
self,
f,
partitions: pq.ParquetPartitions,
partitions: "pq.ParquetPartitions",
partition_keys: List[Tuple],
columns=None,
nrows=None,
Expand Down
5 changes: 3 additions & 2 deletions mars/oscar/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import copy
from typing import Dict, Union

from ..errors import ServerClosed
Expand Down Expand Up @@ -65,15 +66,15 @@ async def _listen(self, client: Client):
message_futures = self._client_to_message_futures.get(client)
self._client_to_message_futures[client] = dict()
for future in message_futures.values():
future.set_exception(e)
future.set_exception(copy.copy(e))
finally:
await asyncio.sleep(0)

message_futures = self._client_to_message_futures.get(client)
self._client_to_message_futures[client] = dict()
error = ServerClosed(f"Remote server {client.dest_address} closed")
for future in message_futures.values():
future.set_exception(error)
future.set_exception(copy.copy(error))

async def call(
self,
Expand Down
40 changes: 40 additions & 0 deletions mars/oscar/backends/mars/tests/test_mars_actor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import sys
import time
import traceback
from collections import deque

import pandas as pd
Expand All @@ -25,6 +26,7 @@
from ..... import oscar as mo
from ....backends.allocate_strategy import RandomSubPool
from ....debug import set_debug_options, DebugOptions
from ...router import Router

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -382,6 +384,44 @@ async def test_mars_batch_method(actor_pool_context):
await ref1.add_ret.batch(ref1.add_ret.delay(1), ref1.add.delay(2))


@pytest.mark.asyncio
async def test_gather_exception(actor_pool_context):
try:
Router.get_instance_or_empty()._cache.clear()
pool = actor_pool_context
ref1 = await mo.create_actor(DummyActor, 1, address=pool.external_address)
router = Router.get_instance_or_empty()
client = next(iter(router._cache.values()))

future = asyncio.Future()
client_channel = client.channel

class FakeChannel(type(client_channel)):
def __init__(self):
pass

def __getattr__(self, item):
return getattr(client_channel, item)

async def recv(self):
return await future

client.channel = FakeChannel()

class MyException(Exception):
pass

await ref1.add(1)
tasks = [ref1.add(i) for i in range(200)]
future.set_exception(MyException("Test recv exception!!"))
with pytest.raises(MyException) as ex:
await asyncio.gather(*tasks)
s = traceback.format_tb(ex.tb)
assert 10 > "\n".join(s).count("send") > 0
finally:
Router.get_instance_or_empty()._cache.clear()


@pytest.mark.asyncio
async def test_mars_destroy_has_actor(actor_pool_context):
pool = actor_pool_context
Expand Down