Skip to content

Commit

Permalink
fix(src): build compassed
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo committed Jul 28, 2023
1 parent 77ba9d6 commit 1e77761
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 49 deletions.
24 changes: 10 additions & 14 deletions src/turbomind/kernels/unfused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ void invokeAttentionScoreSum(AttentionScoreSumParam<T>& param, cudaStream_t stre
attention_score_sum<<<param.batch_size, param.k_length, 0, stream>>>(param.attn_score, param.score_sum, param.batch_size, param.num_heads, param.q_length, param.k_length, param.stride);
}

__global__ void attention_score_bottom_k(int64_t* score_ptrs,
int64_t* index_ptrs,
__global__ void attention_score_bottom_k(uint64_t* score_ptrs,
uint64_t* index_ptrs,
const int* window_ptr,
const int* bottom_k_ptr,
const int group,
Expand Down Expand Up @@ -834,21 +834,21 @@ __global__ void attention_score_bottom_k(int64_t* score_ptrs,
}
}

template void removeOrderedIndicesAsync(float* k_ptr, float* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, const AttentionScoreSortParam<float>& param, cudaStream_t stream);
template void removeOrderedIndicesAsync(half* k_ptr, half* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam<half>& param, cudaStream_t stream);
template void removeOrderedIndicesAsync(int8_t* k_ptr, int8_t* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam<int8_t>& param, cudaStream_t stream);
template void removeOrderedIndicesAsync(float* k_ptr, float* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam& param, cudaStream_t stream);
template void removeOrderedIndicesAsync(half* k_ptr, half* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam& param, cudaStream_t stream);
template void removeOrderedIndicesAsync(int8_t* k_ptr, int8_t* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam& param, cudaStream_t stream);

// shape [head_num_, max_seq_len_, size_per_head_]
template<typename T>
void removeOrderedIndicesAsync(T* k_ptr, T* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam<T>& param, cudaStream_t stream) {
void removeOrderedIndicesAsync(T* k_ptr, T* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam& param, cudaStream_t stream) {
// keep index ascend
const int head_stride = param.max_seq_len * param.size_per_head;
for (int head_id = 0; head_id < param.num_heads; ++head_id) {
int indexToRemove = 0;
size_t indexToRemove = 0;
int shift = 0;

T* k_base = k_ptr + param.max_seq_len * param.size_per_head;
T* v_base = v_ptr + param.max_seq_len * param.size_per_head;
T* k_base = k_ptr + head_stride;
T* v_base = v_ptr + head_stride;

for (int i = 0; i < window; ++i) {
if (indexToRemove < indexes.size() && i == indexes[indexToRemove]) {
Expand Down Expand Up @@ -880,11 +880,7 @@ void removeOrderedIndicesAsync(T* k_ptr, T* v_ptr, int window, int bottom_k, con
}
}

template void invokeCacheKVTrim(AttentionScoreSortParam<float>& param, cudaStream_t stream);
template void invokeCacheKVTrim(AttentionScoreSortParam<half>& param, cudaStream_t stream);

template<typename T>
void invokeCacheKVTrim(AttentionScoreSortParam<T>& param, cudaStream_t stream) {
void invokeCacheKVTrim(AttentionScoreSortParam& param, cudaStream_t stream) {
dim3 grid(param.batch_size), block(param.layer_num);
attention_score_bottom_k<<<grid, block, 0, stream>>>(param.score_device_ptrs, param.index_device_ptrs,
param.window_device_ptr, param.bottom_k_device_ptr,
Expand Down
18 changes: 8 additions & 10 deletions src/turbomind/kernels/unfused_attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,22 @@ struct AttentionScoreSumParam {
int k_length = 0;
int num_heads = 0;
int stride = 0;
int layerid = 0;
int layer_id = 0;
};

template<typename T>
void invokeAttentionScoreSum(AttentionScoreSumParam<T>& param, cudaStream_t stream);

template<typename T>
struct AttentionScoreSortParam {
// shape [batch_size, layer_num, 1, 1, max_seq_len]
int64_t* score_device_ptrs = nullptr;
uint64_t* score_device_ptrs = nullptr;
// shape [batch_size, layer_num, window], window = cur_input_seq_len - GROUP
int64_t* index_device_ptrs = nullptr;
int64_t* index_host_ptrs = nullptr;
uint64_t* index_device_ptrs = nullptr;
uint64_t* index_host_ptrs = nullptr;

// shape [batch_size, layer_num, num_head, max_seq_len, max_seq_len]
int64_t* k_ptrs = nullptr;
int64_t* v_ptrs = nullptr;
uint64_t* k_ptrs = nullptr;
uint64_t* v_ptrs = nullptr;

// shape [batch_size], value < 128
int* window_device_ptr = nullptr;
Expand All @@ -106,10 +105,9 @@ struct AttentionScoreSortParam {
};

template<typename T>
void removeOrderedIndicesAsync(T* k_ptr, T* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam<T>& param, cudaStream_t stream);
void removeOrderedIndicesAsync(T* k_ptr, T* v_ptr, int window, int bottom_k, const std::vector<int>& indexes, AttentionScoreSortParam& param, cudaStream_t stream);

template<typename T>
void invokeCacheKVTrim(AttentionScoreSortParam<T>& param, cudaStream_t stream);
void invokeCacheKVTrim(AttentionScoreSortParam& param, cudaStream_t stream);

template<typename T, typename T_IN>
void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream);
Expand Down
39 changes: 19 additions & 20 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h"
Expand Down Expand Up @@ -58,7 +59,7 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r

auto drop_invalid = [](std::vector<std::shared_ptr<Request>>& rs) {
int count = 0;
for (int i = 0; i < rs.size(); ++i) {
for (size_t i = 0; i < rs.size(); ++i) {
if (rs[i]) {
rs[count++] = std::move(rs[i]);
}
Expand Down Expand Up @@ -549,7 +550,7 @@ bool LlamaBatch<T>::generate()
}

std::stringstream scurr;
for (int k = 0; k < curr.size(); ++k) {
for (size_t k = 0; k < curr.size(); ++k) {
scurr << std::setw(6) << curr[k];
}
TM_LOG_INFO("[generate] step = %d, [%s]", step_ - 1, scurr.str().c_str());
Expand Down Expand Up @@ -622,16 +623,14 @@ void LlamaBatch<T>::trimHookRequest(std::vector<std::shared_ptr<Request>>& infer
}

template<typename T>
bool LlamaBatch<T>::trimUpdateKV(std::vector<std::shared_ptr<Request>>& infer_requests) {
void LlamaBatch<T>::trimUpdateKV(std::vector<std::shared_ptr<Request>>& infer_requests) {
std::vector<int64_t> score_ptrs;
std::vector<int64_t> bottom_ptrs;
std::vector<int64_t> k_ptrs;
std::vector<int64_t> v_ptrs;
std::vector<int> windows;
std::vector<int> bottoms_k_;

ptrs.reserve(infer_requests.size());

for (auto r: infer_requests) {
// update cache meta info
const CacheKVTrimParam& param = r->cache_kv_param;
Expand All @@ -652,10 +651,10 @@ bool LlamaBatch<T>::trimUpdateKV(std::vector<std::shared_ptr<Request>>& infer_re
const int bottom_k = seq.cache_len - 1024;
const int window = seq.cache_len - param.GROUP;

score_ptrs.push_back(reinterpret_cast<in64_t>(seq.attn_score_sum));
bottom_ptrs.push_back(reinterpret_cast<in64_t>(seq.attn_score_bottom_index));
k_ptrs.push_back(reinterpret_cast<in64_t>(seq.k_cache));
v_ptrs.push_back(reinterpret_cast<in64_t>(seq.v_cache));
score_ptrs.push_back(reinterpret_cast<int64_t>(seq.attn_score_sum));
bottom_ptrs.push_back(reinterpret_cast<int64_t>(seq.attn_score_bottom_index));
k_ptrs.push_back(reinterpret_cast<int64_t>(seq.k_cache));
v_ptrs.push_back(reinterpret_cast<int64_t>(seq.v_cache));
bottoms_k_.push_back(bottom_k);
windows.push_back(window);

Expand All @@ -665,21 +664,21 @@ bool LlamaBatch<T>::trimUpdateKV(std::vector<std::shared_ptr<Request>>& infer_re
llama_->kv_cache_mgr_->update(seq, stream_);
}

if (not ptrs.empty()) {
if (not score_ptrs.empty()) {

const int batch_size = infer_requests.size();
check_cuda_error(cudaMemcpyAsync(trim_score_ptrs_, score_ptrs.data(), sizeof(int64_t&) * batch_size, cudaMemcpyHostToDevice));
check_cuda_error(cudaMemcpyAsync(trim_index_ptrs_, bottom_ptrs.data(), sizeof(int64_t&) * batch_size, cudaMemcpyHostToDevice));
check_cuda_error(cudaMemcpyAsync(trim_window_ptr_, windows.data(), sizeof(int) * ptrs.size(), cudaMemcpyHostToDevice, stream_));
check_cuda_error(cudaMemcpyAsync(trim_bottom_k_ptr_, bottoms_k_.data(), sizeof(int) * ptrs.size(), cudaMemcpyHostToDevice, stream_));
check_cuda_error(cudaMemcpyAsync(trim_window_ptr_, windows.data(), sizeof(int) * batch_size, cudaMemcpyHostToDevice, stream_));
check_cuda_error(cudaMemcpyAsync(trim_bottom_k_ptr_, bottoms_k_.data(), sizeof(int) * batch_size, cudaMemcpyHostToDevice, stream_));
check_cuda_error(cudaStreamSynchronize(stream_));

AttentionScoreSortParam param;
param.score_device_ptrs = trim_score_ptrs_;
param.index_device_ptrs = trim_index_ptrs_;
param.index_host_ptrs = (int64_t*)bottom_ptrs.data();
param.k_ptrs = (int64_t*)k_ptrs.data();
param.v_ptrs = (int64_t*)v_ptrs.data();
param.index_host_ptrs = (uint64_t*)bottom_ptrs.data();
param.k_ptrs = (uint64_t*)k_ptrs.data();
param.v_ptrs = (uint64_t*)v_ptrs.data();

param.window_device_ptr = trim_window_ptr_;
param.window_host_ptr = (int*)windows.data();
Expand All @@ -689,8 +688,8 @@ bool LlamaBatch<T>::trimUpdateKV(std::vector<std::shared_ptr<Request>>& infer_re
param.layer_num = llama_->num_layer_;
param.num_heads = llama_->head_num_;
param.size_per_head = llama_->size_per_head_;
param.max_seq_len = llama_->session_len;
param.stride = llama_->session_len / 2 + 128;
param.max_seq_len = llama_->session_len_;
param.stride = llama_->session_len_ / 2 + 128;
invokeCacheKVTrim(param, stream_);
}
check_cuda_error(cudaStreamSynchronize(stream_));
Expand Down Expand Up @@ -722,7 +721,7 @@ void LlamaBatch<T>::initialize(const std::vector<std::shared_ptr<Request>>& infe
{
FT_CHECK(batch_size_ + infer_requests.size() <= max_batch_size_);

const int infer_request_count = infer_requests.size();
const uint32_t infer_request_count = infer_requests.size();

allocateBuffer(batch_size_ + infer_request_count, session_len_);

Expand All @@ -745,7 +744,7 @@ void LlamaBatch<T>::initialize(const std::vector<std::shared_ptr<Request>>& infe

const int step = r.inputs[rank_].getVal<int>("step", -1);
if (step >= 0) {
if (step <= seq.token_ids.size()) {
if (step <= static_cast<int>(seq.token_ids.size())) {
seq.token_ids.resize(step);
seq.cache_len = std::min(seq.cache_len, (size_t)step);
}
Expand Down Expand Up @@ -779,7 +778,7 @@ void LlamaBatch<T>::initialize(const std::vector<std::shared_ptr<Request>>& infe
std::vector<int> idxs(tmp_input_length.size());
std::iota(idxs.begin(), idxs.end(), 0);
std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return tmp_input_length[i] < tmp_input_length[j]; });
for (int i = 0; i < idxs.size(); ++i) {
for (size_t i = 0; i < idxs.size(); ++i) {
requests_[batch_size_ + i] = infer_requests[idxs[i]];
cached_seq_[batch_size_ + i] = tmp_cached_seq[idxs[i]];
}
Expand Down
3 changes: 1 addition & 2 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
if (use_fmha_) {
fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
attn_sum_ptrs,
layer_offset,
attention_mask,
cu_seqlens,
Expand Down Expand Up @@ -378,7 +377,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
// sum attention_score
// from shape [batch_size, local_head_num_, max_q_len, max_k_len]
// to [batch_size, 1, 1, max_k_len]
AttentionScoreSumParam param;
AttentionScoreSumParam<T> param;
param.attn_score = qk_buf_;
param.score_sum = attn_sum_ptrs;
param.batch_size = batch_size;
Expand Down
1 change: 0 additions & 1 deletion src/turbomind/models/llama/LlamaContextAttentionLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class LlamaContextAttentionLayer {

void fusedMultiHeadAttention(T** key_cache_ptrs,
T** val_cache_ptrs,
float** attn_sum_ptrs,
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LlamaDecoder: public BaseLayer {
size_t max_memory_len;
Tensor* k_cache;
Tensor* v_cache;
Tensor* attn_score_sum;
const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
};

Expand Down
6 changes: 4 additions & 2 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,21 @@ LlamaV2<T>::LlamaV2(size_t head_num,
num_layer_(num_layer),
vocab_size_(vocab_size),
rotary_embedding_dim_(rotary_embedding_dim),
session_len_(session_len),
rmsnorm_eps_(norm_eps),
quant_policy_(quant_policy),
start_id_(start_id),
end_id_(end_id),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_),
weights_(weights),
tensor_para_(tensor_para),
stream_(stream),
cublas_wrapper_(cublas_wrapper),
allocator_(allocator),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
cuda_device_prop_(cuda_device_prop),
debug_(isDebug()),
weights_(weights),
step_length_(step_length),
batch_(max_batch_size, max_context_token_num, session_len, this),
shared_state_(shared_state)
Expand Down Expand Up @@ -603,7 +605,7 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,

if (rank == 0 && has_error) {
std::stringstream ss;
for (int i = 0; i < error_codes.size(); ++i) {
for (size_t i = 0; i < error_codes.size(); ++i) {
ss << (i ? "" : " ") << error_codes[i];
}
throw std::runtime_error(ss.str());
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class LlamaV2 {
const size_t num_layer_;
const size_t vocab_size_;
const size_t rotary_embedding_dim_;
const size_t session_len_;
float rmsnorm_eps_ = 1e-6f;
const int quant_policy_;

Expand Down

0 comments on commit 1e77761

Please sign in to comment.