Skip to content

Commit

Permalink
[KVCache] Introducing single page copy func for KV cache fork (mlc-ai…
Browse files Browse the repository at this point in the history
…#2060)

This PR introduces the single page copy TIR function for KV cache.
This function is helpful for sequence fork at specified positions.

NOTE: this PR is a breaking change, so you will need to re-compile
your model and update TVM or the MLC-AI pip package to the latest.

Related PR: apache/tvm#16813

Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com>
  • Loading branch information
MasterJH5574 and cyx-6 authored Mar 30, 2024
1 parent 2600a70 commit 9ecc00e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 153 files
42 changes: 42 additions & 0 deletions python/mlc_llm/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def __init__( # pylint: disable=too-many-locals
rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"),
rx.extern("flashinfer.merge_state_in_place"),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
# fmt: on
# pylint: enable=line-too-long
Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"),
bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
# fmt: on
# pylint: enable=line-too-long
Expand Down Expand Up @@ -1539,3 +1541,43 @@ def apply_to_md(sch, block):

apply_to_md(sch, sch.get_block("lse_store"))
return sch.mod["main"].with_attr("tir.is_scheduled", 1)


def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target):
tx = get_max_num_threads_per_block(target)

@T.prim_func
def copy_single_page(
var_pages: T.handle,
src_page_id: T.int64,
tgt_page_id: T.int64,
copy_length: T.int64,
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype)

for b in T.thread_binding(
(copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
):
for t in T.thread_binding(tx, thread="threadIdx.x"):
with T.block("copy"):
vh = T.axis.spatial(
num_heads,
T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),
)
vp = T.axis.spatial(
copy_length,
(b * tx + t) % (copy_length * head_dim) // head_dim,
)
vd = T.axis.spatial(
head_dim,
T.Cast(
"int32",
(b * tx + t) % head_dim,
),
)
pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

return copy_single_page

0 comments on commit 9ecc00e

Please sign in to comment.