Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scheduler transition error on memory->erred #8549

Merged
merged 26 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@
)

v = a_recs.get(key, finish)
# The inner rec has higher priority? Is that always desired?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a general comment about the two-step transitions. The recommendations created by the first step are executed before the second step, which may create weird state (as it did in this case).

func = self._TRANSITIONS_TABLE["released", v]
b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id)

Expand Down Expand Up @@ -2083,7 +2084,11 @@
assert not ts.who_has
assert not ts.processing_on
for dts in ts.dependencies:
assert dts.state not in {"forgotten", "erred"}
assert dts.state not in {"forgotten", "erred"}, (
str(ts),
str(dts),
self.transition_log,
)

if ts.has_lost_dependencies:
return {key: "forgotten"}, {}, {}
Expand Down Expand Up @@ -2481,7 +2486,9 @@
recommendations[key] = "forgotten"
elif ts.has_lost_dependencies:
recommendations[key] = "forgotten"
elif ts.who_wants or ts.waiters:
elif (ts.who_wants or ts.waiters) and not any(
dts.state == "erred" for dts in ts.dependencies
):
recommendations[key] = "waiting"

for dts in ts.waiters or ():
Expand All @@ -2506,14 +2513,13 @@
assert ts.exception_blame
assert not ts.who_has
assert not ts.waiting_on
assert not ts.waiters
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion does not work in two-step transitions.


failing_ts = ts.exception_blame
assert failing_ts

for dts in ts.dependents:
dts.exception_blame = failing_ts
if not dts.who_has:
dts.exception_blame = failing_ts

Check warning on line 2522 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L2522

Added line #L2522 was not covered by tests
recommendations[dts.key] = "erred"

report_msg = {
Expand Down Expand Up @@ -2548,6 +2554,9 @@

for dts in ts.dependents:
if dts.state == "erred":
# Does this make sense?
# This goes via released
# dts -> released -> waiting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree this makes no sense to me either. Is there a unit test anywhere to shed light on it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't investigated this any further.

recommendations[dts.key] = "waiting"

w_msg = {
Expand Down Expand Up @@ -2622,8 +2631,8 @@
self,
key: Key,
stimulus_id: str,
*,
worker: str,
*,
cause: Key | None = None,
exception: Serialized | None = None,
traceback: Serialized | None = None,
Expand Down Expand Up @@ -2699,7 +2708,7 @@
)
)

for dts in ts.dependents:
for dts in ts.waiters or set():
dts.exception_blame = failing_ts
recommendations[dts.key] = "erred"

Expand Down Expand Up @@ -5040,6 +5049,19 @@
"stimulus_id": stimulus_id,
}
]
elif ts.state == "erred":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already covered by the next paragraph elif ts.run_id != run_id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, AFAIR, we only issue a new run_id when the task is sent to a worker again.

Copy link
Member Author

@hendrikmakait hendrikmakait Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added another test for this. I don't like the way this handled context managers, if you have suggestions on how to simplify this, I'm all ears.

logger.debug(

Check warning on line 5053 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L5053

Added line #L5053 was not covered by tests
"Received already erred task, worker: %s" ", key: %s",
worker,
key,
)
worker_msgs[worker] = [

Check warning on line 5058 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L5058

Added line #L5058 was not covered by tests
{
"op": "free-keys",
"keys": [key],
"stimulus_id": stimulus_id,
}
]
elif ts.run_id != run_id:
if not ts.processing_on or ts.processing_on.address != worker:
logger.debug(
Expand Down
136 changes: 136 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4890,3 +4890,139 @@ async def test_resubmit_different_task_same_key_warns_only_once(

async with Worker(s.address):
assert await c.gather(zs) == [2, 3, 4] # Kept old ys


def block(x, in_event, block_event):
in_event.set()
block_event.wait()
return x


@gen_cluster(
client=True,
nthreads=[("", 1, {"resources": {"a": 1}})],
config={"distributed.scheduler.allowed-failures": 0},
)
async def test_fan_out_pattern_deadlock(c, s, a):
"""Regression test for https://github.com/dask/distributed/issues/8548

This test heavily uses resources to force scheduling decisions.
"""
in_f, block_f = Event(), Event()
in_ha, block_ha = Event(), Event()
in_hb, block_hb = Event(), Event()

# Input task to 'g' that we can fail
with dask.annotate(resources={"b": 1}):
f = delayed(block)(1, in_f, block_f, dask_key_name="f")
g = delayed(inc)(f, dask_key_name="g")

# Fan-out from 'g' and run h1 and h2 on different workers
hb = delayed(block)(g, in_hb, block_hb, dask_key_name="hb")
with dask.annotate(resources={"a": 1}):
ha = delayed(block)(g, in_ha, block_ha, dask_key_name="ha")

f, ha, hb = c.compute([f, ha, hb])
with captured_logger("distributed.scheduler", level=logging.ERROR) as logger:
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b:
await block_f.set()
await in_ha.wait()
await in_hb.wait()
await in_f.clear()

# Make sure that the scheduler knows that both workers hold 'g' in memory
await async_poll_for(lambda: len(s.tasks["g"].who_has) == 2, timeout=5)
# Remove worker 'b' while it's processing h1
await s.remove_worker(b.address, stimulus_id="remove_b1")
await block_hb.set()
await block_f.clear()

# Remove the new instance of the 'b' worker while it processes 'f'
# to trigger an transition for 'f' to 'erred'
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b:
await in_f.wait()
await in_f.clear()
await s.remove_worker(b.address, stimulus_id="remove_b2")
await block_f.set()
await block_f.clear()

await block_ha.set()
await ha

with pytest.raises(KilledWorker, match="Attempted to run task 'hb'"):
await hb

del ha, hb
# Make sure that h2 gets forgotten on worker 'a'
await async_poll_for(lambda: not a.state.tasks, timeout=5)
# Ensure that no other errors including transition failures were logged
assert (
logger.getvalue()
== "Task hb marked as failed because 1 workers died while trying to run it\nTask f marked as failed because 1 workers died while trying to run it\n"
)


@gen_cluster(
client=True,
nthreads=[("", 1, {"resources": {"a": 1}})],
config={"distributed.scheduler.allowed-failures": 0},
)
async def test_stimulus_from_erred_task(c, s, a):
"""This test heavily uses resources to force scheduling decisions."""
in_f, block_f = Event(), Event()
in_g, block_g = Event(), Event()

with dask.annotate(resources={"b": 1}):
f = delayed(block)(1, in_f, block_f, dask_key_name="f")

with dask.annotate(resources={"a": 1}):
g = delayed(block)(f, in_g, block_g, dask_key_name="g")

f, g = c.compute([f, g])
with captured_logger("distributed.scheduler", level=logging.ERROR) as logger:
frozen_stream_from_a_ctx = freeze_batched_send(a.batched_stream)
frozen_stream_from_a_ctx.__enter__()

async with Worker(s.address, nthreads=1, resources={"b": 1}) as b1:
await block_f.set()
await in_g.wait()
await in_f.clear()
frozen_stream_to_a_ctx = freeze_batched_send(s.stream_comms[a.address])
frozen_stream_to_a_ctx.__enter__()
await s.remove_worker(b1.address, stimulus_id="remove_b1")
await block_f.clear()

# Remove the new instance of the 'b' worker while it processes 'f'
# to trigger a transition for 'f' to 'erred'
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b2:
await in_f.wait()
await in_f.clear()
await s.remove_worker(b2.address, stimulus_id="remove_b2")
await block_f.set()

with pytest.raises(KilledWorker, match="Attempted to run task 'f'"):
await f

# g has already been transitioned to 'erred' because 'f' failed
with pytest.raises(KilledWorker, match="Attempted to run task 'f'"):
await g

# Finish 'g' and let the scheduler know so it can trigger cleanup
await block_g.set()
with mock.patch.object(
s, "stimulus_task_finished", wraps=s.stimulus_task_finished
) as wrapped_stimulus:
frozen_stream_from_a_ctx.__exit__(None, None, None)
# Make sure the `stimulus_task_finished` gets processed
await async_poll_for(lambda: wrapped_stimulus.call_count == 1, timeout=5)

# Allow the scheduler to talk to the worker again
frozen_stream_to_a_ctx.__exit__(None, None, None)
# Make sure all data gets forgotten on worker 'a'
await async_poll_for(lambda: not a.state.tasks, timeout=5)

# Ensure that no other errors including transition failures were logged
assert (
logger.getvalue()
== "Task f marked as failed because 1 workers died while trying to run it\n"
)
Loading