Skip to content

Commit

Permalink
Generate: PT Dynamo without graph breaks in the main greedy/sample lo…
Browse files Browse the repository at this point in the history
…op (#21648)
  • Loading branch information
gante authored Feb 15, 2023
1 parent 7a5533b commit 1567bef
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def __init__(self, **kwargs):
self.validate()

def __eq__(self, other):
if not isinstance(other, GenerationConfig):
return False

self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
if is_torch_tpu_available():
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES:
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
if t.dtype == torch.float:
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES:
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
return torch.float32
if t.dtype == torch.double:
return torch.float32
return t.dtype

if last_dtype is not None:
Expand Down

0 comments on commit 1567bef

Please sign in to comment.