Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior of memory_efficient_attention with BlockDiagonalMask #1122

Open
xiangxu-google opened this issue Oct 10, 2024 · 0 comments

Comments

@xiangxu-google
Copy link

xiangxu-google commented Oct 10, 2024

import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

torch.manual_seed(0)

# Make q, k, v
B, M, H, K = 1, 10, 4, 8
Mk = 20
q = torch.rand((B, M, H, K)).cuda()
k = torch.rand((B, Mk, H, K)).cuda()
v = torch.rand((B, Mk, H, K)).cuda()

# Try to cut off K and V to shorter length
for new_Mk in [1, 2, 10, 17, 20]:
    new_k = k[:, :new_Mk, :, :]
    new_v = v[:, :new_Mk, :, :]

    # NOTE: here we intentionally create mask with length (M, Mk)
    # rather than (M, new_Mk) to trigger the issue.
    a = BlockDiagonalMask.from_seqlens([M], [Mk])
    result = xops.memory_efficient_attention(q, new_k, new_v, attn_bias=a)

    print(torch.sum(result).tolist())

Run the code on a H100 python:3.10 CUDA:12.4 torch:2.4.0 will get results:

156.18984985351562
156.18984985351562
156.18984985351562
156.18984985351562
156.18984985351562

This raises two questions:

  1. While K and V are cut off to shorter length for each iteration, the mask can still work without any error or warning raised for shape unmatch. Does the broadcasting happen implicitly? This is confusing because torch.sdpa would raise an error for shape unmatch.
  2. Although the code can run without any error raised, the results are unexpectedly wrong because we use different length for K and V for each iteration.

I did some investigation without looking into the source code deeply, here is my guess:

There is some broadcasting happening to allow the mask to be applied:

attention_score: (B, H, M, new_Mk)
attention_mask: (B, H, M, Mk)

When Mk != new_Mk, the attention_score is broadcasted to match attention_mask.

The reason why each iteration yields the same result is that the GPU memory of the origin K and V is held across all iterations, then whenever we need to broadcast the new K and V, they would reuse those values in the origin K and V to calculate the attention score. This leads to a result that the new K and V are not really cut off.

To prove my hypothesis, I changed these two lines:

    new_k = k[:, :new_Mk, :, :].clone()
    new_v = v[:, :new_Mk, :, :].clone()

to prevent the new K and V from reusing the memory, then get the expected results:

16.857088088989258
40.805782318115234
89.96614837646484
144.32723999023438
156.18984985351562

So can we add a check to the attn_bias argument to explicitly raise an error when its shape unmatches the shape of attention score?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant