Skip to content

Commit ddb965d

Browse files
Fix Multi-GPU Seed Problem (#220)
* fix multigpu * fix multi-gpu seeding --------- Co-authored-by: Gordon Guocheng Qian 钱国成 <guocheng.qian@outlook.com>
1 parent 56170dd commit ddb965d

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

launch.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def main(args, extras) -> None:
7979
ProgressCallback,
8080
)
8181
from threestudio.utils.config import ExperimentConfig, load_config
82+
from threestudio.utils.misc import get_rank
8283
from threestudio.utils.typing import Optional
8384

8485
logger = logging.getLogger("pytorch_lightning")
@@ -97,7 +98,8 @@ def main(args, extras) -> None:
9798
cfg: ExperimentConfig
9899
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
99100

100-
pl.seed_everything(cfg.seed)
101+
# set a different seed for each device
102+
pl.seed_everything(cfg.seed + get_rank(), workers=True)
101103

102104
dm = threestudio.find(cfg.data_type)(cfg.data)
103105
system: BaseSystem = threestudio.find(cfg.system_type)(

threestudio/models/geometry/implicit_sdf.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
1111
from threestudio.models.mesh import Mesh
1212
from threestudio.models.networks import get_encoding, get_mlp
13-
from threestudio.utils.misc import get_rank
13+
from threestudio.utils.misc import broadcast, get_rank
1414
from threestudio.utils.typing import *
1515

1616

@@ -209,6 +209,10 @@ def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
209209
loss.backward()
210210
optim.step()
211211

212+
# explicit broadcast to ensure param consistency across ranks
213+
for param in self.parameters():
214+
broadcast(param, src=0)
215+
212216
def get_shifted_sdf(
213217
self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"]
214218
) -> Float[Tensor, "*N 1"]:

threestudio/utils/misc.py

+8
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,11 @@ def barrier():
110110
return
111111
else:
112112
torch.distributed.barrier()
113+
114+
115+
def broadcast(tensor, src=0):
116+
if not _distributed_available():
117+
return tensor
118+
else:
119+
torch.distributed.broadcast(tensor, src=src)
120+
return tensor

0 commit comments

Comments
 (0)