diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a3cefbdb9b31c9..176dc9181390ae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1798,12 +1798,12 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = dict(device=self.args.device) - if self.args.deepspeed_inference: - print(data.dtype) - print(kwargs) - return data.to("cuda:0") + # if self.args.deepspeed_inference: + # print(data.dtype) + # print(kwargs) + # return data.to("cuda:0") - elif self.deepspeed and data.dtype != torch.int64: + if self.deepspeed and data.dtype != torch.int64: # NLP models inputs are int64 and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model @@ -2404,8 +2404,8 @@ def _pad_across_processes(self, tensor, pad_index=-100): they can safely be gathered. """ # XXX: hangs here with 2 gpus if we don't return - if self.args.deepspeed_inference: - return tensor + # if self.args.deepspeed_inference: + # return tensor if isinstance(tensor, (list, tuple)): return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)