Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 #11475

Merged
merged 16 commits into from
May 14, 2021

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Apr 27, 2021

What does this PR do?

This PR provides a function called "symbolic_trace" which enables symbolic tracing for models of the library using the new and still experimental torch.fx feature. Our models can't be symbolically traces directly using torch.fx, so this is wrapper function that overcomes various issues.

This new feature allows to perform many kinds of transformations to the graph.

It's also needed for projects like https://github.com/flexflow/FlexFlow/

As an experiment currently only three models are supported: BERT, ELECTRA and T5 (support for other models will follow soon).

@@ -1465,8 +1465,9 @@ def is_tensor(x):
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
if is_torch_available():
import torch
import torch.fx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is a bit tricky since it's a recent addition in pytorch, so it can only be loaded if the pytorch version is right.

Do you know if we want pt-1.8.0 or higher? I think it should be 1.8.0 - we can adjust later if need be.

and of course it'd impact if isinstance below - so probably split isinstance in 2 parts and condition if isinstance(x, torch.fx.Proxy) check on pytorch version

typically we do it by implementing is_torch_fx_available - see a whole bunch of those in file_utils.py

So

if is_torch_fx_available():
    import torch.fx
....

if  is_torch_fx_available() and isinstance(x, torch.fx.Proxy):
    return True

and version you get from:

if version.parse(torch.__version__) >= version.parse("1.8"): 

so now you can implement is_torch_fx_available in file_utiles.py and import it here.

src/transformers/models/bert/modeling_bert.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_bert.py Outdated Show resolved Hide resolved
Comment on lines 772 to 786
if isinstance(input_ids, torch.fx.Proxy):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1].clone()], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if the fx-friendly version is slower and thus we need both?

And we need here and the import on top to add if is_torch_fx_available ...

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So torch.fx can't work with modular interface and has to be replaced with functional? Asking since you replaced many CrossEntropyLoss() with F.cross_entropy?

@stas00
Copy link
Contributor

stas00 commented Apr 27, 2021

Let's also:

  1. add some basic usage doc - we can start with just docstring
  2. add one test for one of the models, polish it and then see how to replicate it for other models.

shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if is_torch_fx_available() and isinstance(input_ids, torch.fx.Proxy):
Copy link
Contributor

@stas00 stas00 May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to use this combo a lot, then down the road we could consider combining these 2 with a helper, so we won't need to repeat the code:

if is_torch_fx_proxy(input_ids): ...

but it's probably perfect as it is for now.

@stas00 stas00 self-requested a review May 5, 2021 17:19
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ready! Great work, @michaelbenayoun

@stas00 stas00 requested review from sgugger and LysandreJik May 5, 2021 17:20
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this experimental feature. My main problem with the PR are the tests: we should have all those new tests refactored in one common test, which will also make it easier to add support for tracing to new architectures in the future.

Comment on lines 28 to 30
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't merge empty asserts (except in tests) so please add a helpful error message :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are verbatim copy from the original pytorch implementation, but yes, it'd definitely be helpful to improve those.

@michaelbenayoun, if you think this code is a keeper let's then do a better error handling then.

Comment on lines 95 to 97
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty asserts

@@ -405,6 +409,309 @@ def create_and_check_for_multiple_choice(
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))

def create_and_check_tracing_for_causal_lm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it could be a common test instead of repeating the same thing for all classes. I suggest adding an attributed test_tracing in the common tester and a common test for tracing that will loop through self.all_model_classes in the test_modeling_common file.

@stas00
Copy link
Contributor

stas00 commented May 6, 2021

My understanding was that this is experimental and as we start using this side of library we will generalize and improve things. Hence the more slack approach.

Same for tests, I thought it was good to start with unique tests because the workarounds are unique and then over time as more models are ported to come up with common tests.

@michaelbenayoun, one way to approach this puzzle is to create common tests for what's the same in all of them, and if something is unique to a given model then have just that tested in that model's test file. If you need help with that, please don't hesitate to ask.

@sgugger
Copy link
Collaborator

sgugger commented May 6, 2021

Even for experimental features like model parallelism, we are using common tests. This should not be different IMO.

@@ -21,6 +21,7 @@
import os

import torch
import torch.nn.functional as F
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this one is no longer needed, is it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it shouldn't be needed anymore, I was wondering if we should keep it or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably remove it then if tests pass.

We are going to get rid of F after we merge this - will use nn.functional everywhere consistently.

@stas00
Copy link
Contributor

stas00 commented May 12, 2021

@sgugger, Michael merged the custom tests into common_tests and significantly simplified the mods to the models - yay!

So it looks ready for your review whenever you have a chance.

Thank you!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I just have a few last comments on documentation/naming and this should be good to be merged!

src/transformers/file_utils.py Outdated Show resolved Hide resolved


class CustomProxy(Proxy):
def __init__(self, node: Node, tracer: Optional[Tracer] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small docstring on what that object does would go a long way for the code maintainability in the future :-)
What does this custom proxy does that the torch.fx does not?

Also should if be HFProxy instead of CustomProxy?



class CustomTracer(Tracer):
def __init__(self, batch_size=1, seqlen=[128, 128], num_choices=-1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the docstring + HFTracer ?

model: PreTrainedModel,
input_names: Optional[List[str]] = None,
batch_size: int = 1,
seqlen: Union[int, List[int]] = [128, 128],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use sequence_length to go with batch_size please.

Comment on lines 203 to 210
model (:obj:`PretrainedModel`): The model to trace.
input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model.
If input_names is None, the model dummy_inputs keys are used instead.
batch_size (:obj:`int`): The batch size of the traced model inputs.
seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs.
For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must
be [encoder_sequence_length, decoder_sequence_length].
num_choices (:obj:`int`): The number of possible choices for MultipleChoice task.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we detail the args, let's use the same style as in other places then:

Suggested change
model (:obj:`PretrainedModel`): The model to trace.
input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model.
If input_names is None, the model dummy_inputs keys are used instead.
batch_size (:obj:`int`): The batch size of the traced model inputs.
seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs.
For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must
be [encoder_sequence_length, decoder_sequence_length].
num_choices (:obj:`int`): The number of possible choices for MultipleChoice task.
model (:obj:`PretrainedModel`):
The model to trace.
input_names (:obj:`List[str]`, `optional`):
The names of the inputs of the traced model. If unset, the model dummy_inputs keys are used instead.
batch_size (:obj:`int`, `optional`, defaults to 1):
The batch size of the traced model inputs.
sequence_length (:obj:`int` or :obj:`List[int]]`):
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence lengths
between the encoder and the decoder inputs, this must
be :obj:`[encoder_sequence_length, decoder_sequence_length]`.
num_choices (:obj:`int`, `optional`, defaults to -1):
The number of possible choices for a multiple choice task.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what is dummy_inputs in that docstring?

@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great refactor!

tests/test_modeling_common.py Outdated Show resolved Hide resolved
michaelbenayoun and others added 2 commits May 14, 2021 10:26
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice implementation! Love the common tests, it's clean.

Before adding other models, I would prioritize writing a small documentation (which can probably live under "Advanced guides" in the transformers doc) explaining what this is and how it can be used; this will help understandability and maintainability in the future. (This can be done in a future PR)

Of course it looks as if it is still early in the developments so no need for a very thorough doc - just enough to get a grasp of what's happening without necessarily playing with the code first.

Comment on lines +467 to +473
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict)

def test_torch_fx_output_loss(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How heavy in terms of memory/processing power are those tests? It makes me think of the Keras tests that we had to eventually disable as it took more than a couple of minutes per model class, which wasn't feasible for CI.

Copy link
Contributor

@stas00 stas00 May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From: https://circle-production-customer-artifacts.s3.amazonaws.com/picard/forks/5bdabdd888af1f000130874a/360197618/609eb688e2ebd83fbb07349a-0-build/artifacts/~/transformers/reports/tests_torch_durations.txt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20210514T180650Z&X-Amz-SignedHeaders=host&X-Amz-Expires=60&X-Amz-Credential=AKIAJR3Q6CR467H7Z55A%2F20210514%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=2f132f54ff812769cb205d414bf497061226c7c6fd2260ef02711b19e540050c

3.71s call     tests/test_modeling_bert.py::BertModelTest::test_torch_fx
0.71s call     tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx
0.62s call     tests/test_modeling_t5.py::T5ModelTest::test_torch_fx

3.56s call     tests/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss
0.86s call     tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss
0.58s call     tests/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss

It's interesting Bert is 6x slower than t5. Any idea why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I would say that's because there are more model classes for BERT than for T5.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! the answer was simple then ;) Thank you!

@michaelbenayoun michaelbenayoun merged commit 86d5fb0 into huggingface:master May 14, 2021
@nbcsm
Copy link

nbcsm commented Jun 3, 2021

Sorry for jumping in.
Out of curiosity, what is the scenario to use this symbolic tracing feature? Didn't find any example/doc...
Thanks.

@stas00
Copy link
Contributor

stas00 commented Jun 3, 2021

Well, I initially wanted this in order to be able to try https://github.com/flexflow/FlexFlow, which requires symbolic tracing - but I haven't had a chance to do so yet.

@nbcsm
Copy link

nbcsm commented Jun 3, 2021

Got it, thanks for the explanation.

Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
… and T5 (huggingface#11475)

Symbolic tracing feature for BERT, ELECTRA and T5

Co-authored-by: Michael Benayoun <michael@huggingface.co>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@HamidShojanazeri
Copy link
Contributor

Sorry for jumping in.
Out of curiosity, what is the scenario to use this symbolic tracing feature? Didn't find any example/doc...
Thanks.

This would be also be helpful to quantize models using FX Graph Mode Quantization which automate the quantization process in Pytorch.

@JackFram
Copy link

Are these updates still functional currently? As no modeling_fx_utils.py can be seen in the source code directory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants