Skip to content

Commit dcee90b

Browse files
authored
fix shard2 global norm (#58402)
1 parent a5274e2 commit dcee90b

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,8 @@ def _create_slice_param(self, param):
540540
# the buffer underlining is not initialized yet
541541
slice_param = EagerParamBase(shape=[1], dtype=param.dtype)
542542
slice_param.name = param.name
543+
if hasattr(param, "is_distributed"):
544+
slice_param.is_distributed = param.is_distributed
543545
self._slice_params[param.name] = slice_param
544546
return slice_param
545547

0 commit comments

Comments
 (0)