diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 132513e0786a13..91779a38b4e938 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1473,7 +1473,8 @@ def train_step(self, data): label_kwargs = find_labels(self.__class__) label_to_output = self.get_label_to_output_name_mapping() output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss: + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer TF train steps leave this out data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify @@ -1580,7 +1581,8 @@ def test_step(self, data): label_kwargs = find_labels(self.__class__) label_to_output = self.get_label_to_output_name_mapping() output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss: + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer versions leave this out data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify