Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor-Parallelism general support #1512

Merged
merged 17 commits into from
Nov 12, 2021
Merged

Conversation

RezaYazdaniAminabadi
Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi commented Nov 2, 2021

This PR provides support for model parallelism during inference without the need for injecting the kernels.

Add the PR in DeepSpeed-Example branch and verify the tensor-parallelism functionality on different model architectures:

  • Bert
  • Roberta
  • GPT2
  • GPT-Neo
  • GPT-J
  • Wav2vec2
  • T5

cc: @stas00 @hyunwoongko

@jeffra jeffra merged commit 9ce00a2 into master Nov 12, 2021
@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 22, 2022

I will review this today and apply it to oslo.

cc @stas00 @RezaYazdaniAminabadi

new_embedding.weight.data.copy_(data)
return new_embedding

def update_mp_params(child):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi @stas00 This code works well for a few cases, but I don't think it's a good structure to scale to 70 models. Is there any more efficient way?

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jaketae do you have any nice idea for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true that this part needs some refactoring. Please let me know if you have some ideas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a comment on the issue. please note.
huggingface/transformers#13690 (comment)

@@ -299,18 +338,126 @@ def transpose(data):
new_module.output_b.data = _4hh_b
return new_module

def replace_wo_policy(module, all_reduce_linears):
Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the strategy of this method usually to apply column slice, and if the name of specific layers are input, to apply row slice them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we automate this a bit more? It would be nice to have a strategy that doesn't require parameter names at all.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi You probably thought more than me when you are making this. I'm curious about your opinion.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking briefly, what about profiling strategy?

  • First, We first replace the forward function of each Linear or Conv1D module with the profiling_forward function. This function measures the time the layer has been forwarded. and add get_first_forwared_time function. this function returns the time of first forward. if this time value is exist, profiling_forward no longer measures time.
  • Second, forward the module and get forwarded time from get_first_forwared_time.
  • All forwarded Linear or Conv1D layers except the last forwarded layer are considered columns
  • The last forwarded Linear or Conv1D layer are considered row.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more flexible to just use torch.fx than this. I'll start automate the whole process of tensor & pipeline parallelization using torch.fx.

Copy link
Collaborator

@stas00 stas00 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing the full context. Do you suggest to have a policy record for each model like in the example you have shown here:
#1512 (comment)

I'd help to see several full examples, then it's much easier to see how it can be integrated.

For example I started integrating Deepspeed-Inference huggingface/transformers#14426
after studying a few examples here: microsoft/DeepSpeedExamples#144

So I can see what's common, what's unique, which code sections are the driver and need to go into into the Trainer loop.

Monkey-see, monkey-do style is the easiest w/o needing to figure out all the low-level details.

Does it make sense?

Copy link
Collaborator

@stas00 stas00 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will write some example code after deployment so that you can easily apply it.

Yes, please.

As I reported to you originally it didn't appear that different OSLO components can be integrated separately and require all other OSLO components to work.

So Deepspeed-Inference I can relatively easily integrate into the HF Trainer since it doesn't require me to use anything else other than wrapping the model. We just need to figure out a few MPU quirks. With OSLO I have no idea how to do it, because what I tried didn't work.

But let's not derail this PR and discuss OSLO either on OSLO or HF Transformers Issues. This PR is about Deepspeed-Inference.

Copy link
Contributor

@hyunwoongko hyunwoongko Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has nothing to do with deepspeed, so let's talk about the transformers issue.
huggingface/transformers#13690

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking briefly, what about profiling strategy?

  • First, We first replace the forward function of each Linear or Conv1D module with the profiling_forward function. This function measures the time the layer has been forwarded. and add get_first_forwared_time function. this function returns the time of first forward. if this time value is exist, profiling_forward no longer measures time.
  • Second, forward the module and get forwarded time from get_first_forwared_time.
  • All forwarded Linear or Conv1D layers except the last forwarded layer are considered columns
  • The last forwarded Linear or Conv1D layer are considered row.

I think it is still not so easy to find which one should be using all-reduce, as it can be dependent on the architecture. But, I may miss something here. Maybe, we can have an offline chat about this? Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi Yes, offline chat would be better. When do you like it?

if len(linear_layer_setting) == 2:
linear_policies.update({linear_layer_setting[1]: _slice_embedding})
else:
if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing this, how about loading that module object and checking that it is a Conv1D? In the future, models using Conv1D modules other than GPT2 may exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Thanks @hyunwoongko

try:
import transformers
conv_linear_layer = True
linear_policies = {transformers.model_utils.Conv1D: _replace}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you set GPT2 not to slice embeddings?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tell me if I misunderstood.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That embedding is not part of the layer, but a model. What I am slicing here is the transformer layers. Basically, that is just a small part of the model

except ImportError:
linear_policies = {nn.Linear: _replace}
else:
linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you differentiate between positional embedding and token embedding?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used model.get_input_embeddings(). this is useful in this case.

@RezaYazdaniAminabadi
Copy link
Contributor Author

I will analyze this today and apply it to oslo.

cc @stas00 @RezaYazdaniAminabadi

Thanks a lot @hyunwoongko for the thorough review on this PR. I will use some of your feedback to make this stronger.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants