Skip to content

Commit

Permalink
[KVCache] Support forking sequence at specific posotion (#16813)
Browse files Browse the repository at this point in the history
This PR enables KVCache to fork a sequence at specific position.
  • Loading branch information
cyx-6 authored Mar 29, 2024
1 parent 5daa303 commit c3be89a
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 56 deletions.
5 changes: 4 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ class KVStateObj : public Object {
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
* The child sequence id should not exist in cache prior to fork.
* \param fork_pos The parent position to fork, the legal forking position is within
* [0, parent_seq_length] and -1 as default for last position. And if forking position is 0,
* it equals to add a new sequence with child sequence id.
* \throws Error if the given sequence ids are not valid.
*/
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0;
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0;

/*!
* \brief Pop out the trailing `n` tokens from the KV cache for the
Expand Down
127 changes: 102 additions & 25 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Optional<PackedFunc> f_attention_decode_end_forward_;
PackedFunc f_merge_inplace_;
PackedFunc f_split_rotary_;
PackedFunc f_copy_single_page_;
Optional<PackedFunc> f_debug_get_kv_;

/*! \brief Number of fork depth in the current round of forward. */
Expand Down Expand Up @@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
num_qo_heads_(num_qo_heads),
Expand Down Expand Up @@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
Expand Down Expand Up @@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
void RemoveSequence(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
const Block& block = global_block_pool_[it->second.last_block_idx];
CHECK_EQ(block.external_ref_cnt, 0)
int32_t block_idx = it->second.last_block_idx;
CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
<< "The sequence is currently referenced by other sequence and thus cannot be removed.";

// - Decrease the external reference of the parent block.
if (block.parent_idx != -1) {
Block& parent_block = global_block_pool_[block.parent_idx];
ICHECK_GT(parent_block.external_ref_cnt, 0);
--parent_block.external_ref_cnt;
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
// - Free pages in the last block.
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
}
// - Free pages in the last block.
for (int32_t page_id : block.page_ids) {
free_page_ids_.push_back(page_id);
// - Decrease the external reference of the parent block.
if (block_idx != -1) {
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
--global_block_pool_[block_idx].external_ref_cnt;
}
// - Remove the sequence from seq_map.
free_block_idx_.push_back(it->second.last_block_idx);
seq_map_.erase(it);
dirty_aux_data_device_ = true;
}

void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end())
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
Expand All @@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
CHECK_EQ(parent_it->second.sliding_window_size, -1)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus cannot be forked.";
CHECK_GE(fork_pos, -1)
<< "The forked position should be non-negative, or -1 for last position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent sequence.";

int32_t parent_block_idx = parent_it->second.last_block_idx;
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Create a child block with the parent block pointer.
int32_t child_block_idx = GetFreeBlock();
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
int64_t in_block_offset = fork_pos;
int32_t forked_block_idx = -1;
for (int32_t block_idx : trace) {
if (in_block_offset < global_block_pool_[block_idx].seq_length) {
forked_block_idx = block_idx;
break;
}
in_block_offset -= global_block_pool_[block_idx].seq_length;
}
int32_t in_page_offset = in_block_offset % page_size_;
int32_t moved_offset = in_block_offset - in_page_offset;
if (moved_offset == 0) {
// Forked at the first page in block
int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx;
if (parent_block_idx != -1) {
++global_block_pool_[parent_block_idx].external_ref_cnt;
}
// Update child block start position and parent index
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
} else {
// Forked at the second or latter page in block
int32_t parent_block_idx = GetFreeBlock();
// Insert new parent block before forked block and link child block
global_block_pool_[parent_block_idx].parent_idx =
global_block_pool_[forked_block_idx].parent_idx;
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
global_block_pool_[parent_block_idx].external_ref_cnt = 1;

// Move common leading pages to new parent block
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
auto last_page =
global_block_pool_[forked_block_idx].page_ids.begin() + moved_offset / page_size_;
global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page);

// Update start position per blocks
global_block_pool_[parent_block_idx].start_pos =
global_block_pool_[forked_block_idx].start_pos;
global_block_pool_[forked_block_idx].start_pos += moved_offset;

// Update in-block sequence length per blocks
global_block_pool_[parent_block_idx].seq_length = moved_offset;
global_block_pool_[forked_block_idx].seq_length -= moved_offset;
}
global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
global_block_pool_[child_block_idx].seq_length = in_page_offset;

if (in_page_offset > 0) {
// Fork within a page and copy common page to child block partially
int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
int32_t tgt_page_id = GetFreePage();
global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
}
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)});
dirty_aux_data_device_ = true;
}

void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) {
for (int layer = 0; layer < num_layers_; ++layer) {
f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, copy_length);
}
}

void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) final {
CHECK(support_sliding_window_) << "The KV cache does not support sliding window.";
Expand Down Expand Up @@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// - Reset the dirty flag to false.
dirty_aux_data_device_ = false;
}
};
}; // namespace relax_vm

TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

Expand All @@ -1412,7 +1485,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_attention_prefill_end_forward,
PackedFunc f_attention_decode_begin_forward,
PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
Expand All @@ -1435,7 +1509,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_prefill_ragged_end_forward),
std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward),
std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward),
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});

Expand All @@ -1447,7 +1522,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
Expand All @@ -1467,7 +1543,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/relax_vm/rnn_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj {
dirty_aux_data_device_ = true;
}

void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id
<< "\" cannot be found in space state storage.";
Expand Down
Loading

0 comments on commit c3be89a

Please sign in to comment.