Skip to content

Commit

Permalink
[Deepspeed Inference] HF Integration
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Nov 17, 2021
1 parent 790cdc2 commit e78d4b0
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
45 changes: 45 additions & 0 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,51 @@
logger = logging.get_logger(__name__)


# XXX: Reza - need the rest of the map
inference_custom_map = dict(
electra=dict(ElectraLayer=("output.dense")),
roberta=dict(RobertaLayer=("output.dense")),
t5=dict(T5Block=("SelfAttention.o", "EncDecAttention.o", "DenseReluDense.wo")),
)

# XXX: Reza - need the rest of models that are automated
inference_auto = [
"gpt_neo",
]


def deepspeed_inference_init(trainer, model_arch):
"""
XXX:
"""

dep_version_check("deepspeed")
import deepspeed

args = trainer.args

if model_arch in inference_auto:
kwargs = dict(
replace_method="auto",
replace_with_kernel_inject=True,
)
elif model_arch in inference_custom_map:
kwargs = dict(injection_policy=inference_custom_map[model_arch])
else:
raise ValueError(
f"[Deepspeed Inference] {model_arch} hasn't yet been mapped out, please file an Issue to request support for it"
)

deepspeed_inference_engine = deepspeed.init_inference(
trainer.model,
mp_size=args.world_size,
dtype=torch.float, # XXX: Reza: how to define other types? ds config file?
**kwargs,
)

return deepspeed_inference_engine


def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None

Expand Down
20 changes: 18 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_inference_init, deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -359,6 +359,7 @@ def __init__(
if (
self.is_model_parallel
or args.deepspeed
or args.deepspeed_inference
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
Expand Down Expand Up @@ -1797,7 +1798,12 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and data.dtype != torch.int64:
if self.args.deepspeed_inference:
print(data.dtype)
print(kwargs)
return data.to("cuda:0")

elif self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
Expand Down Expand Up @@ -2231,6 +2237,12 @@ def evaluation_loop(
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None

if self.args.deepspeed_inference:
deepspeed_inference_engine = deepspeed_inference_init(self, self.args.deepspeed_inference)
self.model = deepspeed_inference_engine.module
self.model_wrapped = deepspeed_inference_engine
self.deepspeed = deepspeed_inference_engine

model = self._wrap_model(self.model, training=False)

# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
Expand Down Expand Up @@ -2391,6 +2403,10 @@ def _pad_across_processes(self, tensor, pad_index=-100):
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
they can safely be gathered.
"""
# XXX: hangs here with 2 gpus if we don't return
if self.args.deepspeed_inference:
return tensor

if isinstance(tensor, (list, tuple)):
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
elif isinstance(tensor, dict):
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ class TrainingArguments:
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
``ds_config.json``) or an already loaded json file as a :obj:`dict`"
deepspeed_inference (:obj:`str`, `optional`):
Use `Deepspeed Inference <https://www.deepspeed.ai/tutorials/inference-tutorial/>`__. This is an
experimental feature and its API may evolve in the future. The value is the model arch name (e.g., `t5`,
`gpt_neo`, `electra`)
label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
Expand Down Expand Up @@ -609,9 +613,13 @@ class TrainingArguments:
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
"help": "Enable DeepSpeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
},
)
deepspeed_inference: Optional[str] = field(
default=None,
metadata={"help": "Enable DeepSpeed Inference using the name of the architecture as the value"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
Expand Down Expand Up @@ -920,7 +928,7 @@ def _setup_devices(self) -> "torch.device":
self.local_rank = sm_dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.deepspeed:
elif self.deepspeed or self.deepspeed_inference:
# deepspeed inits torch.distributed internally
from .deepspeed import is_deepspeed_available

Expand Down

0 comments on commit e78d4b0

Please sign in to comment.