@@ -138,6 +138,11 @@ class VanillaCPU final: public DispatchStub {
138
138
const ReduceOptions& opts,
139
139
ProcessGroupCCL& pg) override ;
140
140
141
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_ (std::vector<at::Tensor>& outputTensors,
142
+ std::vector<std::vector<at::Tensor>>& inputTensors,
143
+ const ReduceScatterOptions& opts,
144
+ ProcessGroupCCL& pg_ccl) override ;
145
+
141
146
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_ (at::Tensor& outputTensor,
142
147
at::Tensor& inputTensor,
143
148
const ReduceScatterOptions& opts,
@@ -194,6 +199,11 @@ class VanillaCPU final: public DispatchStub {
194
199
std::condition_variable queueProduceCV_;
195
200
std::condition_variable queueConsumeCV_;
196
201
202
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_oop (at::Tensor& outputTensor,
203
+ at::Tensor& inputTensor,
204
+ const ReduceOptions& opts,
205
+ ProcessGroupCCL& pg_ccl);
206
+
197
207
};
198
208
199
209
struct RegisterCPUPMethods {
@@ -388,6 +398,45 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_(std::vecto
388
398
return work;
389
399
}
390
400
401
+ // _reduce_oop implements an out-of-place reduce procedure.
402
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::_reduce_oop (at::Tensor& outputTensor,
403
+ at::Tensor& inputTensor,
404
+ const ReduceOptions& opts,
405
+ ProcessGroupCCL& pg_ccl) {
406
+ const int root = opts.rootRank + opts.rootTensor ;
407
+ std::vector<at::Tensor> inputTensors{inputTensor};
408
+ std::vector<at::Tensor> outputTensors{outputTensor};
409
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
410
+ work = collective<get_ccl_comms, CPUWorkCCL>(
411
+ pg_ccl,
412
+ inputTensors,
413
+ outputTensors,
414
+ [=](at::Tensor input,
415
+ at::Tensor output,
416
+ ccl::reduce_attr attr,
417
+ ccl::communicator& comm) {
418
+
419
+ ccl::event ret_evt;
420
+ call_with_lock (c10d::ProcessGroupCCL::globalMutex, [&]() {
421
+ CCL_CHECK (ret_evt = ccl::reduce (input.data_ptr (),
422
+ output.data_ptr (),
423
+ (size_t ) input.numel (),
424
+ cclDatatypes.at (input.scalar_type ()),
425
+ cclOps.at (opts.reduceOp ),
426
+ root,
427
+ comm));
428
+ });
429
+ return ret_evt;
430
+
431
+ },
432
+ c10d::OpType::REDUCE,
433
+ " oneccl_bindings_for_pytorch::cpu_work::_reduce_oop" );
434
+
435
+ work->debugName = std::string (" cpu::_reduce_oop" );
436
+ enqueue (work);
437
+ return work;
438
+ }
439
+
391
440
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::broadcast_ (std::vector<at::Tensor>& tensors,
392
441
const BroadcastOptions &opts,
393
442
ProcessGroupCCL& pg) {
@@ -596,6 +645,65 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::gather_(std::vecto
596
645
return work;
597
646
}
598
647
648
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_scatter_ (std::vector<at::Tensor>& outputTensors,
649
+ std::vector<std::vector<at::Tensor>>& inputTensors,
650
+ const ReduceScatterOptions& opts,
651
+ ProcessGroupCCL& pg_ccl) {
652
+ checkSingleTensor (outputTensors);
653
+ auto outputTensor = outputTensors.back ();
654
+ auto inputTensors_ = inputTensors.back ();
655
+ bool same_size = check_same_size (inputTensors_);
656
+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
657
+ if (same_size) {
658
+ auto inputFlattened = newLikeFlat (inputTensors_);
659
+ for (const auto j : c10::irange (inputTensors_.size ())) {
660
+ inputFlattened[j].copy_ (inputTensors_[j], true );
661
+ }
662
+ std::vector<at::Tensor> flattendInputTensors{inputFlattened};
663
+ work = collective<get_ccl_comms, CPUWorkCCL>(
664
+ pg_ccl,
665
+ flattendInputTensors,
666
+ outputTensors,
667
+ [=](at::Tensor input,
668
+ at::Tensor output,
669
+ ccl::reduce_attr attr,
670
+ ccl::communicator& comm) {
671
+
672
+ ccl::event ret_evt;
673
+ call_with_lock (c10d::ProcessGroupCCL::globalMutex, [&]() {
674
+ CCL_CHECK (ret_evt = ccl::reduce_scatter (input.data_ptr (),
675
+ output.data_ptr (),
676
+ (size_t ) output.numel (),
677
+ cclDatatypes.at (input.scalar_type ()),
678
+ cclOps.at (opts.reduceOp ),
679
+ comm));
680
+ });
681
+ return ret_evt;
682
+
683
+ },
684
+ c10d::OpType::REDUCE_SCATTER,
685
+ " oneccl_bindings_for_pytorch::cpu_work::reduce_scatter" );
686
+ work->debugName = std::string (" cpu::reduce_scatter" );
687
+ enqueue (work);
688
+ return work;
689
+
690
+ } else {
691
+ // Use multiple reduce to simulate reduce_scatter.
692
+ const auto num_reduces = inputTensors_.size ();
693
+ for (const int i : c10::irange (num_reduces)) {
694
+ auto & input = inputTensors_[i];
695
+ auto & output = (i == pg_ccl.getRank ()) ? outputTensor : input;
696
+ auto reduceOpts = ReduceOptions{
697
+ opts.reduceOp ,
698
+ static_cast <int64_t >(i),
699
+ static_cast <int64_t >(0 ),
700
+ opts.timeout };
701
+ work = _reduce_oop (output, input, reduceOpts, pg_ccl);
702
+ }
703
+ return work;
704
+ }
705
+ }
706
+
599
707
c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::scatter_ (std::vector<at::Tensor>& outputTensors,
600
708
std::vector<std::vector<at::Tensor>>& inputTensors,
601
709
const ScatterOptions& opts,
0 commit comments