From 10a8744699af1deda3f2293e9fcbfd3b32557137 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Dec 2022 17:11:03 +0000 Subject: [PATCH] Stop calling expand_1d on newer TF versions --- src/transformers/modeling_tf_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 132513e0786a13..91779a38b4e938 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1473,7 +1473,8 @@ def train_step(self, data): label_kwargs = find_labels(self.__class__) label_to_output = self.get_label_to_output_name_mapping() output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss: + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer TF train steps leave this out data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify @@ -1580,7 +1581,8 @@ def test_step(self, data): label_kwargs = find_labels(self.__class__) label_to_output = self.get_label_to_output_name_mapping() output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss: + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer versions leave this out data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify