Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(src): add kv cache int8 quantization #22

Merged
merged 9 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(kernel): k_cache use int8_t pointer
  • Loading branch information
tpoisonooo committed Jun 28, 2023
commit 2ea4e4af20e6597aefaaf28d2be60e9791490cff
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,19 @@ bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fast
python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1
```

## User Guide
## Quantization

In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
First execute the quantization script, and the quantization parameters are stored in the weight directory transformed by `deploy.py`.
Then adjust `config.ini`
* `use_context_fmha` changed to 0, means off
* `quant_policy` is set to 4. This parameter defaults to 0, which means it is not enabled

## Contributing

We appreciate all contributions to LLMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.


## Acknowledgement

- [FasterTransformer](https://github.com/NVIDIA/FasterTransformer)
Expand Down
19 changes: 13 additions & 6 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pip install -e .
```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-7B /path/to/llama-7b llama \
--tokenizer_path /path/to/tokenizer/model
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -79,7 +79,7 @@ bash workspace/service_docker_up.sh
```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-13B /path/to/llama-13b llama \
--tokenizer_path /path/to/tokenizer/model --tp 2
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -90,7 +90,7 @@ bash workspace/service_docker_up.sh
```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-33B /path/to/llama-33b llama \
--tokenizer_path /path/to/tokenizer/model --tp 4
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -101,7 +101,7 @@ bash workspace/service_docker_up.sh
```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-65B /path/to/llama-65b llama \
--tokenizer_path /path/to/tokenizer/model --tp 8
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -119,7 +119,7 @@ python3 -m fastchat.model.apply_delta \
--delta-path lmsys/vicuna-7b-delta-v1.1

python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-7B /path/to/vicuna-7b hf
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -135,7 +135,7 @@ python3 -m fastchat.model.apply_delta \
--delta-path lmsys/vicuna-13b-delta-v1.1

python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-13B /path/to/vicuna-13b hf
bash workspace/service_docker_up.sh
bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
```

</details>
Expand All @@ -146,6 +146,13 @@ bash workspace/service_docker_up.sh
python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1
```

## 量化部署
在 fp16 模式下,可以开启 kv_cache int8 量化,单卡可服务更多用户。
首先执行量化脚本,量化参数存放到 `deploy.py` 转换的 weight 目录下。
然后调整 `config.ini`
* `use_context_fmha` 改为 0,表示关闭
* `quant_policy` 设置为 4。此参数默认为 0,表示不开启

## 贡献指南

我们感谢所有的贡献者为改进和提升 LLMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1052,13 +1052,13 @@ inline __device__ int32_t quant(float4 a, const float scale)
}

// float16 to int8
// inline __device__ int8_t quant(uint16_t a, const float scale)
// {
// int8_t int8;
// float b = half_to_float(a);
// int8 = round(max(-128.f, min(127.f, b.x / scale)));
// return int8;
// }
inline __device__ int8_t quant(uint16_t a, const float scale)
{
int8_t int8;
float b = half_to_float(a);
int8 = round(max(-128.f, min(127.f, b / scale)));
return int8;
}
// float16x2 to int8x2
inline __device__ int16_t quant(uint a, const float scale)
{
Expand Down Expand Up @@ -1460,8 +1460,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);

Packed_Int8_t* dst_ptr = reinterpret_cast<Packed_Int8_t*>(params.k_cache);
dst_ptr[offset] = k_int8;
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
Expand All @@ -1481,8 +1481,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);

Packed_Int8_t** dst_ptr = reinterpret_cast<Packed_Int8_t**>(params.k_cache_per_sample);
dst_ptr[bi][offset] = k_int8;
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
Expand Down Expand Up @@ -1571,8 +1571,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
// convert k_cache_per_sample to int8
if (params.k_cache_per_sample) {
int8_t** ptr = reinterpret_cast<int8_t**>(params.k_cache_per_sample);
k_cache_batch_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
Expand Down Expand Up @@ -1755,8 +1755,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>

if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
if (params.v_cache_per_sample) {
int8_t** ptr = reinterpret_cast<int8_t**>(params.v_cache_per_sample);
v_cache_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
Expand Down
40 changes: 0 additions & 40 deletions src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc

#include <iostream>
#include <filesystem>
#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
Expand Down Expand Up @@ -202,22 +200,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
weights->past_kv_scale.data());


const std::string path_base = "/workspace/save/context_";

sync_check_cuda_error();
std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy";
std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy";
if (not std::filesystem::exists(kpath)) {
// auto kptr = k_cache_ptrs[0];
// auto vptr = v_cache_ptrs[0];

// Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, kptr);
// k.saveNpy(kpath);

// Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, vptr);
// v.saveNpy(vpath);
}

if (use_fmha_) {
fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
Expand Down Expand Up @@ -351,23 +334,6 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac
kv_scale);
sync_check_cuda_error();

const std::string path_base = "/workspace/save-fp16/context_";
std::string hkpath = path_base + "hk" + ".npy";
std::string hvpath = path_base + "hv" + ".npy";
std::string hspath = path_base + "hs" + ".npy";

if (not std::filesystem::exists(hkpath)) {

Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, k_cache_buf_);
k.saveNpy(hkpath);

Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, v_cache_buf_);
v.saveNpy(hvpath);

Tensor s(MemoryType::MEMORY_CPU, DataType::TYPE_FP32, {2}, kv_scale);
s.saveNpy(hspath);
}

const T qk_scale = static_cast<T>(1.f / sqrtf(size_per_head_ * 1.f));

//////////////////////////////////////////////
Expand Down Expand Up @@ -437,12 +403,6 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac
0,
stream_);
sync_check_cuda_error();

std::string hapath = path_base + "ha" + "policy" + std::to_string(quant_policy_) + ".npy";
if (not std::filesystem::exists(hapath)) {
Tensor attn(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, max_q_len}, qkv_buf_3_);
attn.saveNpy(hapath);
}
}

template class LlamaContextAttentionLayer<float>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include <iostream>
#include <filesystem>
#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/models/llama/LlamaNcclGuard.h"
Expand Down Expand Up @@ -241,6 +239,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len;

if (kv_cache_layer_offset >= 192937984) {
fprintf(stderr, "gdb");
}

fusedQKV_masked_attention_dispatch<T>(
qkv_buf_,
weights->qkv.bias, // query_weight.bias,
Expand Down Expand Up @@ -280,22 +282,6 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
stream_);
sync_check_cuda_error();

const std::string path_base = "/workspace/save-fp16/normal_";
std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy";
std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy";
// if (not std::filesystem::exists(kpath)) {
// auto kptr = key_cache_ptrs[0];
// auto vptr = value_cache_ptrs[0];

// Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, kptr);
// k.saveNpy(kpath);

// Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, vptr);
// v.saveNpy(vpath);
// }



linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);

if (tensor_para_.world_size_ > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,11 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
session_len_ = reader.GetInteger("llama", "session_len", 0);
step_length_ = reader.GetInteger("llama", "step_length", 0);
cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0);
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 0);
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0);
prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = 0;
// quant_policy_ = reader.GetInteger("llama", "quant_policy", 4);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);

handleMissingParams();

Expand Down