Skip to content

Commit

Permalink
[mthreads]Update base/benchmarks: resolve comments of P2P_inter
Browse files Browse the repository at this point in the history
  • Loading branch information
gliangMT committed Dec 11, 2024
1 parent 90b781b commit b97e82b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion base/benchmarks/interconnect-P2P_interserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def main(config, case_config, rank, world_size, local_rank):

Melements = case_config.Melements
torchsize = (Melements, 1024, 1024)
tensor = torch.rand(torchsize, dtype=torch.float32).to(local_rank)
if "mthreads" in config.vendor:
tensor = torch.rand(torchsize, dtype=torch.float32).to(local_rank)
else:
tensor = torch.rand(torchsize, dtype=torch.float32).cuda()

host_device_sync(config.vendor)
multi_device_sync(config.vendor)
Expand Down

0 comments on commit b97e82b

Please sign in to comment.