diff --git a/3rdparty/tvm b/3rdparty/tvm index 6d47d37dfe..5400532c4b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6d47d37dfe0e8f7bd079859d2aa744531887dacb +Subproject commit 5400532c4ba37e8a30fcaac488c2ecb05a307e4f diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 206e5d4958..4a058c6e03 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -244,6 +244,7 @@ def __init__( # pylint: disable=too-many-locals rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long @@ -347,6 +348,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long @@ -1539,3 +1541,43 @@ def apply_to_md(sch, block): apply_to_md(sch, sch.get_block("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def copy_single_page( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + return copy_single_page