Skip to content

Commit

Permalink
Fix bug of xformer prefill for encoder-decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxu-google committed Oct 2, 2024
1 parent f58d4fc commit 68650d2
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,25 +562,27 @@ def forward(
self.kv_cache_dtype,
k_scale, v_scale)

if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0

if attn_type == AttentionType.DECODER:
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = attn_metadata.num_decode_tokens


output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
Expand Down

0 comments on commit 68650d2

Please sign in to comment.