Skip to content

Commit a71096e

Browse files
authored
Support reduce_scatter (#192)
* Support reduce scatter
1 parent dd6ebb2 commit a71096e

8 files changed

+436
-4
lines changed

src/ProcessGroupCCL.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -869,11 +869,17 @@ c10::intrusive_ptr<C10D_Work> ProcessGroupCCL::scatter(
869869
}
870870

871871
c10::intrusive_ptr<C10D_Work> ProcessGroupCCL::reduce_scatter(
872-
std::vector<at::Tensor>& /* unused */,
873-
std::vector<std::vector<at::Tensor>>& /* unused */,
874-
const ReduceScatterOptions& /* unused */)
872+
std::vector<at::Tensor>& outputTensors,
873+
std::vector<std::vector<at::Tensor>>& inputTensors,
874+
const ReduceScatterOptions& opts)
875875
{
876-
TORCH_CHECK(false, "ProcessGroupCCL does not support reduce_scatter");
876+
std::vector<c10::IValue> tensor_param;
877+
format_tensors_param(tensor_param, inputTensors);
878+
format_tensors_param(tensor_param, outputTensors);
879+
RECORD_FUNCTION("oneccl_bindings_for_pytorch::reduce_scatter", tensor_param);
880+
881+
auto work = DispatchStub::reduce_scatter(outputTensors, inputTensors, opts, *this);
882+
return work;
877883
}
878884

879885
c10::intrusive_ptr<C10D_Work> ProcessGroupCCL::_reduce_scatter_base(

src/cpu/cpu_ccl.cpp

+108
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ class VanillaCPU final: public DispatchStub {
138138
const ReduceOptions& opts,
139139
ProcessGroupCCL& pg) override;
140140

141+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_(std::vector<at::Tensor>& outputTensors,
142+
std::vector<std::vector<at::Tensor>>& inputTensors,
143+
const ReduceScatterOptions& opts,
144+
ProcessGroupCCL& pg_ccl) override;
145+
141146
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_(at::Tensor& outputTensor,
142147
at::Tensor& inputTensor,
143148
const ReduceScatterOptions& opts,
@@ -194,6 +199,11 @@ class VanillaCPU final: public DispatchStub {
194199
std::condition_variable queueProduceCV_;
195200
std::condition_variable queueConsumeCV_;
196201

202+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_oop(at::Tensor& outputTensor,
203+
at::Tensor& inputTensor,
204+
const ReduceOptions& opts,
205+
ProcessGroupCCL& pg_ccl);
206+
197207
};
198208

199209
struct RegisterCPUPMethods {
@@ -388,6 +398,45 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_(std::vecto
388398
return work;
389399
}
390400

401+
// _reduce_oop implements an out-of-place reduce procedure.
402+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::_reduce_oop(at::Tensor& outputTensor,
403+
at::Tensor& inputTensor,
404+
const ReduceOptions& opts,
405+
ProcessGroupCCL& pg_ccl) {
406+
const int root = opts.rootRank + opts.rootTensor;
407+
std::vector<at::Tensor> inputTensors{inputTensor};
408+
std::vector<at::Tensor> outputTensors{outputTensor};
409+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
410+
work = collective<get_ccl_comms, CPUWorkCCL>(
411+
pg_ccl,
412+
inputTensors,
413+
outputTensors,
414+
[=](at::Tensor input,
415+
at::Tensor output,
416+
ccl::reduce_attr attr,
417+
ccl::communicator& comm) {
418+
419+
ccl::event ret_evt;
420+
call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() {
421+
CCL_CHECK(ret_evt = ccl::reduce(input.data_ptr(),
422+
output.data_ptr(),
423+
(size_t) input.numel(),
424+
cclDatatypes.at(input.scalar_type()),
425+
cclOps.at(opts.reduceOp),
426+
root,
427+
comm));
428+
});
429+
return ret_evt;
430+
431+
},
432+
c10d::OpType::REDUCE,
433+
"oneccl_bindings_for_pytorch::cpu_work::_reduce_oop");
434+
435+
work->debugName = std::string("cpu::_reduce_oop");
436+
enqueue(work);
437+
return work;
438+
}
439+
391440
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::broadcast_(std::vector<at::Tensor>& tensors,
392441
const BroadcastOptions &opts,
393442
ProcessGroupCCL& pg) {
@@ -596,6 +645,65 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::gather_(std::vecto
596645
return work;
597646
}
598647

648+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_scatter_(std::vector<at::Tensor>& outputTensors,
649+
std::vector<std::vector<at::Tensor>>& inputTensors,
650+
const ReduceScatterOptions& opts,
651+
ProcessGroupCCL& pg_ccl) {
652+
checkSingleTensor(outputTensors);
653+
auto outputTensor = outputTensors.back();
654+
auto inputTensors_ = inputTensors.back();
655+
bool same_size = check_same_size(inputTensors_);
656+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
657+
if (same_size) {
658+
auto inputFlattened = newLikeFlat(inputTensors_);
659+
for (const auto j : c10::irange(inputTensors_.size())) {
660+
inputFlattened[j].copy_(inputTensors_[j], true);
661+
}
662+
std::vector<at::Tensor> flattendInputTensors{inputFlattened};
663+
work = collective<get_ccl_comms, CPUWorkCCL>(
664+
pg_ccl,
665+
flattendInputTensors,
666+
outputTensors,
667+
[=](at::Tensor input,
668+
at::Tensor output,
669+
ccl::reduce_attr attr,
670+
ccl::communicator& comm) {
671+
672+
ccl::event ret_evt;
673+
call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() {
674+
CCL_CHECK(ret_evt = ccl::reduce_scatter(input.data_ptr(),
675+
output.data_ptr(),
676+
(size_t) output.numel(),
677+
cclDatatypes.at(input.scalar_type()),
678+
cclOps.at(opts.reduceOp),
679+
comm));
680+
});
681+
return ret_evt;
682+
683+
},
684+
c10d::OpType::REDUCE_SCATTER,
685+
"oneccl_bindings_for_pytorch::cpu_work::reduce_scatter");
686+
work->debugName = std::string("cpu::reduce_scatter");
687+
enqueue(work);
688+
return work;
689+
690+
} else {
691+
// Use multiple reduce to simulate reduce_scatter.
692+
const auto num_reduces = inputTensors_.size();
693+
for (const int i : c10::irange(num_reduces)) {
694+
auto& input = inputTensors_[i];
695+
auto& output = (i == pg_ccl.getRank()) ? outputTensor : input;
696+
auto reduceOpts = ReduceOptions{
697+
opts.reduceOp,
698+
static_cast<int64_t>(i),
699+
static_cast<int64_t>(0),
700+
opts.timeout};
701+
work = _reduce_oop(output, input, reduceOpts, pg_ccl);
702+
}
703+
return work;
704+
}
705+
}
706+
599707
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::scatter_(std::vector<at::Tensor>& outputTensors,
600708
std::vector<std::vector<at::Tensor>>& inputTensors,
601709
const ScatterOptions& opts,

src/dispatch_stub.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,30 @@ class DebugCCLStub final: public DispatchStub {
161161
return work;
162162
}
163163

164+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_(std::vector<at::Tensor>& outputTensors,
165+
std::vector<std::vector<at::Tensor>>& inputTensors,
166+
const ReduceScatterOptions& opts,
167+
ProcessGroupCCL& pg_ccl) override {
168+
std::stringstream os;
169+
os << "oneccl_bindings_for_pytorch::" << dev_type << "::reduce_scatter: ";
170+
format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++);
171+
os << " input ";
172+
format_tensors_size(os, inputTensors);
173+
os << " output ";
174+
format_tensors_size(os, outputTensors);
175+
std::cout << os.str() << std::endl;
176+
177+
auto workStartTime_ = std::chrono::steady_clock::now();
178+
auto work = hdlr->reduce_scatter_(outputTensors, inputTensors, opts, pg_ccl);
179+
auto currentTimepoint = std::chrono::steady_clock::now();
180+
auto timeElapsed =
181+
std::chrono::duration_cast<std::chrono::microseconds>(
182+
currentTimepoint - workStartTime_);
183+
format_time_elapsed(os, timeElapsed);
184+
std::cout << os.str() << std::endl;
185+
return work;
186+
}
187+
164188
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_(at::Tensor& outputTensor,
165189
at::Tensor& inputTensor,
166190
const ReduceScatterOptions& opts,
@@ -609,6 +633,14 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::scatter(std::vec
609633
return get_ccl_stub(dev_type)->scatter_(outputTensors, inputTensors, opts, pg_ccl);
610634
}
611635

636+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce_scatter(std::vector<at::Tensor>& outputTensors,
637+
std::vector<std::vector<at::Tensor>>& inputTensors,
638+
const ReduceScatterOptions& opts,
639+
ProcessGroupCCL& pg_ccl) {
640+
c10::DeviceType dev_type = outputTensors[0].device().type();
641+
return get_ccl_stub(dev_type)->reduce_scatter_(outputTensors, inputTensors, opts, pg_ccl);
642+
}
643+
612644
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::_reduce_scatter_base(at::Tensor& outputTensor,
613645
at::Tensor& inputTensor,
614646
const ReduceScatterOptions& opts,

src/dispatch_stub.h

+12
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class DispatchStub {
9595
const ReduceScatterOptions& opts,
9696
ProcessGroupCCL& pg_ccl);
9797

98+
static c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter(std::vector<at::Tensor>& outputTensors,
99+
std::vector<std::vector<at::Tensor>>& inputTensors,
100+
const ReduceScatterOptions& opts,
101+
ProcessGroupCCL& pg_ccl);
98102
static c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_tensor_coalesced(
99103
std::vector<at::Tensor>& outputTensors,
100104
std::vector<at::Tensor>& inputTensors,
@@ -187,6 +191,14 @@ class DispatchStub {
187191
fail(outputTensors[0].device().type(), "scatter");
188192
return c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL>();
189193
}
194+
195+
virtual c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_(std::vector<at::Tensor>& outputTensors,
196+
std::vector<std::vector<at::Tensor>>& inputTensors,
197+
const ReduceScatterOptions& opts,
198+
ProcessGroupCCL& pg_ccl) {
199+
fail(outputTensors[0].device().type(), "reduce_scatter");
200+
return c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL>();
201+
}
190202

191203
virtual c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_(at::Tensor& outputTensor,
192204
at::Tensor& inputTensor,

src/gpu/dpcpp_ccl.cpp

+116
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,11 @@ class XPUCCLStubs final: public DispatchStub {
546546
const ReduceOptions& opts,
547547
ProcessGroupCCL& pg_ccl) override;
548548

549+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_(std::vector<at::Tensor>& outputTensors,
550+
std::vector<std::vector<at::Tensor>>& inputTensors,
551+
const ReduceScatterOptions& opts,
552+
ProcessGroupCCL& pg_ccl) override;
553+
549554
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_(at::Tensor& outputTensor,
550555
at::Tensor& inputTensor,
551556
const ReduceScatterOptions& opts,
@@ -629,6 +634,10 @@ class XPUCCLStubs final: public DispatchStub {
629634
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> allreduce_impl(std::vector<at::Tensor>& tensors,
630635
const AllreduceOptions& opts,
631636
ProcessGroupCCL& pg_ccl);
637+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_oop(at::Tensor& outputTensor,
638+
at::Tensor& inputTensor,
639+
const ReduceOptions& opts,
640+
ProcessGroupCCL& pg_ccl);
632641
};
633642

634643
struct RegisterXPUMethods {
@@ -837,6 +846,113 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> XPUCCLStubs::reduce_(std::vect
837846
return work;
838847
}
839848

849+
// _reduce_oop implements an out-of-place reduce procedure.
850+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> XPUCCLStubs::_reduce_oop(at::Tensor& outputTensor,
851+
at::Tensor& inputTensor,
852+
const ReduceOptions& opts,
853+
ProcessGroupCCL& pg_ccl) {
854+
const int root = opts.rootRank + opts.rootTensor;
855+
std::vector<at::Tensor> inputTensors{inputTensor};
856+
std::vector<at::Tensor> outputTensors{outputTensor};
857+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
858+
work = collective<get_ccl_comms, XPUWorkCCL>(
859+
pg_ccl,
860+
inputTensors,
861+
outputTensors,
862+
[=](at::Tensor input,
863+
at::Tensor output,
864+
ccl::reduce_attr attr,
865+
ccl::communicator& comm,
866+
ccl::stream& stream) {
867+
RECORD_FUNCTION("oneccl_bindings_for_pytorch::xpu::reduce_oop", std::vector<c10::IValue>{input});
868+
869+
ccl::event ret_evt;
870+
call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() {
871+
CCL_CHECK(ret_evt = ccl::reduce(input.data_ptr(),
872+
output.data_ptr(),
873+
(size_t) input.numel(),
874+
cclDatatypes.at(input.scalar_type()),
875+
cclOps.at(opts.reduceOp),
876+
root,
877+
comm,
878+
stream));
879+
});
880+
return ret_evt;
881+
882+
},
883+
c10d::OpType::REDUCE);
884+
885+
work->debugName = std::string("xpu::_reduce_oop");
886+
execute(work);
887+
888+
return work;
889+
}
890+
891+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> XPUCCLStubs::reduce_scatter_(std::vector<at::Tensor>& outputTensors,
892+
std::vector<std::vector<at::Tensor>>& inputTensors,
893+
const ReduceScatterOptions& opts,
894+
ProcessGroupCCL& pg_ccl) {
895+
checkSingleTensor(outputTensors);
896+
auto outputTensor = outputTensors.back();
897+
auto inputTensors_ = inputTensors.back();
898+
bool same_size = check_same_size(inputTensors_);
899+
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
900+
if (same_size) {
901+
auto inputFlattened = newLikeFlat(inputTensors_);
902+
for (const auto j : c10::irange(inputTensors_.size())) {
903+
inputFlattened[j].copy_(inputTensors_[j], true);
904+
}
905+
std::vector<at::Tensor> flattendInputTensors{inputFlattened};
906+
907+
work = collective<get_ccl_comms, XPUWorkCCL>(
908+
pg_ccl,
909+
flattendInputTensors,
910+
outputTensors,
911+
[=](at::Tensor input,
912+
at::Tensor output,
913+
ccl::reduce_attr attr,
914+
ccl::communicator& comm,
915+
ccl::stream& stream) {
916+
RECORD_FUNCTION("oneccl_bindings_for_pytorch::xpu::reduce_scatter", std::vector<c10::IValue>{input});
917+
918+
ccl::event ret_evt;
919+
call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() {
920+
CCL_CHECK(ret_evt = ccl::reduce_scatter(input.data_ptr(),
921+
output.data_ptr(),
922+
(size_t) output.numel(),
923+
cclDatatypes.at(input.scalar_type()),
924+
cclOps.at(opts.reduceOp),
925+
comm,
926+
stream));
927+
});
928+
return ret_evt;
929+
930+
},
931+
c10d::OpType::REDUCE_SCATTER);
932+
933+
work->debugName = std::string("xpu::reduce_scatter");
934+
execute(work);
935+
return work;
936+
} else {
937+
// Use multiple reduce to simulate reduce_scatter.
938+
// Currently one-ccl doest support grouped primitives, we'll add coalescing when it supports.
939+
// todo: startCoalescing
940+
const auto num_reduces = inputTensors_.size();
941+
for (const int i : c10::irange(num_reduces)) {
942+
auto& input = inputTensors_[i];
943+
auto& output = (i == pg_ccl.getRank()) ? outputTensor : input;
944+
auto reduceOpts = ReduceOptions{
945+
opts.reduceOp,
946+
static_cast<int64_t>(i),
947+
static_cast<int64_t>(0),
948+
opts.timeout};
949+
work = _reduce_oop(output, input, reduceOpts, pg_ccl);
950+
}
951+
// todo: endCoalescing
952+
return work;
953+
}
954+
}
955+
840956
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> XPUCCLStubs::_reduce_scatter_base_(at::Tensor& outputTensor,
841957
at::Tensor& inputTensor,
842958
const ReduceScatterOptions& opts,

0 commit comments

Comments
 (0)