Skip to content

Commit 3173234

Browse files
committed
propagate ProcessGroup timeout to Store
1 parent 4f80939 commit 3173234

9 files changed

+29
-16
lines changed

torch/distributed/distributed_c10d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def init_process_group(backend,
351351
elif world_size != -1:
352352
url += "?world_size={}".format(world_size)
353353

354-
store, rank, world_size = next(rendezvous(url))
354+
store, rank, world_size = next(rendezvous(url, timeout=timeout))
355355
if backend == Backend.GLOO:
356356
_default_pg = ProcessGroupGloo(
357357
store,

torch/distributed/rendezvous.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
_rendezvous_handlers = {}
11+
_default_store_timeout = timedelta(minutes=5)
1112

1213

1314
def register_rendezvous_handler(scheme, handler):
@@ -53,7 +54,7 @@ def _rendezvous_error(msg):
5354
return ValueError("Error initializing torch.distributed using " + msg)
5455

5556

56-
def _file_rendezvous_handler(url):
57+
def _file_rendezvous_handler(url, timeout=_default_store_timeout):
5758
def _error(msg):
5859
return _rendezvous_error("file:// rendezvous: " + msg)
5960

@@ -69,14 +70,14 @@ def _error(msg):
6970

7071
rank = int(query["rank"])
7172
world_size = int(query["world_size"])
72-
store = FileStore(path, world_size)
73+
store = FileStore(path, world_size, timeout)
7374
yield (store, rank, world_size)
7475

7576
# If this configuration is invalidated, there is nothing we can do about it
7677
raise RuntimeError("Unable to perform rerendezvous using file:// method")
7778

7879

79-
def _tcp_rendezvous_handler(url):
80+
def _tcp_rendezvous_handler(url, timeout=_default_store_timeout):
8081
def _error(msg):
8182
return _rendezvous_error("tcp:// rendezvous: " + msg)
8283

@@ -92,14 +93,14 @@ def _error(msg):
9293
rank = int(query["rank"])
9394
world_size = int(query["world_size"])
9495
start_daemon = rank == 0
95-
store = TCPStore(result.hostname, result.port, world_size, start_daemon)
96+
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
9697
yield (store, rank, world_size)
9798

9899
# If this configuration is invalidated, there is nothing we can do about it
99100
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
100101

101102

102-
def _env_rendezvous_handler(url):
103+
def _env_rendezvous_handler(url, timeout=_default_store_timeout):
103104
def _error(msg):
104105
return _rendezvous_error("env:// rendezvous: " + msg)
105106

@@ -140,7 +141,7 @@ def _env_error(var):
140141

141142
# Now start the TCP store daemon on the rank 0
142143
start_daemon = rank == 0
143-
store = TCPStore(master_addr, master_port, world_size, start_daemon)
144+
store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
144145
yield (store, rank, world_size)
145146

146147
# If this configuration is invalidated, there is nothing we can do about it

torch/lib/c10d/FileStore.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,11 @@ off_t refresh(
206206

207207
} // namespace
208208

209-
FileStore::FileStore(const std::string& path, int numWorkers)
210-
: Store(),
209+
FileStore::FileStore(
210+
const std::string& path,
211+
int numWorkers,
212+
std::chrono::milliseconds timeout)
213+
: Store(timeout),
211214
path_(path),
212215
pos_(0),
213216
numWorkers_(numWorkers),

torch/lib/c10d/FileStore.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace c10d {
1010

1111
class FileStore : public Store {
1212
public:
13-
explicit FileStore(const std::string& path, int numWorkers);
13+
explicit FileStore(
14+
const std::string& path,
15+
int numWorkers,
16+
std::chrono::milliseconds timeout=kDefaultTimeout);
1417

1518
virtual ~FileStore();
1619

torch/lib/c10d/PrefixStore.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ namespace c10d {
66

77
class PrefixStore : public Store {
88
public:
9-
explicit PrefixStore(const std::string& prefix, Store& store);
9+
explicit PrefixStore(
10+
const std::string& prefix,
11+
Store& store);
1012

1113
virtual ~PrefixStore(){};
1214

torch/lib/c10d/ProcessGroupGloo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class GlooStore : public ::gloo::rendezvous::Store {
6666
}
6767

6868
void wait(const std::vector<std::string>& keys) override {
69-
store_->wait(keys, Store::kDefaultTimeout);
69+
store_->wait(keys);
7070
}
7171

7272
void wait(

torch/lib/c10d/Store.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class Store {
1515
static constexpr std::chrono::milliseconds kNoTimeout =
1616
std::chrono::milliseconds::zero();
1717

18-
Store() : timeout_(kDefaultTimeout) {}
18+
Store(std::chrono::milliseconds timeout=kDefaultTimeout)
19+
: timeout_(timeout) {}
1920

2021
virtual ~Store();
2122

torch/lib/c10d/TCPStore.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,10 @@ TCPStore::TCPStore(
279279
const std::string& masterAddr,
280280
PortType masterPort,
281281
int numWorkers,
282-
bool isServer)
283-
: isServer_(isServer),
282+
bool isServer,
283+
std::chrono::milliseconds timeout)
284+
: Store(timeout),
285+
isServer_(isServer),
284286
tcpStoreAddr_(masterAddr),
285287
tcpStorePort_(masterPort),
286288
numWorkers_(numWorkers),

torch/lib/c10d/TCPStore.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class TCPStore : public Store {
4949
const std::string& masterAddr,
5050
PortType masterPort,
5151
int numWorkers,
52-
bool isServer = false);
52+
bool isServer = false,
53+
std::chrono::milliseconds timeout=kDefaultTimeout);
5354

5455
virtual ~TCPStore();
5556

0 commit comments

Comments
 (0)