Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor-Parallelism general support #1512

Merged
merged 17 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def init_inference(model,
dtype=None,
injection_policy=None,
replace_method='auto',
quantization_setting=None):
quantization_setting=None,
replace_with_kernel_inject=False,
return_tuple=True):
"""Initialize the DeepSpeed InferenceEngine.

Arguments:
Expand Down Expand Up @@ -267,6 +269,7 @@ def init_inference(model,
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
all the network except the MLP part that we use 8 extra grouping).
replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine

Returns:
A deepspeed.InferenceEngine wrapped model.
Expand All @@ -286,7 +289,9 @@ def init_inference(model,
checkpoint,
dtype,
injection_policy,
return_tuple,
replace_method,
quantization_setting)
quantization_setting,
replace_with_kernel_inject)

return engine
20 changes: 13 additions & 7 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self,
injection_dict=None,
return_tuple=True,
replace_method='auto',
quantization_setting=None):
quantization_setting=None,
replace_with_kernel_inject=False):
"""
Args:
model: torch.nn.Module
Expand Down Expand Up @@ -74,15 +75,17 @@ def __init__(self,
self.mp_group = self.mpu.get_model_parallel_group()
elif self.mp_world_size > 1:
self._create_model_parallel_group()

# apply injection policy
if self.injection_dict:
for client_module, injection_policy in self.injection_dict.items():
self._apply_injection_policy(client_module,
injection_policy,
return_tuple)
elif replace_method == "auto":
self._apply_injection_policy()
return_tuple,
replace_with_kernel_inject)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
replace_with_kernel_inject=replace_with_kernel_inject)

device = torch.cuda.current_device()
logger.info(f"Place model to device: {device}")
Expand Down Expand Up @@ -152,7 +155,9 @@ def _validate_args(self, mpu):
def _apply_injection_policy(self,
client_module=None,
injection_policy=None,
return_tuple=True):
return_tuple=True,
replace_with_kernel_inject=False):

replace_transformer_layer(client_module,
self.module,
policy=injection_policy,
Expand All @@ -166,7 +171,8 @@ def _apply_injection_policy(self,
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups))
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject)

def _load_checkpoint(self, load_dir, load_module_strict=True):
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
Expand Down
170 changes: 156 additions & 14 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,40 @@
import torch
import deepspeed
import deepspeed.ops.transformer as transformer_inference
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy
from .replace_policy import replace_policies
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE
from ..runtime.weight_quantizer import WeightQuantization
from torch import nn


class LinearAllreduce(nn.Module):
def __init__(self, weight, bias=None, mp_group=None):
super(LinearAllreduce, self).__init__()
self.weight = weight
self.bias = bias
self.mp_group = mp_group

def forward(self, input):
output = torch.matmul(input, self.weight)
if self.mp_group is not None:
torch.distributed.all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class LinearLayer(nn.Module):
def __init__(self, weight, bias=None):
super(LinearLayer, self).__init__()
self.weight = weight
self.bias = bias

def forward(self, input):
output = torch.matmul(input, self.weight)
if self.bias is not None:
output += self.bias
return output


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -103,13 +133,17 @@ def replace_transformer_layer(orig_layer_impl,
training=True,
quantize=False,
quantize_settings=None,
return_tuple=False):
return_tuple=True,
replace_with_kernel_inject=False,
linear_layer_setting=None):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when
replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as
a tuple: (attention_output projection, transformer output projection)
micro_batch_size (int): micro batch size per gpu used during training/eval
config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
Expand All @@ -127,7 +161,12 @@ def replace_transformer_layer(orig_layer_impl,
It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
return_tuple (bool): if set, transformer layer returns a tuple as the output.
Note: this flag needs to be set for huggingface models.

replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring
Tensor-Parallelism
linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers
and embedding layers
attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to
be adjusted based on the model-parallelism
Returns:
Updated nn.module with replaced transformer layers
"""
Expand Down Expand Up @@ -299,18 +338,126 @@ def transpose(data):
new_module.output_b.data = _4hh_b
return new_module

def replace_wo_policy(module, all_reduce_linears):
Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the strategy of this method usually to apply column slice, and if the name of specific layers are input, to apply row slice them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we automate this a bit more? It would be nice to have a strategy that doesn't require parameter names at all.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi You probably thought more than me when you are making this. I'm curious about your opinion.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking briefly, what about profiling strategy?

  • First, We first replace the forward function of each Linear or Conv1D module with the profiling_forward function. This function measures the time the layer has been forwarded. and add get_first_forwared_time function. this function returns the time of first forward. if this time value is exist, profiling_forward no longer measures time.
  • Second, forward the module and get forwarded time from get_first_forwared_time.
  • All forwarded Linear or Conv1D layers except the last forwarded layer are considered columns
  • The last forwarded Linear or Conv1D layer are considered row.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more flexible to just use torch.fx than this. I'll start automate the whole process of tensor & pipeline parallelization using torch.fx.

Copy link
Collaborator

@stas00 stas00 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing the full context. Do you suggest to have a policy record for each model like in the example you have shown here:
#1512 (comment)

I'd help to see several full examples, then it's much easier to see how it can be integrated.

For example I started integrating Deepspeed-Inference huggingface/transformers#14426
after studying a few examples here: microsoft/DeepSpeedExamples#144

So I can see what's common, what's unique, which code sections are the driver and need to go into into the Trainer loop.

Monkey-see, monkey-do style is the easiest w/o needing to figure out all the low-level details.

Does it make sense?

Copy link
Collaborator

@stas00 stas00 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will write some example code after deployment so that you can easily apply it.

Yes, please.

As I reported to you originally it didn't appear that different OSLO components can be integrated separately and require all other OSLO components to work.

So Deepspeed-Inference I can relatively easily integrate into the HF Trainer since it doesn't require me to use anything else other than wrapping the model. We just need to figure out a few MPU quirks. With OSLO I have no idea how to do it, because what I tried didn't work.

But let's not derail this PR and discuss OSLO either on OSLO or HF Transformers Issues. This PR is about Deepspeed-Inference.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has nothing to do with deepspeed, so let's talk about the transformers issue.
huggingface/transformers#13690

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking briefly, what about profiling strategy?

  • First, We first replace the forward function of each Linear or Conv1D module with the profiling_forward function. This function measures the time the layer has been forwarded. and add get_first_forwared_time function. this function returns the time of first forward. if this time value is exist, profiling_forward no longer measures time.
  • Second, forward the module and get forwarded time from get_first_forwared_time.
  • All forwarded Linear or Conv1D layers except the last forwarded layer are considered columns
  • The last forwarded Linear or Conv1D layer are considered row.

I think it is still not so easy to find which one should be using all-reduce, as it can be dependent on the architecture. But, I may miss something here. Maybe, we can have an offline chat about this? Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi Yes, offline chat would be better. When do you like it?

def _replace(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if name in all_reduce_linears:
new_weight = torch.empty(
(child.weight.shape[0]
if conv_linear_layer else child.weight.shape[1] // mp_size,
child.weight.shape[1]
if conv_linear_layer else child.weight.shape[0]),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
if not conv_linear_layer:
child.weight.data.view(-1).copy_(
child.weight.data.transpose(-1,
-2).contiguous().view(-1))
child.weight.data = child.weight.data.reshape(
child.weight.data.shape[-1],
child.weight.data.shape[-2])
data = mp_replace.copy(new_weight,
child.weight.data).to(torch.cuda.current_device())
return LinearAllreduce(data, child.bias if child.bias is None else \
child.bias.to(torch.cuda.current_device()), mp_group)
else:
new_weight = torch.empty(
(child.weight.shape[0] //
mp_size if conv_linear_layer else child.weight.shape[1],
child.weight.shape[1]
if conv_linear_layer else child.weight.shape[0] // mp_size),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
if not conv_linear_layer:
child.weight.data.view(-1).copy_(
child.weight.data.transpose(-1,
-2).contiguous().view(-1))
child.weight.data = child.weight.data.reshape(
child.weight.data.shape[-1],
child.weight.data.shape[-2])
data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(torch.cuda.current_device())
return LinearLayer(data.to(torch.cuda.current_device()), bias_data)

def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0],
child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight, child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
return new_embedding

def update_mp_params(child):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi @stas00 This code works well for a few cases, but I don't think it's a good structure to scale to 70 models. Is there any more efficient way?

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jaketae do you have any nice idea for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true that this part needs some refactoring. Please let me know if you have some ideas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a comment on the issue. please note.
huggingface/transformers#13690 (comment)

if hasattr(child, 'n_heads'):
child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'):
child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'):
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'all_head_size'):
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
child.embed_dim = child.embed_dim // mp_size

conv_linear_layer = False
if linear_layer_setting is not None:
linear_policies = {linear_layer_setting[0]: _replace}
if len(linear_layer_setting) == 2:
linear_policies.update({linear_layer_setting[1]: _slice_embedding})
else:
if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing this, how about loading that module object and checking that it is a Conv1D? In the future, models using Conv1D modules other than GPT2 may exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Thanks @hyunwoongko

try:
import transformers
conv_linear_layer = True
linear_policies = {transformers.model_utils.Conv1D: _replace}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you set GPT2 not to slice embeddings?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tell me if I misunderstood.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That embedding is not part of the layer, but a model. What I am slicing here is the transformer layers. Basically, that is just a small part of the model

except ImportError:
linear_policies = {nn.Linear: _replace}
else:
linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you differentiate between positional embedding and token embedding?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used model.get_input_embeddings(). this is useful in this case.


def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
if child.__class__ in linear_policies:
setattr(
r_module,
name,
linear_policies[child.__class__](child,
prev_name + '.' + name,
conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name)
return r_module

return _replace_module(module)

def replace_fn(child, _policy, layer_id=0):
if training:
# copy relevant state from child -> new module
new_module = replace_with_policy(child, _policy, preln=preln)

else:
# copy relevant state from child -> new module
new_module = replace_with_policy(child,
_policy,
inference=True,
preln=(policy is not HFBertLayerPolicy),
layer_id=layer_id)
if replace_with_kernel_inject:
new_module = replace_with_policy(
child,
_policy,
inference=True,
preln=(_policy is not HFBertLayerPolicy),
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy)

return new_module

Expand All @@ -327,7 +474,6 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
config (dict): model config containing hidden size, attention heads, etc.

Returns:
Updated nn.module with original bert-style transformer layers
"""
Expand Down Expand Up @@ -396,7 +542,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.

Returns:
A modified ``model``.
"""
Expand All @@ -422,20 +567,17 @@ def _replace_module(model, policies, layer_id=0):
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.

Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(
model,
name,
policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id))
new = getattr(model, name)
layer_id += 1
else:
_, layer_id = _replace_module(child, policies, layer_id=layer_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,6 @@ def forward(self,
output = (output, presents)

if self.config.return_tuple:
return (output, )
return output if type(output) is tuple else (output, )
else:
return output