Skip to content

Commit

Permalink
[dask] make random port search more resilient to random collisions (f…
Browse files Browse the repository at this point in the history
…ixes #4057) (#4133)

* [dask] make random port search more resilient to random collisions

* linting

* more reliable ports check

* address review comments

* add error message
  • Loading branch information
jameslamb authored Mar 31, 2021
1 parent 9388b2e commit 1ce4b22
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,44 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
return out


def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]:
"""Fix any duplicate IP-port pairs in a ``worker_map``."""
worker_map = deepcopy(worker_map)
workers_that_need_new_ports = []
host_to_port = defaultdict(set)
for worker, port in worker_map.items():
host = urlparse(worker).hostname
if port in host_to_port[host]:
workers_that_need_new_ports.append(worker)
else:
host_to_port[host].add(port)

# if any duplicates were found, search for new ports one by one
for worker in workers_that_need_new_ports:
_log_info(f"Searching for a LightGBM training port for worker '{worker}'")
host = urlparse(worker).hostname
retries_remaining = 100
while retries_remaining > 0:
retries_remaining -= 1
new_port = client.submit(
_find_random_open_port,
workers=[worker],
allow_other_workers=False,
pure=False
).result()
if new_port not in host_to_port[host]:
worker_map[worker] = new_port
host_to_port[host].add(new_port)
break

if retries_remaining == 0:
raise LightGBMError(
"Failed to find an open port. Try re-running training or explicitly setting 'machines' or 'local_listen_port'."
)

return worker_map


def _train(
client: Client,
data: _DaskMatrixLike,
Expand Down Expand Up @@ -367,10 +405,19 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
# this approach with client.run() is faster than searching for ports
# serially, but can produce duplicates sometimes. Try the fast approach one
# time, then pass it through a function that will use a slower but more reliable
# approach if duplicates are found.
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
worker_address_to_port = _possibly_fix_worker_map_duplicates(
worker_map=worker_address_to_port,
client=client
)

machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
Expand Down
31 changes: 31 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,37 @@ def test_find_random_open_port(client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def test_possibly_fix_worker_map(capsys, client):
client.wait_for_workers(2)
worker_addresses = list(client.scheduler_info()["workers"].keys())

retry_msg = 'Searching for a LightGBM training port for worker'

# should handle worker maps without any duplicates
map_without_duplicates = {
worker_address: 12400 + i
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert patched_map == map_without_duplicates
assert retry_msg not in capsys.readouterr().out

# should handle worker maps with duplicates
map_with_duplicates = {
worker_address: 12400
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_with_duplicates
)
assert retry_msg in capsys.readouterr().out
assert len(set(patched_map.values())) == len(worker_addresses)


def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')

Expand Down

0 comments on commit 1ce4b22

Please sign in to comment.