Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GQA support for ROCm #21032

Merged
merged 18 commits into from
Jul 3, 2024
Merged

Add GQA support for ROCm #21032

merged 18 commits into from
Jul 3, 2024

Conversation

cloudhan
Copy link
Member

@cloudhan cloudhan commented Jun 13, 2024

@cloudhan cloudhan force-pushed the guangyunhan/rocm-gqa branch 5 times, most recently from b6be9bd to 14d1a1a Compare June 20, 2024 14:43
@cloudhan
Copy link
Member Author

CI test revealed something like the following

kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

.local/lib/python3.9/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py:58: in test_gqa_past_flash_attention
    parity_check_gqa_past(
/onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_cuda.py:1702: in parity_check_gqa_past
    numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7f28fe08e280>, array([[[[-8.6060e-03,  4.1046e-02, -2.5604e-02, ..., ...n, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan]]]], dtype=float16))
kwds = {'equal_nan': True, 'err_msg': ' with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_l...ue, rotary_interleaved=False, packed=True', 'header': 'Not equal to tolerance rtol=0.002, atol=0.002', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.002, atol=0.002
E            with Config(batch_size=5, sequence_length=1, kv_sequence_length=2048, past_sequence_length=227, num_heads=32, kv_num_heads=8, head_size=256, ep=rocm), causal=True, local=False, past_format=1, rotary=True, rotary_interleaved=False, packed=True
E           x and y nan location mismatch:
E            x: array([[[[-8.6060e-03,  4.1046e-02, -2.5604e-02, ..., -7.4829e-02,
E                      5.8060e-03, -2.0828e-03],
E                    [ 4.0207e-03,  7.6523e-03,  1.5244e-02, ..., -4.6326e-02,...
E            y: array([[[[nan, nan, nan, ..., nan, nan, nan],
E                    [nan, nan, nan, ..., nan, nan, nan],
E                    [nan, nan, nan, ..., nan, nan, nan],...

/opt/miniconda/envs/rocm-ci/lib/python3.9/contextlib.py:79: AssertionError

and some sparse 'inf' in other tests. This however, happened to the y value, aka, the reference value. I locally reproduced many of these issue and update torch (along with torch triton) to 2.3.1 eliminate all of them.

@cloudhan cloudhan marked this pull request as ready for review July 1, 2024 04:18
@cloudhan cloudhan requested a review from a team as a code owner July 1, 2024 04:18
@tianleiwu
Copy link
Contributor

tianleiwu commented Jul 1, 2024

LGTM except there is a build error:

https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1423537&view=logs&j=7536d2cd-87d4-54fe-4891-bfbbf2741d83&t=66420422-c7d6-5f71-625c-4b7851c9b9ba&l=3997

CMakeFiles/onnxruntime_providers_rocm.dir/onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu.o
/onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu:5:10: fatal error: 'ck_tile/core/numeric/integer.hpp' file not found
#include "ck_tile/core/numeric/integer.hpp"

@cloudhan
Copy link
Member Author

cloudhan commented Jul 3, 2024

@snnn need an es approve. The some packages in CI are updated due to some nan and inf are produced from the reference impl, see my previous comment.

@cloudhan cloudhan merged commit f39ee14 into main Jul 3, 2024
92 of 100 checks passed
@cloudhan cloudhan deleted the guangyunhan/rocm-gqa branch July 3, 2024 06:55
tianleiwu added a commit that referenced this pull request Jul 17, 2024
The test_flash_attn_rocm.py from
#21032 failed frequently.
For example, I saw two failed jobs today:
E           Max absolute difference: 0.002167
E           Max absolute difference: 0.002686

Adjust the abs threshold from 0.002 to 0.005, and use default relative tolerance rtol=0.001.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants