diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index e4009f6ac883..71a01b2391f8 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1946,8 +1946,10 @@ def _average_expert_grad_norms(self, norm_groups): for i, norm in enumerate(norm_groups): if self.is_moe_param_group[i]: scaled_norm_tensor = norm * 1.0 / dist.get_world_size(group=self.real_dp_process_group[i]) + if self.device == 'cpu': + scaled_norm_tensor = scaled_norm_tensor.to(get_accelerator().current_device_name()) dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) - norm_groups[i] = scaled_norm_tensor + norm_groups[i] = scaled_norm_tensor.to(self.device) def unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group