Skip to content

Commit

Permalink
Stop calling expand_1d on newer TF versions (#20786)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 authored Dec 16, 2022
1 parent 3ee9582 commit e65445b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,8 @@ def train_step(self, data):
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer TF train steps leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
Expand Down Expand Up @@ -1580,7 +1581,8 @@ def test_step(self, data):
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer versions leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
Expand Down

0 comments on commit e65445b

Please sign in to comment.