Skip to content

Commit

Permalink
Fix VarType
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Feb 4, 2024
1 parent 062b99e commit 249a7da
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 41 deletions.
4 changes: 2 additions & 2 deletions python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def apply_gradients(self, params_grads):
found_inf = self._check_finite_and_unscale(params_grads)
if (
self._use_dynamic_loss_scaling
and self._amp_vartype == core.VarDesc.VarType.FP16
and self._amp_vartype == paddle.float16
):
self._add_dynamic_loss_scaling(params_grads, found_inf)

Expand All @@ -507,7 +507,7 @@ def apply_gradients(self, params_grads):

def _split_grads(self, params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp32_grads = [g for g in grads if g.dtype == paddle.float32]
fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/static/amp/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def get_low_precision_dtypestr(dtype):
if isinstance(dtype, str):
return check_amp_dtype(dtype)
elif isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.FP16:
if dtype == paddle.float16:
return "float16"
elif dtype == core.VarDesc.VarType.BF16:
elif dtype == paddle.bfloat16:
return "bfloat16"
else:
raise ValueError(
Expand Down
15 changes: 5 additions & 10 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops = 0

for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name
):
if src_dtype == paddle.float32 and _keep_fp32_input(op, in_name):
continue
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
Expand All @@ -210,10 +208,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
# set cast_op device to `all`, can reduce send cast_var.
# TODO: need remove this after we unified the dynamic
# and static pipeline interface.
if (
src_dtype == core.VarDesc.VarType.FP32
and in_var.stop_gradient
):
if src_dtype == paddle.float32 and in_var.stop_gradient:
prev_op = None
if in_var.op is op:
prev_op = find_true_prev_op(
Expand Down Expand Up @@ -527,7 +522,7 @@ def get_promote_dtype(op, amp_dtype, block):
if in_name:
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var and in_var.dtype == core.VarDesc.VarType.FP32:
if in_var and in_var.dtype == paddle.float32:
dst_dtype = core.VarDesc.VarType.FP32
break
else:
Expand Down Expand Up @@ -915,7 +910,7 @@ def cast_parameters_to_fp16(
if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
if dest_type == paddle.bfloat16:
p_array = _convert_float_to_bfloat16(place, data)
param_t.set(p_array, place)
else:
Expand Down Expand Up @@ -952,7 +947,7 @@ def update_role_var_grad(main_prog, params_grads):
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if g.dtype == paddle.float32 and op.type == 'cast':
role = op.attr('op_role')
if role & int(BACKWARD) and op.has_attr('op_role_var'):
op._remove_attr("op_role_var")
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def instance_norm(
dtype = helper.input_dtype()

# use fp32 for in parameter
if dtype == paddle.framework.core.VarDesc.VarType.FP16:
if dtype == paddle.float16:
dtype = paddle.framework.core.VarDesc.VarType.FP32

input_shape = input.shape
Expand Down Expand Up @@ -2765,7 +2765,7 @@ def batch_norm(
dtype = helper.input_dtype()

# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16 or dtype == core.VarDesc.VarType.BF16:
if dtype == paddle.float16 or dtype == paddle.bfloat16:
dtype = core.VarDesc.VarType.FP32

input_shape = input.shape
Expand Down
50 changes: 25 additions & 25 deletions python/paddle/static/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ def _insert_quant_abs_max_op(
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -563,9 +563,9 @@ def _insert_quant_range_abs_max_op(
)

scale_name = self._quantized_scale_name(name)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -595,9 +595,9 @@ def _insert_quant_range_abs_max_op(
shape=[self._window_size],
var_dtype=var_node.dtype(),
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -645,9 +645,9 @@ def _insert_quant_moving_average_abs_max_op(
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -675,9 +675,9 @@ def _insert_quant_moving_average_abs_max_op(
var_dtype=var_node.dtype(),
shape=[1],
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -753,9 +753,9 @@ def _insert_channel_quant_op(
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -1291,9 +1291,9 @@ def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
var_dtype=output_var_node.dtype(),
)

if output_var_node.dtype() == core.VarDesc.VarType.FP64:
if output_var_node.dtype() == paddle.float64:
data_type = 'float64'
elif output_var_node.dtype() == core.VarDesc.VarType.FP32:
elif output_var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -1643,9 +1643,9 @@ def apply(self, graph):
):
continue

if in_node.dtype() == core.VarDesc.VarType.FP64:
if in_node.dtype() == paddle.float64:
data_type = 'float64'
elif in_node.dtype() == core.VarDesc.VarType.FP32:
elif in_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -1997,9 +1997,9 @@ def _inser_quant_dequant_moving_average_abs_max_op(
var_dtype=var_node.dtype(),
)
scale_name = f"{var_node.name()}.quant_dequant@scale"
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -2037,9 +2037,9 @@ def _inser_quant_dequant_moving_average_abs_max_op(
var_dtype=var_node.dtype(),
shape=[1],
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -2156,9 +2156,9 @@ def insert_quant_op(
var_dtype=var_node.dtype(),
)
if not scale_var_node:
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -2226,9 +2226,9 @@ def insert_quant_op(
var_dtype=var_node.dtype(),
shape=[1],
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
if var_node.dtype() == paddle.float64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
elif var_node.dtype() == paddle.float32:
data_type = 'float32'
else:
data_type = "float16"
Expand Down Expand Up @@ -3419,7 +3419,7 @@ def _insert_quant_dequant_op(self, graph, var_node):
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
if var_node.dtype() == paddle.float64
else 'float32'
)
_init_var_node(
Expand Down

0 comments on commit 249a7da

Please sign in to comment.