Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile compatibilty for decoder-only models #32617

Merged
merged 6 commits into from
Sep 9, 2024

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Aug 12, 2024

What does this PR do?

Recently we merged a few PRs deprecating old-style cache in all decoder-only models. This PR is a continuation of it, here we verify that all newly deprecated models can support static cache and are compatible with torch.compile. The main change is in RoPE to get rid of dynamic control flow

A few exception that cannot be supported yet: MoE models and some other with dynamic control flow like Phi3 or Chameleon.

Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.

In the next PR I can start deprecating old cache in encoder-decoder models starting from Bart and GPT models

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments, mostly about aligning with llama

Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.

💛

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/models/bloom/modeling_bloom.py Outdated Show resolved Hide resolved
@@ -899,9 +895,24 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing #Copied from ... ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, bloom has alibi and needs 2D attention for that. So we can't expand it to 4D, and choose to append zeros to attn to make it static shape.

src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/mixtral/modeling_mixtral.py Outdated Show resolved Hide resolved
src/transformers/models/phi/modeling_phi.py Outdated Show resolved Hide resolved
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py Outdated Show resolved Hide resolved
src/transformers/models/stablelm/modeling_stablelm.py Outdated Show resolved Hide resolved
src/transformers/models/starcoder2/modeling_starcoder2.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

Updated with @gante comments and used the new RoPE modeling in all models. Ready for review!

@zucchini-nlp
Copy link
Member Author

Failing tests are not related

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💎 thanks so much for this tedious work, well done 🥳
What is left is to make sure the compile tests pass !

src/transformers/models/bloom/modeling_bloom.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support compile ? (not seeing the supports_static_cache

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -273,9 +380,29 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is potentially breaking no? (no more offset)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm right, lemme check this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just verified we don't need to slice anymore, because we apply rope directly on the curretn position. Prev we applied Rope for all positions up to the current and had to slice out cached positions


if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
kv_seq_len = key_states.shape[-2] + cache_position[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't remember why we don't use cache_position[-1]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the last position is the whole past kv length, which causes incorrect length in pre-fill or uncached generation. Maybe we should switch to simply past_length = cache_position[-1] everywhere?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for these very laborious changes 🙏

src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
tests/models/falcon/test_modeling_falcon.py Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

@simonJJJ I added the new RoPE embedding for Qwen2-VL in this PR. Since I changes Qwen2, the changes were automatically propagated with copy statements. I remember you had a PR to fix RoPE for FA2 can you check if the current version works as you expect?

@zucchini-nlp
Copy link
Member Author

@ArthurZucker @gante changed deprecation to v4.46 and added qwen2-VL. Ran the tests again to check everything is okey. Let me know if you have any comments

@@ -870,7 +870,7 @@ def _update_causal_mask(
# to infer the attention mask.

# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in Llama, using cache_position is a dynamic control flow which is not supported currently by compile. The fullgraph-compile test fails without this change

@gante
Copy link
Member

gante commented Sep 6, 2024

@zucchini-nlp happy with the changes, feel free to merge! (given that you mentioned that you re-ran the tests 💛 )

@zucchini-nlp
Copy link
Member Author

Yes, was exactly thinking to rebase main and re-ran tests one more time

@zucchini-nlp
Copy link
Member Author

Test are passing, including slow. So, merging

@zucchini-nlp zucchini-nlp merged commit 65bb284 into huggingface:main Sep 9, 2024
23 checks passed
@anijain2305
Copy link
Contributor

Can we update the tracker in #28981

itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* squash into one commit

* add qwen2-vl for rope standardization

* fix mistral compile

* fix qwen2-vl

* fix-copies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants