Skip to content

Commit 0cfc10f

Browse files
authored
Add check message for unsupport allreduce op (#204)
* Add check message for unsupport allreduce op
1 parent aaf59dc commit 0cfc10f

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/dispatch_stub.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::allreduce(std::v
554554
ProcessGroupCCL& pg_ccl) {
555555
checkSameType(tensors[0], tensors);
556556
c10::DeviceType dev_type = tensors[0].device().type();
557+
check_supported_reduce_op(dev_type, opts.reduceOp);
557558
return get_ccl_stub(dev_type)->allreduce_(tensors, opts, pg_ccl);
558559
}
559560

@@ -562,6 +563,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::allreduce_coales
562563
ProcessGroupCCL& pg_ccl) {
563564
checkSameType(tensors[0], tensors);
564565
c10::DeviceType dev_type = tensors[0].device().type();
566+
check_supported_reduce_op(dev_type, opts.reduceOp);
565567
return get_ccl_stub(dev_type)->allreduce_coalesced_(tensors, opts, pg_ccl);
566568
}
567569

@@ -570,6 +572,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce(std::vect
570572
ProcessGroupCCL& pg_ccl) {
571573
checkSameType(tensors[0], tensors);
572574
c10::DeviceType dev_type = tensors[0].device().type();
575+
check_supported_reduce_op(dev_type, opts.reduceOp);
573576
return get_ccl_stub(dev_type)->reduce_(tensors, opts, pg_ccl);
574577
}
575578

@@ -638,6 +641,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce_scatter(s
638641
const ReduceScatterOptions& opts,
639642
ProcessGroupCCL& pg_ccl) {
640643
c10::DeviceType dev_type = outputTensors[0].device().type();
644+
check_supported_reduce_op(dev_type, opts.reduceOp);
641645
return get_ccl_stub(dev_type)->reduce_scatter_(outputTensors, inputTensors, opts, pg_ccl);
642646
}
643647

@@ -646,6 +650,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::_reduce_scatter_
646650
const ReduceScatterOptions& opts,
647651
ProcessGroupCCL& pg_ccl) {
648652
c10::DeviceType dev_type = inputTensor.device().type();
653+
check_supported_reduce_op(dev_type, opts.reduceOp);
649654
return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl);
650655
}
651656

@@ -657,6 +662,7 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> DispatchStub::reduce_scatter_t
657662
checkSameType(inputTensors[0], inputTensors);
658663
checkSameType(outputTensors[0], outputTensors);
659664
c10::DeviceType dev_type = inputTensors[0].device().type();
665+
check_supported_reduce_op(dev_type, opts.reduceOp);
660666
return get_ccl_stub(dev_type)->reduce_scatter_tensor_coalesced_(outputTensors, inputTensors, opts, pg_ccl);
661667
}
662668

src/dispatch_stub.h

+44
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,48 @@ class DispatchStub {
290290
}
291291
};
292292

293+
} // namespace oneccl_bindings_for_pytorch
294+
295+
namespace {
296+
297+
std::string reduce_op_to_string(c10d::ReduceOp op) {
298+
switch (op) {
299+
case c10d::ReduceOp::SUM:
300+
return "SUM";
301+
case c10d::ReduceOp::PRODUCT:
302+
return "PRODUCT";
303+
case c10d::ReduceOp::MIN:
304+
return "MIN";
305+
case c10d::ReduceOp::MAX:
306+
return "MAX";
307+
case c10d::ReduceOp::BAND:
308+
return "BAND";
309+
case c10d::ReduceOp::BOR:
310+
return "BOR";
311+
case c10d::ReduceOp::BXOR:
312+
return "BXOR";
313+
case c10d::ReduceOp::AVG:
314+
return "AVG";
315+
default:
316+
return "UNKNOWN";
317+
}
293318
}
319+
320+
void check_supported_reduce_op(c10::DeviceType dev_type, c10d::ReduceOp op) {
321+
if (dev_type == c10::DeviceType::XPU) {
322+
switch (op) {
323+
case c10d::ReduceOp::BAND:
324+
case c10d::ReduceOp::BOR:
325+
case c10d::ReduceOp::BXOR:
326+
case c10d::ReduceOp::AVG:
327+
case c10d::ReduceOp::PREMUL_SUM:
328+
case c10d::ReduceOp::UNUSED:
329+
TORCH_CHECK(false, ("Cannot use ReduceOp." + reduce_op_to_string(op) + " with XPU"));
330+
default:
331+
// No action needed for supported operations
332+
break;
333+
}
334+
}
335+
}
336+
337+
} // namespace

0 commit comments

Comments
 (0)