@@ -554,6 +554,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::allreduce(std::v
554
554
ProcessGroupCCL& pg_ccl) {
555
555
checkSameType (tensors[0 ], tensors);
556
556
c10::DeviceType dev_type = tensors[0 ].device ().type ();
557
+ check_supported_reduce_op (dev_type, opts.reduceOp );
557
558
return get_ccl_stub (dev_type)->allreduce_ (tensors, opts, pg_ccl);
558
559
}
559
560
@@ -562,6 +563,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::allreduce_coales
562
563
ProcessGroupCCL& pg_ccl) {
563
564
checkSameType (tensors[0 ], tensors);
564
565
c10::DeviceType dev_type = tensors[0 ].device ().type ();
566
+ check_supported_reduce_op (dev_type, opts.reduceOp );
565
567
return get_ccl_stub (dev_type)->allreduce_coalesced_ (tensors, opts, pg_ccl);
566
568
}
567
569
@@ -570,6 +572,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce(std::vect
570
572
ProcessGroupCCL& pg_ccl) {
571
573
checkSameType (tensors[0 ], tensors);
572
574
c10::DeviceType dev_type = tensors[0 ].device ().type ();
575
+ check_supported_reduce_op (dev_type, opts.reduceOp );
573
576
return get_ccl_stub (dev_type)->reduce_ (tensors, opts, pg_ccl);
574
577
}
575
578
@@ -638,6 +641,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce_scatter(s
638
641
const ReduceScatterOptions& opts,
639
642
ProcessGroupCCL& pg_ccl) {
640
643
c10::DeviceType dev_type = outputTensors[0 ].device ().type ();
644
+ check_supported_reduce_op (dev_type, opts.reduceOp );
641
645
return get_ccl_stub (dev_type)->reduce_scatter_ (outputTensors, inputTensors, opts, pg_ccl);
642
646
}
643
647
@@ -646,6 +650,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::_reduce_scatter_
646
650
const ReduceScatterOptions& opts,
647
651
ProcessGroupCCL& pg_ccl) {
648
652
c10::DeviceType dev_type = inputTensor.device ().type ();
653
+ check_supported_reduce_op (dev_type, opts.reduceOp );
649
654
return get_ccl_stub (dev_type)->_reduce_scatter_base_ (outputTensor, inputTensor, opts, pg_ccl);
650
655
}
651
656
@@ -657,6 +662,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce_scatter_t
657
662
checkSameType (inputTensors[0 ], inputTensors);
658
663
checkSameType (outputTensors[0 ], outputTensors);
659
664
c10::DeviceType dev_type = inputTensors[0 ].device ().type ();
665
+ check_supported_reduce_op (dev_type, opts.reduceOp );
660
666
return get_ccl_stub (dev_type)->reduce_scatter_tensor_coalesced_ (outputTensors, inputTensors, opts, pg_ccl);
661
667
}
662
668
0 commit comments