From c3be89a4070287cb98fded112a48a3d295564dea Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 29 Mar 2024 16:09:37 -0700 Subject: [PATCH] [KVCache] Support forking sequence at specific posotion (#16813) This PR enables KVCache to fork a sequence at specific position. --- src/runtime/relax_vm/kv_state.h | 5 +- src/runtime/relax_vm/paged_kv_cache.cc | 127 ++++++++++++++---- src/runtime/relax_vm/rnn_state.cc | 2 +- ...tin_paged_attention_kv_cache_flashinfer.py | 102 ++++++++++++-- ...me_builtin_paged_attention_kv_cache_tir.py | 101 +++++++++++--- .../relax/test_runtime_builtin_rnn_state.py | 2 +- 6 files changed, 283 insertions(+), 56 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f6857a9dceae..e3c6e9608c3f 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -59,9 +59,12 @@ class KVStateObj : public Object { * \param parent_seq_id The parent (source) of the fork. * \param child_seq_id The child (destination) of the fork. * The child sequence id should not exist in cache prior to fork. + * \param fork_pos The parent position to fork, the legal forking position is within + * [0, parent_seq_length] and -1 as default for last position. And if forking position is 0, + * it equals to add a new sequence with child sequence id. * \throws Error if the given sequence ids are not valid. */ - virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0; + virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0; /*! * \brief Pop out the trailing `n` tokens from the KV cache for the diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 9c3ee5d427c2..3ccab3826df9 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_attention_decode_end_forward_; PackedFunc f_merge_inplace_; PackedFunc f_split_rotary_; + PackedFunc f_copy_single_page_; Optional f_debug_get_kv_; /*! \brief Number of fork depth in the current round of forward. */ @@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_attention_prefill_end_forward, Optional f_attention_decode_begin_forward, Optional f_attention_decode_end_forward, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, Optional f_debug_get_kv) + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), num_qo_heads_(num_qo_heads), @@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)), f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), + f_copy_single_page_(std::move(f_copy_single_page)), f_debug_get_kv_(std::move(f_debug_get_kv)), device_(device) { pages_.reserve(num_layers); @@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void RemoveSequence(int64_t seq_id) final { auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; - const Block& block = global_block_pool_[it->second.last_block_idx]; - CHECK_EQ(block.external_ref_cnt, 0) + int32_t block_idx = it->second.last_block_idx; + CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0) << "The sequence is currently referenced by other sequence and thus cannot be removed."; - - // - Decrease the external reference of the parent block. - if (block.parent_idx != -1) { - Block& parent_block = global_block_pool_[block.parent_idx]; - ICHECK_GT(parent_block.external_ref_cnt, 0); - --parent_block.external_ref_cnt; + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + // - Free pages in the last block. + for (int32_t page_id : global_block_pool_[block_idx].page_ids) { + free_page_ids_.push_back(page_id); + } + free_block_idx_.push_back(block_idx); + block_idx = global_block_pool_[block_idx].parent_idx; } - // - Free pages in the last block. - for (int32_t page_id : block.page_ids) { - free_page_ids_.push_back(page_id); + // - Decrease the external reference of the parent block. + if (block_idx != -1) { + ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0); + --global_block_pool_[block_idx].external_ref_cnt; } - // - Remove the sequence from seq_map. - free_block_idx_.push_back(it->second.last_block_idx); seq_map_.erase(it); dirty_aux_data_device_ = true; } - void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final { + void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final { auto parent_it = seq_map_.find(parent_seq_id); CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache."; @@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_EQ(parent_it->second.sliding_window_size, -1) << "The parent sequence \"" << parent_seq_id << "\" is enabled with sliding window and thus cannot be forked."; + CHECK_GE(fork_pos, -1) + << "The forked position should be non-negative, or -1 for last position as default."; + CHECK_LE(fork_pos, parent_it->second.seq_length) + << "The forked position should not exceed the total length of parent sequence."; - int32_t parent_block_idx = parent_it->second.last_block_idx; - ++global_block_pool_[parent_block_idx].external_ref_cnt; - // Create a child block with the parent block pointer. int32_t child_block_idx = GetFreeBlock(); - global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; - global_block_pool_[child_block_idx].parent_idx = parent_block_idx; + if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) { + // Fork at last by appending a new block directly + int32_t parent_block_idx = parent_it->second.last_block_idx; + ++global_block_pool_[parent_block_idx].external_ref_cnt; + // Update child block start position and parent index + global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; + global_block_pool_[child_block_idx].parent_idx = parent_block_idx; + } else { + // Locate the block to fork from and calculate in-block offset + std::vector trace = parent_it->second.GetBlockTrace(global_block_pool_); + int64_t in_block_offset = fork_pos; + int32_t forked_block_idx = -1; + for (int32_t block_idx : trace) { + if (in_block_offset < global_block_pool_[block_idx].seq_length) { + forked_block_idx = block_idx; + break; + } + in_block_offset -= global_block_pool_[block_idx].seq_length; + } + int32_t in_page_offset = in_block_offset % page_size_; + int32_t moved_offset = in_block_offset - in_page_offset; + if (moved_offset == 0) { + // Forked at the first page in block + int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx; + if (parent_block_idx != -1) { + ++global_block_pool_[parent_block_idx].external_ref_cnt; + } + // Update child block start position and parent index + global_block_pool_[child_block_idx].parent_idx = parent_block_idx; + } else { + // Forked at the second or latter page in block + int32_t parent_block_idx = GetFreeBlock(); + // Insert new parent block before forked block and link child block + global_block_pool_[parent_block_idx].parent_idx = + global_block_pool_[forked_block_idx].parent_idx; + global_block_pool_[forked_block_idx].parent_idx = parent_block_idx; + global_block_pool_[child_block_idx].parent_idx = parent_block_idx; + global_block_pool_[parent_block_idx].external_ref_cnt = 1; + + // Move common leading pages to new parent block + auto first_page = global_block_pool_[forked_block_idx].page_ids.begin(); + auto last_page = + global_block_pool_[forked_block_idx].page_ids.begin() + moved_offset / page_size_; + global_block_pool_[parent_block_idx].page_ids = {first_page, last_page}; + global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page); + + // Update start position per blocks + global_block_pool_[parent_block_idx].start_pos = + global_block_pool_[forked_block_idx].start_pos; + global_block_pool_[forked_block_idx].start_pos += moved_offset; + + // Update in-block sequence length per blocks + global_block_pool_[parent_block_idx].seq_length = moved_offset; + global_block_pool_[forked_block_idx].seq_length -= moved_offset; + } + global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset; + global_block_pool_[child_block_idx].seq_length = in_page_offset; + + if (in_page_offset > 0) { + // Fork within a page and copy common page to child block partially + int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0]; + int32_t tgt_page_id = GetFreePage(); + global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id); + CopySinglePage(src_page_id, tgt_page_id, in_page_offset); + } + } // Create the child sequence with the child block. seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)}); dirty_aux_data_device_ = true; } + void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) { + for (int layer = 0; layer < num_layers_; ++layer) { + f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, copy_length); + } + } + void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { CHECK(support_sliding_window_) << "The KV cache does not support sliding window."; @@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - Reset the dirty flag to false. dirty_aux_data_device_ = false; } -}; +}; // namespace relax_vm TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); @@ -1412,7 +1485,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_attention_prefill_end_forward, PackedFunc f_attention_decode_begin_forward, PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, Optional f_debug_get_kv) { + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, + Optional f_debug_get_kv) { CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1435,7 +1509,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv)); + std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), + std::move(f_debug_get_kv)); return AttentionKVCache(std::move(n)); }); @@ -1447,7 +1522,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, Optional f_debug_get_kv) { + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, + Optional f_debug_get_kv) { CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1467,7 +1543,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // - std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv)); + std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), + std::move(f_debug_get_kv)); return AttentionKVCache(std::move(n)); }); diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index 09873ba5f735..69225d6b2c47 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj { dirty_aux_data_device_ = true; } - void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final { + void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final { auto parent_it = seq_map_.find(parent_seq_id); CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id << "\" cannot be found in space state storage."; diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index d30ccd022432..c71b0dde3e61 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -66,6 +66,7 @@ ftranspose_append = None fsplit_rotary = None +fcopy_single_page = None fcopy_cache = None @@ -222,6 +223,46 @@ def copy_cache( ] +def _copy_single_page(num_heads, page_size, head_dim, dtype, target): + tx = 256 if str(target.kind) == "webgpu" else 1024 + + @T.prim_func + def copy_single_page( + pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, vd] + P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, vd] + + return copy_single_page + + def set_global_func(): global fclear, fcreate, fadd_sequence, fremove_sequence, ffork_sequence, fpopn global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, fdebug_get_kv @@ -230,7 +271,7 @@ def set_global_func(): global fattention_prefill_ragged global fattention_prefill_ragged_begin_forward global fattention_prefill_ragged_end_forward - global fattention_merge_state, fsplit_rotary + global fattention_merge_state, fsplit_rotary, fcopy_single_page global ftranspose_append, fcopy_cache fclear = tvm.get_global_func("vm.builtin.kv_state_clear") @@ -282,6 +323,7 @@ def set_global_func(): llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype ), + _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), copy_cache, ]: mod = tvm.IRModule({"main": tir_func}) @@ -290,7 +332,7 @@ def set_global_func(): f = tvm.build(mod["main"], target=target) builts.append(f.entry_func) - ftranspose_append, fsplit_rotary, fcopy_cache = builts + ftranspose_append, fsplit_rotary, fcopy_single_page, fcopy_cache = builts def create_kv_cache(rope_mode): @@ -327,6 +369,7 @@ def create_kv_cache(rope_mode): fattention_decode_end_forward, fattention_merge_state, fsplit_rotary, + fcopy_single_page, fcopy_cache, ) return cache @@ -384,7 +427,7 @@ def f_apply_rotary(x, offset, scale, theta): def apply_attention( kv_cache, rope_mode: RopeMode, - batch: List[Tuple[Union[int, Tuple[int, int]], int]], + batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], cached_k: Dict[int, np.ndarray], cached_v: Dict[int, np.ndarray], ) -> None: @@ -394,16 +437,20 @@ def apply_attention( fork_parent_id = None if isinstance(seq_id, tuple): # Fork sequence - seq_id, fork_parent_id = seq_id + seq_id, fork_parent_id, fork_pos = seq_id batch[i] = (seq_id, append_length) seq_ids.append(seq_id) append_lengths.append(append_length) if fork_parent_id is not None: assert fork_parent_id in cached_k assert seq_id not in cached_k - ffork_sequence(kv_cache, fork_parent_id, seq_id) - cached_k[seq_id] = cached_k[fork_parent_id] - cached_v[seq_id] = cached_v[fork_parent_id] + ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos) + if fork_pos == -1: + cached_k[seq_id] = cached_k[fork_parent_id] + cached_v[seq_id] = cached_v[fork_parent_id] + else: + cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos] + cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos] elif seq_id not in cached_k: fadd_sequence(kv_cache, seq_id) cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) @@ -563,12 +610,15 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode): batch = [(0, 60), (1, 88), (2, 17), (3, 4)] apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) # Fork existing sequences. - apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k, cached_v) + # 0 <- 5 <- 6,8,9 + # 0 <- 7 + # 3 <- 4 # Mixture of decode and prefill. operation_seq = [ [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], @@ -579,6 +629,32 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode): for batch in operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k, cached_v) + + operation_seq = [ + [(6, 1), (11, 1), (13, 1), (9, 1)], + [(10, 1), (16, 1), (18, 1), (19, 1)], + [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)], + [(10, 10), (6, 2), (8, 3), (19, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + for i in range(19, -1, -1): + fremove_sequence(kv_cache, i) + cached_k.pop(i) + cached_v.pop(i) + verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v) + @pytest.mark.skip(reason="Require FlashInfer enabled") def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode): diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index c33686d16e77..3ed89ecd0fee 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -66,6 +66,7 @@ fmerge_state = None fsplit_rotary = None fattention_rotary = None +fcopy_single_page = None def set_global_func(head_dim, dtype): @@ -73,7 +74,7 @@ def set_global_func(head_dim, dtype): global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged global fattn_prefill_sliding_window, fattn_decode_sliding_window - global fmerge_state, fsplit_rotary, fattention_rotary + global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page fclear = tvm.get_global_func("vm.builtin.kv_state_clear") fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") @@ -104,6 +105,7 @@ def set_global_func(head_dim, dtype): llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype ), + _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), ]: mod = tvm.IRModule({"main": tir_func}) with target: @@ -121,6 +123,7 @@ def set_global_func(head_dim, dtype): fattn_prefill_ragged, fmerge_state, fsplit_rotary, + fcopy_single_page, ) = builts @@ -152,6 +155,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fattn_prefill_ragged, fmerge_state, fsplit_rotary, + fcopy_single_page, fcopy_cache, ) return cache @@ -226,7 +230,7 @@ def f_apply_rotary(x, offset, scale, theta): def apply_attention( kv_cache, rope_mode: RopeMode, - batch: List[Tuple[Union[int, Tuple[int, int]], int]], + batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], cached_k: Dict[int, np.ndarray], cached_v: Dict[int, np.ndarray], sliding_window_sizes: Optional[List[int]] = None, @@ -238,16 +242,20 @@ def apply_attention( fork_parent_id = None if isinstance(seq_id, tuple): # Fork sequence - seq_id, fork_parent_id = seq_id + seq_id, fork_parent_id, fork_pos = seq_id batch[i] = (seq_id, append_length) seq_ids.append(seq_id) append_lengths.append(append_length) if fork_parent_id is not None: assert fork_parent_id in cached_k assert seq_id not in cached_k - ffork_sequence(kv_cache, fork_parent_id, seq_id) - cached_k[seq_id] = cached_k[fork_parent_id] - cached_v[seq_id] = cached_v[fork_parent_id] + ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos) + if fork_pos == -1: + cached_k[seq_id] = cached_k[fork_parent_id] + cached_v[seq_id] = cached_v[fork_parent_id] + else: + cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos] + cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos] elif seq_id not in cached_k: fadd_sequence(kv_cache, seq_id) cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) @@ -442,12 +450,15 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): batch = [(0, 60), (1, 88), (2, 17), (3, 4)] apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) # Fork existing sequences. - apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k, cached_v) + # 0 <- 5 <- 6,8,9 + # 0 <- 7 + # 3 <- 4 # Mixture of decode and prefill. operation_seq = [ [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], @@ -458,7 +469,27 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): for batch in operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) - for i in range(9, -1, -1): + apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k, cached_v) + + operation_seq = [ + [(6, 1), (11, 1), (13, 1), (9, 1)], + [(10, 1), (16, 1), (18, 1), (19, 1)], + [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)], + [(10, 10), (6, 2), (8, 3), (19, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + + for i in range(19, -1, -1): fremove_sequence(kv_cache, i) cached_k.pop(i) cached_v.pop(i) @@ -477,7 +508,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): cached_v = {} batch = [(0, 35), (1, 88), (2, 17), (3, 4)] apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) - apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)] for seq_id, pop_length in popn_operations: @@ -539,7 +570,7 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): sliding_window_sizes += [0, 18] attn_sink_sizes += [0, 12] apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v) - ffork_sequence(kv_cache, 5, 6) + ffork_sequence(kv_cache, 5, 6, -1) cached_k[6] = cached_k[5] cached_v[6] = cached_v[5] fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], attn_sink_sizes[-1]) @@ -1845,6 +1876,46 @@ def merge_state_inplace( return merge_state_inplace +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = 256 if str(target.kind) == "webgpu" else 1024 + + @T.prim_func + def copy_single_page( + pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, vd] + P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, vd] + + return copy_single_page + + if __name__ == "__main__": HEAD_DIMS = [64, 128] DTYPES = ["float16", "float32"] diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index 28f370bca037..de35ad5d7793 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -172,7 +172,7 @@ def test_rnn_state_fork_sequence(rnn_state): # pylint: disable=redefined-outer- f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device)) f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), device=device)) f_end_forward(state) - f_fork_sequence(state, 0, 1) + f_fork_sequence(state, 0, 1, -1) verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]]) # Verify popn for the forked sequence f_popn(state, 1, 1)