Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve test_pt_tf_model_equivalence on PT side #16731

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 0 additions & 144 deletions tests/clip/test_modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_torch,
require_vision,
slow,
Expand Down Expand Up @@ -602,149 +601,6 @@ def test_load_vision_text_config(self):
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())

# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf

import transformers

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning

if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return

tf_model_class = getattr(transformers, tf_model_class_name)

config.output_hidden_states = True

tf_model = tf_model_class(config)
pt_model = model_class(config)

# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())

# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")

pt_inputs = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)

# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)

self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue

tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()

self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

if len(tf_out.shape) > 0:

tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))

pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0

max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)

# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

with torch.no_grad():
pto = pt_model(**pt_inputs)

tfo = tf_model(tf_inputs_dict)

self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue

tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()

self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))

pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0

max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)

# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
Expand Down
146 changes: 27 additions & 119 deletions tests/lxmert/test_modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@


import copy
import os
import tempfile
import unittest

import numpy as np

import transformers
from transformers import LxmertConfig, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device

from ..test_configuration_common import ConfigTester
from ..test_modeling_common import ModelTesterMixin, ids_tensor
Expand Down Expand Up @@ -527,6 +524,8 @@ def prepare_config_and_inputs_for_common(self, return_obj_labels=False):

if return_obj_labels:
inputs_dict["obj_labels"] = obj_labels
else:
config.task_obj_predict = False

return config, inputs_dict

Expand Down Expand Up @@ -740,121 +739,30 @@ def test_retain_grad_hidden_states_attentions(self):
self.assertIsNotNone(hidden_states_vision.grad)
self.assertIsNotNone(attentions_vision.grad)

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
return_obj_labels="PreTraining" in model_class.__name__
)

tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning

if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return

tf_model_class = getattr(transformers, tf_model_class_name)

config.output_hidden_states = True
config.task_obj_predict = False

pt_model = model_class(config)
tf_model = tf_model_class(config)

# Check we can load pt model in tf and vice-versa with model => model functions
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if type(value) == bool:
return_dict[key] = value
if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value)
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value
)
else:
return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
return return_dict

tf_inputs_dict = recursive_numpy_convert(pt_inputs)

tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()

# Delete obj labels as we want to compute the hidden states and not the loss

if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]

pt_inputs = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict = recursive_numpy_convert(pt_inputs)

with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].cpu().numpy()

tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))

pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0

max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 2e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs)
self.assertLessEqual(max_diff, 6e-2)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()

for key, value in pt_inputs.items():
if key in ("visual_feats", "visual_pos"):
pt_inputs[key] = value.to(torch.float32)
else:
pt_inputs[key] = value.to(torch.long)

with torch.no_grad():
pto = pt_model(**pt_inputs)

tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))

pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0

max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 6e-2)
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):

tf_inputs_dict = {}
for key, value in pt_inputs_dict.items():
# skip key that does not exist in tf
if isinstance(value, dict):
tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
elif type(value) == bool:
tf_inputs_dict[key] = value
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "input_features":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
# other general float inputs
elif value.is_floating_point():
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)

return tf_inputs_dict


@require_torch
Expand Down
Loading