-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 #11475
Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 #11475
Conversation
src/transformers/file_utils.py
Outdated
@@ -1465,8 +1465,9 @@ def is_tensor(x): | |||
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`.""" | |||
if is_torch_available(): | |||
import torch | |||
import torch.fx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one is a bit tricky since it's a recent addition in pytorch, so it can only be loaded if the pytorch version is right.
Do you know if we want pt-1.8.0 or higher? I think it should be 1.8.0 - we can adjust later if need be.
and of course it'd impact if isinstance
below - so probably split isinstance
in 2 parts and condition if isinstance(x, torch.fx.Proxy)
check on pytorch version
typically we do it by implementing is_torch_fx_available
- see a whole bunch of those in file_utils.py
So
if is_torch_fx_available():
import torch.fx
....
if is_torch_fx_available() and isinstance(x, torch.fx.Proxy):
return True
and version you get from:
if version.parse(torch.__version__) >= version.parse("1.8"):
so now you can implement is_torch_fx_available
in file_utiles.py
and import it here.
if isinstance(input_ids, torch.fx.Proxy): | ||
# Item assignment is not supported natively for proxies. | ||
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) | ||
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1].clone()], dim=-1) | ||
else: | ||
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | ||
shifted_input_ids[..., 0] = decoder_start_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know if the fx-friendly version is slower and thus we need both?
And we need here and the import on top to add if is_torch_fx_available ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So torch.fx can't work with modular interface and has to be replaced with functional? Asking since you replaced many CrossEntropyLoss()
with F.cross_entropy
?
Let's also:
|
d711cf4
to
e98218c
Compare
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | ||
shifted_input_ids[..., 0] = decoder_start_token_id | ||
if is_torch_fx_available() and isinstance(input_ids, torch.fx.Proxy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to use this combo a lot, then down the road we could consider combining these 2 with a helper, so we won't need to repeat the code:
if is_torch_fx_proxy(input_ids): ...
but it's probably perfect as it is for now.
dc99b65
to
fae6f03
Compare
…pdated the models that were causing utils/check_copies.py to complain
33145ec
to
aaadd24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks ready! Great work, @michaelbenayoun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this experimental feature. My main problem with the PR are the tests: we should have all those new tests refactored in one common test, which will also make it easier to add support for tracing to new architectures in the future.
assert frame is not None | ||
calling_frame = frame.f_back | ||
assert calling_frame is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't merge empty asserts (except in tests) so please add a helpful error message :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are verbatim copy from the original pytorch implementation, but yes, it'd definitely be helpful to improve those.
@michaelbenayoun, if you think this code is a keeper let's then do a better error handling then.
assert frame is not None | ||
calling_frame = frame.f_back | ||
assert calling_frame is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty asserts
tests/test_modeling_bert.py
Outdated
@@ -405,6 +409,309 @@ def create_and_check_for_multiple_choice( | |||
) | |||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) | |||
|
|||
def create_and_check_tracing_for_causal_lm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it could be a common test instead of repeating the same thing for all classes. I suggest adding an attributed test_tracing
in the common tester and a common test for tracing that will loop through self.all_model_classes
in the test_modeling_common
file.
My understanding was that this is experimental and as we start using this side of library we will generalize and improve things. Hence the more slack approach. Same for tests, I thought it was good to start with unique tests because the workarounds are unique and then over time as more models are ported to come up with common tests. @michaelbenayoun, one way to approach this puzzle is to create common tests for what's the same in all of them, and if something is unique to a given model then have just that tested in that model's test file. If you need help with that, please don't hesitate to ask. |
Even for experimental features like model parallelism, we are using common tests. This should not be different IMO. |
…ing of modules defined in the forward pass, making the whole nn.functional renaming not needed anymore
…les instanciated in the forward pass
@@ -21,6 +21,7 @@ | |||
import os | |||
|
|||
import torch | |||
import torch.nn.functional as F |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this one is no longer needed, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it shouldn't be needed anymore, I was wondering if we should keep it or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably remove it then if tests pass.
We are going to get rid of F
after we merge this - will use nn.functional
everywhere consistently.
@sgugger, Michael merged the custom tests into common_tests and significantly simplified the mods to the models - yay! So it looks ready for your review whenever you have a chance. Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I just have a few last comments on documentation/naming and this should be good to be merged!
|
||
|
||
class CustomProxy(Proxy): | ||
def __init__(self, node: Node, tracer: Optional[Tracer] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small docstring on what that object does would go a long way for the code maintainability in the future :-)
What does this custom proxy does that the torch.fx does not?
Also should if be HFProxy
instead of CustomProxy
?
|
||
|
||
class CustomTracer(Tracer): | ||
def __init__(self, batch_size=1, seqlen=[128, 128], num_choices=-1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for the docstring + HFTracer
?
model: PreTrainedModel, | ||
input_names: Optional[List[str]] = None, | ||
batch_size: int = 1, | ||
seqlen: Union[int, List[int]] = [128, 128], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use sequence_length
to go with batch_size
please.
model (:obj:`PretrainedModel`): The model to trace. | ||
input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model. | ||
If input_names is None, the model dummy_inputs keys are used instead. | ||
batch_size (:obj:`int`): The batch size of the traced model inputs. | ||
seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs. | ||
For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must | ||
be [encoder_sequence_length, decoder_sequence_length]. | ||
num_choices (:obj:`int`): The number of possible choices for MultipleChoice task. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we detail the args, let's use the same style as in other places then:
model (:obj:`PretrainedModel`): The model to trace. | |
input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model. | |
If input_names is None, the model dummy_inputs keys are used instead. | |
batch_size (:obj:`int`): The batch size of the traced model inputs. | |
seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs. | |
For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must | |
be [encoder_sequence_length, decoder_sequence_length]. | |
num_choices (:obj:`int`): The number of possible choices for MultipleChoice task. | |
model (:obj:`PretrainedModel`): | |
The model to trace. | |
input_names (:obj:`List[str]`, `optional`): | |
The names of the inputs of the traced model. If unset, the model dummy_inputs keys are used instead. | |
batch_size (:obj:`int`, `optional`, defaults to 1): | |
The batch size of the traced model inputs. | |
sequence_length (:obj:`int` or :obj:`List[int]]`): | |
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence lengths | |
between the encoder and the decoder inputs, this must | |
be :obj:`[encoder_sequence_length, decoder_sequence_length]`. | |
num_choices (:obj:`int`, `optional`, defaults to -1): | |
The number of possible choices for a multiple choice task. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what is dummy_inputs
in that docstring?
@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): | |||
else () | |||
) | |||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () | |||
fx_ready_model_classes = all_model_classes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great refactor!
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice implementation! Love the common tests, it's clean.
Before adding other models, I would prioritize writing a small documentation (which can probably live under "Advanced guides" in the transformers doc) explaining what this is and how it can be used; this will help understandability and maintainability in the future. (This can be done in a future PR)
Of course it looks as if it is still early in the developments so no need for a very thorough doc - just enough to get a grasp of what's happening without necessarily playing with the code first.
def test_torch_fx(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
self._create_and_check_torch_fx_tracing(config, inputs_dict) | ||
|
||
def test_torch_fx_output_loss(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How heavy in terms of memory/processing power are those tests? It makes me think of the Keras tests that we had to eventually disable as it took more than a couple of minutes per model class, which wasn't feasible for CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3.71s call tests/test_modeling_bert.py::BertModelTest::test_torch_fx
0.71s call tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx
0.62s call tests/test_modeling_t5.py::T5ModelTest::test_torch_fx
3.56s call tests/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss
0.86s call tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss
0.58s call tests/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss
It's interesting Bert is 6x slower than t5. Any idea why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I would say that's because there are more model classes for BERT than for T5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! the answer was simple then ;) Thank you!
Sorry for jumping in. |
Well, I initially wanted this in order to be able to try https://github.com/flexflow/FlexFlow, which requires symbolic tracing - but I haven't had a chance to do so yet. |
Got it, thanks for the explanation. |
… and T5 (huggingface#11475) Symbolic tracing feature for BERT, ELECTRA and T5 Co-authored-by: Michael Benayoun <michael@huggingface.co> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This would be also be helpful to quantize models using FX Graph Mode Quantization which automate the quantization process in Pytorch. |
Are these updates still functional currently? As no modeling_fx_utils.py can be seen in the source code directory. |
What does this PR do?
This PR provides a function called "symbolic_trace" which enables symbolic tracing for models of the library using the new and still experimental torch.fx feature. Our models can't be symbolically traces directly using
torch.fx
, so this is wrapper function that overcomes various issues.This new feature allows to perform many kinds of transformations to the graph.
It's also needed for projects like https://github.com/flexflow/FlexFlow/
As an experiment currently only three models are supported: BERT, ELECTRA and T5 (support for other models will follow soon).