From 68650d2dfb8ba339080b6b7ef3dedfdab0a720c0 Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Wed, 2 Oct 2024 11:16:06 -0700 Subject: [PATCH] Fix bug of xformer prefill for encoder-decoder --- vllm/attention/backends/xformers.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a3f9ff64f8b8b..3dfc10c26bada 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -562,25 +562,27 @@ def forward( self.kv_cache_dtype, k_scale, v_scale) - if attn_type != AttentionType.ENCODER: - # Decoder self-attention supports chunked prefill. - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - else: + if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 - - if attn_type == AttentionType.DECODER: + elif attn_type == AttentionType.DECODER: + # Decoder self-attention supports chunked prefill. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + else: # attn_type == AttentionType.ENCODER_DECODER + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached.