Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Nov 17, 2021
1 parent e78d4b0 commit 10a382b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,12 +1798,12 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.args.deepspeed_inference:
print(data.dtype)
print(kwargs)
return data.to("cuda:0")
# if self.args.deepspeed_inference:
# print(data.dtype)
# print(kwargs)
# return data.to("cuda:0")

elif self.deepspeed and data.dtype != torch.int64:
if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
Expand Down Expand Up @@ -2404,8 +2404,8 @@ def _pad_across_processes(self, tensor, pad_index=-100):
they can safely be gathered.
"""
# XXX: hangs here with 2 gpus if we don't return
if self.args.deepspeed_inference:
return tensor
# if self.args.deepspeed_inference:
# return tensor

if isinstance(tensor, (list, tuple)):
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
Expand Down

0 comments on commit 10a382b

Please sign in to comment.