Skip to content

Commit

Permalink
Improve test_pt_tf_model_equivalence on PT side (#16731)
Browse files Browse the repository at this point in the history
* Update test_pt_tf_model_equivalence on PT side

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Apr 19, 2022
1 parent 3dd57b1 commit e6d23a4
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 602 deletions.
144 changes: 0 additions & 144 deletions tests/clip/test_modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_torch,
require_vision,
slow,
Expand Down Expand Up @@ -602,149 +601,6 @@ def test_load_vision_text_config(self):
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())

# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf

import transformers

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning

if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return

tf_model_class = getattr(transformers, tf_model_class_name)

config.output_hidden_states = True

tf_model = tf_model_class(config)
pt_model = model_class(config)

# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())

# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")

pt_inputs = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)

# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)

self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue

tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()

self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

if len(tf_out.shape) > 0:

tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))

pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0

max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)

# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

with torch.no_grad():
pto = pt_model(**pt_inputs)

tfo = tf_model(tf_inputs_dict)

self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):

if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue

tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()

self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))

pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0

max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)

# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
Expand Down
146 changes: 27 additions & 119 deletions tests/lxmert/test_modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@


import copy
import os
import tempfile
import unittest

import numpy as np

import transformers
from transformers import LxmertConfig, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device

from ..test_configuration_common import ConfigTester
from ..test_modeling_common import ModelTesterMixin, ids_tensor
Expand Down Expand Up @@ -527,6 +524,8 @@ def prepare_config_and_inputs_for_common(self, return_obj_labels=False):

if return_obj_labels:
inputs_dict["obj_labels"] = obj_labels
else:
config.task_obj_predict = False

return config, inputs_dict

Expand Down Expand Up @@ -740,121 +739,30 @@ def test_retain_grad_hidden_states_attentions(self):
self.assertIsNotNone(hidden_states_vision.grad)
self.assertIsNotNone(attentions_vision.grad)

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
return_obj_labels="PreTraining" in model_class.__name__
)

tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning

if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return

tf_model_class = getattr(transformers, tf_model_class_name)

config.output_hidden_states = True
config.task_obj_predict = False

pt_model = model_class(config)
tf_model = tf_model_class(config)

# Check we can load pt model in tf and vice-versa with model => model functions
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if type(value) == bool:
return_dict[key] = value
if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value)
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value
)
else:
return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
return return_dict

tf_inputs_dict = recursive_numpy_convert(pt_inputs)

tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()

# Delete obj labels as we want to compute the hidden states and not the loss

if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]

pt_inputs = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict = recursive_numpy_convert(pt_inputs)

with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].cpu().numpy()

tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))

pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0

max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 2e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs)
self.assertLessEqual(max_diff, 6e-2)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()

for key, value in pt_inputs.items():
if key in ("visual_feats", "visual_pos"):
pt_inputs[key] = value.to(torch.float32)
else:
pt_inputs[key] = value.to(torch.long)

with torch.no_grad():
pto = pt_model(**pt_inputs)

tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))

pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0

max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 6e-2)
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):

tf_inputs_dict = {}
for key, value in pt_inputs_dict.items():
# skip key that does not exist in tf
if isinstance(value, dict):
tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
elif type(value) == bool:
tf_inputs_dict[key] = value
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "input_features":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
# other general float inputs
elif value.is_floating_point():
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)

return tf_inputs_dict


@require_torch
Expand Down
Loading

0 comments on commit e6d23a4

Please sign in to comment.