From 54ef0b986a82230ae17a83b29dfccb0033bc30eb Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 27 Jun 2023 03:22:07 +0000 Subject: [PATCH 1/9] feat(src): add int8 and compile passed --- CMakeLists.txt | 24 +- .../decoder_masked_multihead_attention.h | 2 + ...er_masked_multihead_attention_template.cuh | 385 ++++++++++++++---- .../llama/LlamaContextAttentionLayer.cc | 16 +- .../models/llama/LlamaContextAttentionLayer.h | 31 +- .../models/llama/LlamaContextDecoder.cc | 10 +- .../models/llama/LlamaContextDecoder.h | 5 +- .../models/llama/LlamaDecoder.cc | 10 +- .../models/llama/LlamaDecoder.h | 6 +- .../models/llama/LlamaDecoderLayerWeight.cc | 5 + .../llama/LlamaDecoderSelfAttentionLayer.cc | 7 +- .../llama/LlamaDecoderSelfAttentionLayer.h | 7 +- .../models/llama/LlamaDenseWeight.h | 1 + src/fastertransformer/models/llama/LlamaV2.cc | 23 +- src/fastertransformer/models/llama/LlamaV2.h | 3 +- .../models/llama/llama_kernels.cu | 250 ++++++++++-- .../models/llama/llama_kernels.h | 44 +- .../models/llama/llama_utils.h | 8 + .../triton_backend/llama/LlamaTritonModel.cc | 4 +- .../triton_backend/llama/LlamaTritonModel.h | 1 + src/fastertransformer/utils/memory_utils.cu | 7 +- src/fastertransformer/utils/memory_utils.h | 2 + 22 files changed, 650 insertions(+), 201 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b81abb3b8..56eb32775 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,18 +17,18 @@ project(FasterTransformer LANGUAGES CXX CUDA) find_package(CUDA 10.2 REQUIRED) -if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") - add_definitions("-DENABLE_BF16") - message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") -endif() - -if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12")) - add_definitions("-DENABLE_FP8") - option(ENABLE_FP8 "ENABLE_FP8" OFF) - if(ENABLE_FP8) - message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.8, enable -DENABLE_FP8 flag") - endif() -endif() +# if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") +# add_definitions("-DENABLE_BF16") +# message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") +# endif() + +# if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12")) +# add_definitions("-DENABLE_FP8") +# option(ENABLE_FP8 "ENABLE_FP8" OFF) +# if(ENABLE_FP8) +# message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.8, enable -DENABLE_FP8 flag") +# endif() +# endif() set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index c56e87358..97c008652 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -116,6 +116,8 @@ struct Multihead_attention_params_base { const float* qkv_scale_out = nullptr; const float* attention_out_scale = nullptr; int int8_mode = 0; + float attention_k_scale = 0.f; + float attention_v_scale = 0.f; }; template diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index ddbbe446e..9f60918fd 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -20,6 +20,7 @@ #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include "src/fastertransformer/models/llama/llama_utils.h" #include #include #include @@ -548,6 +549,43 @@ __inline__ __device__ Tout vec_conversion(const Tin& x) { return x; } +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} +template<> +__inline__ __device__ uint2 vec_conversion(const float4& a) +{ + uint2 b; + float2 val; + val.x = a.x; + val.y = a.y; + b.x = vec_conversion(val); + + val.x = a.z; + val.y = a.w; + b.y = vec_conversion(val); + + return b; +} +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + #ifdef ENABLE_FP8 // fp8_t template<> @@ -980,6 +1018,139 @@ inline __device__ Float8_ float_from_int8(int64_t u) } // clang-format on +inline __device__ int8_t quant(float a, const float scale) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, a / scale))); + return int8; +} + +inline __device__ short quant(float2 a, const float scale) +{ + union { + int8_t int8[2]; + short int16; + }; + + int8[0] = round(max(-128.f, min(127.f, a.x / scale))); + int8[1] = round(max(-128.f, min(127.f, a.y / scale))); + return int16; +} + +inline __device__ int32_t quant(float4 a, const float scale) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + + int8[0] = round(max(-128.f, min(127.f, a.x / scale))); + int8[1] = round(max(-128.f, min(127.f, a.y / scale))); + int8[2] = round(max(-128.f, min(127.f, a.z / scale))); + int8[3] = round(max(-128.f, min(127.f, a.w / scale))); + return int32; +} + +// float16 to int8 +// inline __device__ int8_t quant(uint16_t a, const float scale) +// { +// int8_t int8; +// float b = half_to_float(a); +// int8 = round(max(-128.f, min(127.f, b.x / scale))); +// return int8; +// } +// float16x2 to int8x2 +inline __device__ int16_t quant(uint a, const float scale) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = half2_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, b.x / scale))); + int8[1] = round(max(-128.f, min(127.f, b.y / scale))); + return int16; +} +// float16x4 to int8x4 +inline __device__ int32_t quant(uint2 a, const float scale) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale); + int16[1] = quant(a.y, scale); + return int32; +} +// float16x8 to int8x8 +inline __device__ int64_t quant(uint4 a, const float scale) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale); + int16[1] = quant(a.y, scale); + int16[2] = quant(a.z, scale); + int16[3] = quant(a.w, scale); + return int64; +} +// int8 to float32, then `vec_conversion` to target format +inline __device__ float dequant(int8_t a, const float scale) +{ + float b = a * scale; + return b; +} +// int8x2 to float32x2 +inline __device__ float2 dequant(int16_t a, const float scale) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + int16 = a; + + float2 b; + b.x = int8[0] * scale; + b.y = int8[1] * scale; + return b; +} +// int8x4 to float32x4 +inline __device__ float4 dequant(int32_t a, const float scale) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int32 = a; + + float4 b; + b.x = (int8[0] * scale); + b.y = (int8[1] * scale); + b.z = (int8[2] * scale); + b.w = (int8[3] * scale); + return b; +} + +inline __device__ Float8_ dequant(int64_t a, const float scale) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + int64 = a; + + Float8_ b; + b.x = dequant(int16[0], scale); + b.y = dequant(int16[1], scale); + b.z = dequant(int16[2], scale); + b.w = dequant(int16[3], scale); + return b; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ int8_t cast_to_int8(float val) @@ -1208,44 +1379,23 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // The offset in the bias buffer. int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + // past kv quant param + const float k_scale = params.attention_k_scale; + const float v_scale = params.attention_v_scale; // Trigger the loads from the Q and K buffers. Qk_vec_k q; zero(q); if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); - } + q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); } Qk_vec_k k; zero(k); { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : - k; - } + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; } // Trigger the loads from the Q and K bias buffers. @@ -1270,12 +1420,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (handle_kv) { k = add(k, k_bias); } - if (do_ia3 && !is_masked) { - k = mul( - k, - vec_conversion(*reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); - } // Padded len const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; @@ -1311,7 +1455,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; - *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + using Packed_Int8_t = typename packed_type::value>::type; + Packed_Int8_t k_int8 = quant(k, k_scale); + + Packed_Int8_t* dst_ptr = reinterpret_cast(params.k_cache); + dst_ptr[offset] = k_int8; + } else { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } } else { int offset; @@ -1323,8 +1476,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh + co * QK_ELTS_IN_16B + ci; } - *reinterpret_cast(¶ms.k_cache_per_sample[bi][offset]) = - vec_conversion(k); + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + using Packed_Int8_t = typename packed_type::value>::type; + Packed_Int8_t k_int8 = quant(k, k_scale); + + Packed_Int8_t** dst_ptr = reinterpret_cast(params.k_cache_per_sample); + dst_ptr[bi][offset] = k_int8; + } else { + *reinterpret_cast(¶ms.k_cache_per_sample[bi][offset]) = + vec_conversion(k); + } + } } } @@ -1402,13 +1565,28 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = - params.k_cache_per_sample ? - (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : - ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - // T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - T* k_cache_batch = k_cache; + T* k_cache_batch = nullptr; + int8_t* k_cache_batch_int8 = nullptr; + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + T* k_cache = + params.k_cache_per_sample ? + (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : + ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + // T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + k_cache_batch = k_cache; + } else { + // convert k_cache_per_sample to int8 + if (params.k_cache_per_sample) { + int8_t** ptr = reinterpret_cast(params.k_cache_per_sample); + k_cache_batch_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; + } else { + int8_t* ptr = reinterpret_cast(params.k_cache); + k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki]; + } + } + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; @@ -1439,14 +1617,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params k[ii] = k_vec_zero; } else { + int beam_offset = 0; if (HAS_BEAMS) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; } - else { - k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + + Packed_Int8_t k_vec_m_int8 = *reinterpret_cast(&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]); + Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale); + + k[ii] = vec_conversion(k_vec_m_float); + } else { + k[ii] = vec_conversion((*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } } } @@ -1562,13 +1747,32 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = - params.v_cache_per_sample ? - (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) : - ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - // T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; - T* v_cache_batch = v_cache; + T* v_cache = nullptr; + T* v_cache_batch = nullptr; + + int8_t* v_cache_int8 = nullptr; + int8_t* v_cache_batch_int8 = nullptr; + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + if (params.v_cache_per_sample) { + int8_t** ptr = reinterpret_cast(params.v_cache_per_sample); + v_cache_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; + } else { + int8_t* ptr = reinterpret_cast(params.v_cache); + v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi]; + } + + v_cache_batch_int8 = v_cache_int8; + } else { + + v_cache = + params.v_cache_per_sample ? + (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) : + ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + // T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; + v_cache_batch = v_cache; + } // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1606,7 +1810,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // Loop over the timesteps to compute the partial outputs. // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { if (Dh == Dh_MAX || vi < Dh) { - + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; // Separate the ti < memory_max_len and ti > memory_max_len // to prevent ti % memory_len when ti < memory_len, and // the compiler cannot optimize the codes automatically. @@ -1616,8 +1821,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. - V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + V_vec_k v; + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + Packed_Int8_t v_vec_m_int8 = *reinterpret_cast(&v_cache_batch_int8[beam_offset + ti * Dh]); + Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); + + v = vec_conversion(v_vec_m_float); + } else { + v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + } + // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1652,8 +1866,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. - V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + V_vec_k v; + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + Packed_Int8_t v_vec_m_int8 = *reinterpret_cast(&v_cache_batch_int8[beam_offset + ti_circ * Dh]); + Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); + + v = vec_conversion(v_vec_m_float); + } else { + v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + } + // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1687,18 +1910,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // Trigger the loads from the V buffer. const auto v_offset = qkv_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); - } + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); // Trigger the loads from the V bias buffer. // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); @@ -1706,16 +1918,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (handle_kv) { v = add(v, v_bias); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + + if (params.int8_mode & QuantPolicy::kCacheKVInt8) { + using Packed_Int8_t = typename packed_type::value>::type; + Packed_Int8_t v_int8 = quant(v, v_scale); + *reinterpret_cast(&v_cache_int8[tlength_circ * Dh]) = v_int8; + } else { + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } } // Initialize the output value with the current timestep. @@ -1782,14 +1994,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), mul(result_scale, out)); #endif // FP8_MHA - } - else if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { + } else { convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); } #else // MMHA_USE_FP32_ACUM_FOR_OUT diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc index 745993aba..651857493 100644 --- a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc @@ -195,7 +195,9 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* max_seq_len, size_per_head_, local_head_num_, - stream_); + stream_, + quant_policy_, + weights->past_kv_scale.data()); sync_check_cuda_error(); if (use_fmha_) { @@ -220,7 +222,9 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* num_token, max_q_len, max_k_len, - max_seq_len); + max_seq_len, + quant_policy_, + weights->past_kv_scale.data()); } ////////////////////////////////////////////// @@ -303,7 +307,9 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac int num_token, int max_q_len, int max_k_len, - int max_seq_len) + int max_seq_len, + int quant, + const float* kv_scale) { // key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] // val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] @@ -318,7 +324,9 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac max_seq_len, size_per_head_, local_head_num_, - stream_); + stream_, + quant, + kv_scale); sync_check_cuda_error(); const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h index 2572a696b..8daae35b7 100644 --- a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h @@ -42,7 +42,8 @@ class LlamaContextAttentionLayer { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - bool use_fmha): + bool use_fmha, + int quant_policy): head_num_(head_num), size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), @@ -56,7 +57,8 @@ class LlamaContextAttentionLayer { linear_(cublas_wrapper, stream), allocator_(allocator), is_free_buffer_after_forward_(is_free_buffer_after_forward), - use_fmha_(use_fmha) + use_fmha_(use_fmha), + quant_policy_(quant_policy) { } @@ -72,17 +74,19 @@ class LlamaContextAttentionLayer { int max_k_len, int max_seq_len); - void unfusedMultiHeadAttention(T** key_cache_ptrs, - T** val_cache_ptrs, - size_t cache_layer_offset, - const T* attention_mask, - const int* padding_offset, - const int* context_length, - int batch_size, - int num_token, - int max_q_len, - int max_k_len, - int max_seq_len); + void unfusedMultiHeadAttention(T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + const T* attention_mask, + const int* padding_offset, + const int* context_length, + int batch_size, + int num_token, + int max_q_len, + int max_k_len, + int max_seq_len, + int quant_policy, + const float* kv_scale); private: const size_t head_num_; @@ -96,6 +100,7 @@ class LlamaContextAttentionLayer { const bool neox_rotary_style_; const bool use_fmha_; + const int quant_policy_; NcclParam tensor_para_; diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 3fae4d990..eaa2247e2 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -62,7 +62,7 @@ void LlamaContextDecoder::freeBuffer() } template -void LlamaContextDecoder::initialize(bool use_fmha) +void LlamaContextDecoder::initialize(bool use_fmha, int quant_policy) { h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); @@ -75,7 +75,8 @@ void LlamaContextDecoder::initialize(bool use_fmha) cublas_wrapper_, allocator_, is_free_buffer_after_forward_, - use_fmha); + use_fmha, + quant_policy); silu_ffn_layer_ = new LlamaFfnLayer(head_num_, size_per_head_, @@ -133,7 +134,8 @@ LlamaContextDecoder::LlamaContextDecoder(size_t head_num, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - bool use_fmha): + bool use_fmha, + int quant_policy): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), head_num_(head_num), size_per_head_(size_per_head), @@ -145,7 +147,7 @@ LlamaContextDecoder::LlamaContextDecoder(size_t head_num, tensor_para_(tensor_para), data_type_(getTensorType()) { - initialize(use_fmha); + initialize(use_fmha, quant_policy); } template diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.h b/src/fastertransformer/models/llama/LlamaContextDecoder.h index 82ecc2574..5370e8418 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.h +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.h @@ -42,7 +42,7 @@ class LlamaContextDecoder: public BaseLayer { void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void freeBuffer() override; - void initialize(bool use_fmha); + void initialize(bool use_fmha, int quant_policy); size_t head_num_; size_t size_per_head_; @@ -97,7 +97,8 @@ class LlamaContextDecoder: public BaseLayer { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - bool use_fmha); + bool use_fmha, + int quant_policy); ~LlamaContextDecoder() override; diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc index f02e8ae18..62e1d9eda 100644 --- a/src/fastertransformer/models/llama/LlamaDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -37,7 +37,8 @@ LlamaDecoder::LlamaDecoder(size_t head_num, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, - bool is_free_buffer_after_forward): + bool is_free_buffer_after_forward, + int quant_policy): BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), head_num_(head_num), size_per_head_(size_per_head), @@ -50,7 +51,7 @@ LlamaDecoder::LlamaDecoder(size_t head_num, data_type_(getTensorType()) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); - initialize(); + initialize(quant_policy); } template @@ -62,7 +63,7 @@ LlamaDecoder::~LlamaDecoder() } template -void LlamaDecoder::initialize() +void LlamaDecoder::initialize(int quant_policy) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -74,7 +75,8 @@ void LlamaDecoder::initialize() stream_, cublas_wrapper_, allocator_, - is_free_buffer_after_forward_); + is_free_buffer_after_forward_, + quant_policy); silu_ffn_layer_ = new LlamaFfnLayer(head_num_, size_per_head_, diff --git a/src/fastertransformer/models/llama/LlamaDecoder.h b/src/fastertransformer/models/llama/LlamaDecoder.h index dd5956788..6113e06d5 100644 --- a/src/fastertransformer/models/llama/LlamaDecoder.h +++ b/src/fastertransformer/models/llama/LlamaDecoder.h @@ -35,7 +35,7 @@ class LlamaDecoder: public BaseLayer { void allocateBuffer() override; // deprecated void allocateBuffer(size_t batch_size); void freeBuffer() override; - void initialize(); + void initialize(int quant_policy); size_t head_num_; size_t size_per_head_; @@ -79,7 +79,9 @@ class LlamaDecoder: public BaseLayer { cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, - bool is_free_buffer_after_forward); + bool is_free_buffer_after_forward, + int quant_policy), + ~LlamaDecoder() override; diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc index dd41cb509..5f8d31459 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -158,6 +158,11 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type); loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type); loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type); + + // load kv_cache quant scale + // if file not exist, get empty vector + std::string scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight"; + self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); } template struct LlamaDecoderLayerWeight; diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc index e132a5a51..26babcd1c 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -75,6 +75,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const float* qkv_scale_out, const float* attention_out_scale, const int int8_mode, + const float* attention_kv_scale, cudaStream_t stream) { using DataType = typename SATypeConverter::Type; @@ -151,6 +152,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, if (int8_mode == 2) { params.qkv_scale_out = qkv_scale_out; params.attention_out_scale = attention_out_scale; + } else if (int8_mode == QuantPolicy::kCacheKVInt8) { + params.attention_k_scale = attention_kv_scale[0]; + params.attention_v_scale = attention_kv_scale[1]; } PUSH_RANGE("scaled dot-product fusion"); @@ -269,7 +273,8 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o nullptr, // ia3_value_weights nullptr, // qkv_scale_out nullptr, // attention_out_scale - 0, // int8_mode + quant_policy_, // int8_mode + weights->past_kv_scale.data(), // attention kv scale stream_); sync_check_cuda_error(); diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h index 1365fa703..629d577b4 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h @@ -40,7 +40,8 @@ class LlamaDecoderSelfAttentionLayer { cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, - bool is_free_buffer_after_forward): + bool is_free_buffer_after_forward, + int quant_policy): head_num_(head_num), size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), @@ -52,7 +53,8 @@ class LlamaDecoderSelfAttentionLayer { stream_(stream), linear_(cublas_wrapper, stream), allocator_(allocator), - is_free_buffer_after_forward_(is_free_buffer_after_forward) + is_free_buffer_after_forward_(is_free_buffer_after_forward), + quant_policy_(quant_policy) { } @@ -71,6 +73,7 @@ class LlamaDecoderSelfAttentionLayer { const size_t local_hidden_units_; const size_t rotary_embedding_dim_; const bool is_free_buffer_after_forward_; + const int quant_policy_; const bool neox_rotary_style_; diff --git a/src/fastertransformer/models/llama/LlamaDenseWeight.h b/src/fastertransformer/models/llama/LlamaDenseWeight.h index fba6f1df2..1cfe7ea6a 100644 --- a/src/fastertransformer/models/llama/LlamaDenseWeight.h +++ b/src/fastertransformer/models/llama/LlamaDenseWeight.h @@ -66,6 +66,7 @@ template struct LlamaAttentionWeight { LlamaDenseWeight qkv; LlamaDenseWeight output; + std::vector past_kv_scale; }; template diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc index 6c8c39247..df7fa694e 100644 --- a/src/fastertransformer/models/llama/LlamaV2.cc +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -52,6 +52,7 @@ LlamaV2::LlamaV2(size_t head_num, int end_id, int cache_max_entry_count, int cache_chunk_size, + int quant_policy, bool use_context_fmha, std::shared_ptr shared_state, LlamaWeight* weights, @@ -89,16 +90,26 @@ LlamaV2::LlamaV2(size_t head_num, FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0); FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); + size_t elem_bits = 0; + if (quant_policy & QuantPolicy::kCacheKVInt8) { + elem_bits = sizeof(int8_t) * 8; + if (use_context_fmha) { + FT_LOG_ERROR("use_context_fmha not support int8"); + assert(0); + } + } else { + elem_bits = sizeof(T) * 8; + } kv_cache_mgr_ = std::make_unique(num_layer_, local_head_num_, size_per_head_, session_len, - sizeof(T) * 8, + elem_bits, cache_max_entry_count, cache_chunk_size, tensor_para.rank_, allocator); - initialize(use_context_fmha); + initialize(use_context_fmha, quant_policy); start(); } @@ -113,7 +124,7 @@ LlamaV2::~LlamaV2() } template -void LlamaV2::initialize(bool use_context_fmha) +void LlamaV2::initialize(bool use_context_fmha, int quant_policy) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -128,7 +139,8 @@ void LlamaV2::initialize(bool use_context_fmha) cublas_wrapper_, allocator_, is_free_buffer_after_forward_, - use_context_fmha); + use_context_fmha, + quant_policy); decoder_ = new LlamaDecoder(head_num_, size_per_head_, @@ -140,7 +152,8 @@ void LlamaV2::initialize(bool use_context_fmha) stream_, cublas_wrapper_, allocator_, - is_free_buffer_after_forward_); + is_free_buffer_after_forward_, + quant_policy); dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, vocab_size_, // vocab_size_padded, diff --git a/src/fastertransformer/models/llama/LlamaV2.h b/src/fastertransformer/models/llama/LlamaV2.h index fb41a91ff..3be088c57 100644 --- a/src/fastertransformer/models/llama/LlamaV2.h +++ b/src/fastertransformer/models/llama/LlamaV2.h @@ -62,6 +62,7 @@ class LlamaV2 { int end_id, int cache_max_entry_count, int cache_chunk_size, + int quant_policy, bool use_context_fmha, std::shared_ptr shared_state, LlamaWeight* weights, @@ -88,7 +89,7 @@ class LlamaV2 { void internalThreadEntry(int device_id); - void initialize(bool use_context_fmha); + void initialize(bool use_context_fmha, int quant_policy); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); diff --git a/src/fastertransformer/models/llama/llama_kernels.cu b/src/fastertransformer/models/llama/llama_kernels.cu index 967b0a658..5884eb7dd 100644 --- a/src/fastertransformer/models/llama/llama_kernels.cu +++ b/src/fastertransformer/models/llama/llama_kernels.cu @@ -2,7 +2,9 @@ #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" #include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/models/llama/llama_utils.h" #include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" namespace fastertransformer { @@ -283,6 +285,102 @@ __global__ void extend_value_cache(T** v_dst, } } +inline __device__ float2 float2div(float a, float2 b) +{ + float2 c; + c.x = b.x / a; + c.y = b.y / a; + return c; +} + +static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale) { + half4 dst; + dst.x = __float2half(value.x * scale); + dst.y = __float2half(value.y * scale); + dst.z = __float2half(value.z * scale); + dst.w = __float2half(value.w * scale); + return dst; +} + +static inline __device__ uint32_t float4_to_char4(float x, + float y, + float z, + float w) { + uint32_t dst; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 + uint32_t a; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); +#else + char4 tmp; + tmp.x = x; + tmp.y = y; + tmp.z = z; + tmp.w = w; + dst = reinterpret_cast(tmp); +#endif + return dst; +} + +template +__global__ void extend_value_cache_int8(int8_t** v_dst, + const size_t dst_offset, + const T* v_src, + const int head_num, + const int size_per_head, + const int* query_length, + const int* history_length, + const int max_q_len, + const int max_seq_len, + const float v_scale) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto val_src = reinterpret_cast(v_src); + const auto val_dst = reinterpret_cast(v_dst[batch_id] + dst_offset); + + const auto seq_len = query_length[batch_id]; + const auto t_offset = history_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] -> [H, S[t:t+s], D/x] + const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H + (v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset + v_head_size_id; // D/x + + const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B + head_id * size_per_head_div_x * max_q_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + // scale to int8 and write + const auto value = val_src[src_idx]; + auto to_ptr = reinterpret_cast(val_dst + dst_idx); + + float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x)); + float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y)); + to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); + + float2_0 = float2div(v_scale, mmha::half2_to_float2(value.z)); + float2_1 = float2div(v_scale, mmha::half2_to_float2(value.w)); + to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); + } +} + + template void invokeExtendKVCache(T** k_dst, T** v_dst, @@ -296,18 +394,29 @@ void invokeExtendKVCache(T** k_dst, int max_seq_len, int size_per_head, int local_head_num, - cudaStream_t stream) + cudaStream_t stream, + int quant, + const float* kv_scale) { constexpr int block_sz = 128; constexpr int x = (sizeof(T) == 4) ? 4 : 8; dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); - extend_value_cache<<>>( - k_dst, dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); + if (quant & QuantPolicy::kCacheKVInt8) { + extend_value_cache_int8<<>>( + reinterpret_cast(k_dst), dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len, kv_scale[0]); + + extend_value_cache_int8<<>>( + reinterpret_cast(v_dst), dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len, kv_scale[1]); + + } else { + extend_value_cache<<>>( + k_dst, dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); - extend_value_cache<<>>( - v_dst, dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); + extend_value_cache<<>>( + v_dst, dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); + } } template void invokeExtendKVCache(float**, @@ -322,7 +431,9 @@ template void invokeExtendKVCache(float**, int, int, int, - cudaStream_t stream); + cudaStream_t stream, + int, + const float*); template void invokeExtendKVCache(half**, half**, @@ -336,17 +447,61 @@ template void invokeExtendKVCache(half**, int, int, int, - cudaStream_t stream); + cudaStream_t stream, + int, + const float*); + +// template +// __global__ void transpose_key_cache(T* k_dst, +// const T** k_src, +// const size_t src_offset, +// const int head_num, +// const int size_per_head, +// const int* seq_length, +// const int max_kv_len, +// const int max_seq_len) +// { +// const int batch_id = blockIdx.y; +// const int head_id = blockIdx.z; +// constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + +// const int idx = blockIdx.x * blockDim.x + threadIdx.x; +// int size_per_head_div_x = size_per_head / X_ELEMS; + +// // x dim is now handled by uint4 type +// const auto key_src = reinterpret_cast(k_src[batch_id] + src_offset); +// const auto key_dst = reinterpret_cast(k_dst); + +// const auto seq_len = seq_length[batch_id]; + +// const int k_head_size_id = idx % size_per_head_div_x; +// const int k_seq_len_id = idx / size_per_head_div_x; + +// if (k_seq_len_id < seq_len) { +// // [B, H, s, D/x] <- [B, H, D/x, S[:s]] + +// const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H +// k_head_size_id * max_seq_len + // D/x +// k_seq_len_id; // s + +// const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B +// head_id * size_per_head_div_x * max_kv_len + // H +// k_seq_len_id * size_per_head_div_x + // s +// k_head_size_id; // D/x + +// key_dst[dst_idx] = key_src[src_idx]; +// } +// } template -__global__ void transpose_key_cache(T* k_dst, - const T** k_src, - const size_t src_offset, - const int head_num, - const int size_per_head, - const int* seq_length, - const int max_kv_len, - const int max_seq_len) +__global__ void transpose_value_cache(T* v_dst, // + const T** v_src, + const size_t src_offset, + const int head_num, + const int size_per_head, + const int* seq_length, + const int max_kv_len, + const int max_seq_len) { const int batch_id = blockIdx.y; const int head_id = blockIdx.z; @@ -356,39 +511,40 @@ __global__ void transpose_key_cache(T* k_dst, int size_per_head_div_x = size_per_head / X_ELEMS; // x dim is now handled by uint4 type - const auto key_src = reinterpret_cast(k_src[batch_id] + src_offset); - const auto key_dst = reinterpret_cast(k_dst); + const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); + const auto val_dst = reinterpret_cast(v_dst); const auto seq_len = seq_length[batch_id]; - const int k_head_size_id = idx % size_per_head_div_x; - const int k_seq_len_id = idx / size_per_head_div_x; - - if (k_seq_len_id < seq_len) { - // [B, H, s, D/x] <- [B, H, D/x, S[:s]] + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] <- [B, H, S[:s], D/x] const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H - k_head_size_id * max_seq_len + // D/x - k_seq_len_id; // s + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B head_id * size_per_head_div_x * max_kv_len + // H - k_seq_len_id * size_per_head_div_x + // s - k_head_size_id; // D/x + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x - key_dst[dst_idx] = key_src[src_idx]; + val_dst[dst_idx] = val_src[src_idx]; } } + template -__global__ void transpose_value_cache(T* v_dst, // - const T** v_src, +__global__ void transpose_value_cache_int8(T* v_dst, // + const int8_t** v_src, const size_t src_offset, const int head_num, const int size_per_head, const int* seq_length, const int max_kv_len, - const int max_seq_len) + const int max_seq_len, + const float v_scale) { const int batch_id = blockIdx.y; const int head_id = blockIdx.z; @@ -398,7 +554,7 @@ __global__ void transpose_value_cache(T* v_dst, // int size_per_head_div_x = size_per_head / X_ELEMS; // x dim is now handled by uint4 type - const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); + const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); const auto val_dst = reinterpret_cast(v_dst); const auto seq_len = seq_length[batch_id]; @@ -417,7 +573,12 @@ __global__ void transpose_value_cache(T* v_dst, // v_seq_len_id * size_per_head_div_x + // s v_head_size_id; // D/x - val_dst[dst_idx] = val_src[src_idx]; + // int8x8 -> fp16x8 + const auto from_ptr = reinterpret_cast(val_src + src_idx); + auto to_ptr = reinterpret_cast(val_dst + dst_idx); + + to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale); + to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale); } } @@ -433,24 +594,35 @@ void invokeTransposeKVCache(T* key_cache_trans, int max_seq_len, int size_per_head, int head_num, - cudaStream_t stream) + cudaStream_t stream, + int quant, + const float* kv_scale) { constexpr int block_sz = 128; constexpr int x = (sizeof(T) == 4) ? 4 : 8; dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num); - transpose_value_cache<<>>( - key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); + if (quant & QuantPolicy::kCacheKVInt8) { + transpose_value_cache_int8<<>>( + key_cache_trans, reinterpret_cast(key_cache), src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len, kv_scale[0]); - transpose_value_cache<<>>( - val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); + transpose_value_cache_int8<<>>( + val_cache_trans, reinterpret_cast(val_cache), src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len, kv_scale[1]); + + } else { + transpose_value_cache<<>>( + key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); + + transpose_value_cache<<>>( + val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); + } } template void invokeTransposeKVCache( - float*, float*, const float**, const float**, size_t, int, const int*, int, int, int, int, cudaStream_t stream); + float*, float*, const float**, const float**, size_t, int, const int*, int, int, int, int, cudaStream_t stream, int, const float*); template void invokeTransposeKVCache( - half*, half*, const half**, const half**, size_t, int, const int*, int, int, int, int, cudaStream_t stream); + half*, half*, const half**, const half**, size_t, int, const int*, int, int, int, int, cudaStream_t stream, int, const float*); __global__ void gatherOutput(int* output_ids, const int* ids, diff --git a/src/fastertransformer/models/llama/llama_kernels.h b/src/fastertransformer/models/llama/llama_kernels.h index 492ab2dfc..3e3443bd7 100644 --- a/src/fastertransformer/models/llama/llama_kernels.h +++ b/src/fastertransformer/models/llama/llama_kernels.h @@ -46,7 +46,9 @@ void invokeExtendKVCache(T** k_dst, int max_seq_len, int size_per_head, int local_head_num, - cudaStream_t stream); + cudaStream_t stream, + int quant, + const float* kv_scale); template void invokeTransposeKVCache(T* key_cache_trans, @@ -60,7 +62,9 @@ void invokeTransposeKVCache(T* key_cache_trans, int max_seq_len, int size_per_head, int head_num, - cudaStream_t stream); + cudaStream_t stream, + int quant_policy, + const float* kv_scale); void invokeGatherOutput(int* output_ids, const int* ids, @@ -151,24 +155,24 @@ struct TempBuffer { T* data; }; -template -inline T* -transpose_key_cache(T* key_cache, size_t head_num, size_t size_per_head_by_x, size_t mem_len, size_t x, cudaStream_t st) -{ - static TempBuffer buf(8192 * 8192); - // from: H Dx, S, x - // to : S, H Dx, x - invokeTransposeAxis01(buf.data, key_cache, head_num * size_per_head_by_x, mem_len, x, st); - return buf.data; -} - -template -inline T* transpose_value_cache(T* value_cache, size_t head_num, size_t mem_len, size_t size_per_head, cudaStream_t st) -{ - static TempBuffer buf(8192 * 8192); - invokeTransposeAxis01(buf.data, value_cache, head_num, mem_len, size_per_head, st); - return buf.data; -} +// template +// inline T* +// transpose_key_cache(T* key_cache, size_t head_num, size_t size_per_head_by_x, size_t mem_len, size_t x, cudaStream_t st) +// { +// static TempBuffer buf(8192 * 8192); +// // from: H Dx, S, x +// // to : S, H Dx, x +// invokeTransposeAxis01(buf.data, key_cache, head_num * size_per_head_by_x, mem_len, x, st); +// return buf.data; +// } + +// template +// inline T* transpose_value_cache(T* value_cache, size_t head_num, size_t mem_len, size_t size_per_head, cudaStream_t st) +// { +// static TempBuffer buf(8192 * 8192); +// invokeTransposeAxis01(buf.data, value_cache, head_num, mem_len, size_per_head, st); +// return buf.data; +// } inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st) { diff --git a/src/fastertransformer/models/llama/llama_utils.h b/src/fastertransformer/models/llama/llama_utils.h index b7889fab0..6071291ee 100644 --- a/src/fastertransformer/models/llama/llama_utils.h +++ b/src/fastertransformer/models/llama/llama_utils.h @@ -9,6 +9,14 @@ namespace fastertransformer { +enum QuantPolicy { + kNone = 0x00, + // reserve 0x01 and 0x02 + // quantize cache kv + kCacheKVInt8 = 0x04, + kWeightInt4 = 0x08, +}; + enum CmpMode { kCmpNone, diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc index 9f84b7d17..c9a6ec9e5 100644 --- a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -125,10 +125,11 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, session_len_ = reader.GetInteger("llama", "session_len", 0); step_length_ = reader.GetInteger("llama", "step_length", 0); cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0); - use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); + use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 0); cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); + quant_policy_ = reader.GetInteger("llama", "quant_policy", 4); handleMissingParams(); @@ -224,6 +225,7 @@ std::unique_ptr> LlamaTritonModel::createSh end_id_, cache_max_entry_count_, cache_chunk_size_, + quant_policy_, use_context_fmha_, shared_state_, shared_weights_[device_id].get(), diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h index 013b0f27c..960027607 100644 --- a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h @@ -93,6 +93,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { size_t pipeline_para_size_; ft::WeightType weight_type_; bool attn_bias_; + int quant_policy_; size_t prefix_cache_len_{}; diff --git a/src/fastertransformer/utils/memory_utils.cu b/src/fastertransformer/utils/memory_utils.cu index 3c7c8e731..27cc8c28e 100644 --- a/src/fastertransformer/utils/memory_utils.cu +++ b/src/fastertransformer/utils/memory_utils.cu @@ -344,6 +344,11 @@ std::vector loadWeightFromBinHelper(std::vector shape, std::string fi return host_array; } +std::vector loadArrayFromBin(std::vector shape, std::string filename) +{ + return loadWeightFromBinHelper(shape, filename); +} + template int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) { @@ -523,7 +528,7 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename) std::vector h_ptr(size); cudaD2Hcpy(h_ptr.data(), ptr, size); std::vector float_ptr(size); - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { float_ptr[i] = (float)h_ptr[i]; } diff --git a/src/fastertransformer/utils/memory_utils.h b/src/fastertransformer/utils/memory_utils.h index 316a4dd2b..68087a787 100644 --- a/src/fastertransformer/utils/memory_utils.h +++ b/src/fastertransformer/utils/memory_utils.h @@ -55,6 +55,8 @@ int loadWeightFromBin(T* ptr, std::string filename, FtCudaDataType model_file_type = FtCudaDataType::FP32); +std::vector loadArrayFromBin(std::vector shape, std::string filename); + // template // int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, // T* scale_ptr, From b5623850906a20825382317b0886bfe2222be3bf Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 27 Jun 2023 04:27:35 +0000 Subject: [PATCH 2/9] feat(kernels): fix --- ...coder_masked_multihead_attention_template.cuh | 16 ++++++++-------- src/fastertransformer/models/llama/llama_utils.h | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index 9f60918fd..bd969addf 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -1569,14 +1569,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params int8_t* k_cache_batch_int8 = nullptr; if (params.int8_mode & QuantPolicy::kCacheKVInt8) { - T* k_cache = - params.k_cache_per_sample ? - (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : - ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - // T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - k_cache_batch = k_cache; - } else { // convert k_cache_per_sample to int8 if (params.k_cache_per_sample) { int8_t** ptr = reinterpret_cast(params.k_cache_per_sample); @@ -1585,6 +1577,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params int8_t* ptr = reinterpret_cast(params.k_cache); k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki]; } + } else { + T* k_cache = + params.k_cache_per_sample ? + (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : + ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + // T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + k_cache_batch = k_cache; } diff --git a/src/fastertransformer/models/llama/llama_utils.h b/src/fastertransformer/models/llama/llama_utils.h index 6071291ee..72457c532 100644 --- a/src/fastertransformer/models/llama/llama_utils.h +++ b/src/fastertransformer/models/llama/llama_utils.h @@ -11,10 +11,11 @@ namespace fastertransformer { enum QuantPolicy { kNone = 0x00, - // reserve 0x01 and 0x02 + // reserve 0x01 and 0x02 for backward compatibility + kReserve1 = 0x01, + kReserve2 = 0x02, // quantize cache kv kCacheKVInt8 = 0x04, - kWeightInt4 = 0x08, }; enum CmpMode From 7799608adfb1221e2dcffb3230a7fb1a8c201ae7 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 27 Jun 2023 08:14:24 +0000 Subject: [PATCH 3/9] feat(llama): update kernel --- llmdeploy/serve/fastertransformer/deploy.py | 6 ++++-- .../models/llama/LlamaDecoderLayerWeight.cc | 8 +++++++- src/fastertransformer/models/llama/LlamaV2.cc | 1 + .../triton_backend/llama/LlamaTritonModel.cc | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/llmdeploy/serve/fastertransformer/deploy.py b/llmdeploy/serve/fastertransformer/deploy.py index 2b2756477..4baa87841 100644 --- a/llmdeploy/serve/fastertransformer/deploy.py +++ b/llmdeploy/serve/fastertransformer/deploy.py @@ -147,7 +147,7 @@ def save_bin(param: torch.Tensor, name): step_length=1, cache_max_entry_count=48, cache_chunk_size=8, - use_context_fmha=1)) + use_context_fmha=0)) config = configparser.ConfigParser() for section, key_values in cfg.items(): @@ -302,6 +302,8 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, _files = [file for file in os.listdir(model_path) if file.endswith('.bin')] _files = sorted(_files) + print(_files) + _params = {} for _file in _files: _tmp = torch.load(osp.join(model_path, _file), map_location='cpu') @@ -369,7 +371,7 @@ def get_tensor_transposed(name): for ft, hf in other: model_params[ft] = get_tensor(hf) - return export(model_name, i + 1, norm_eps, model_params, tokenizer_path, + return export(model_name, num_layer, norm_eps, model_params, tokenizer_path, triton_models_path, tp) diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc index 5f8d31459..0d8e1f6c5 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -162,7 +162,13 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType // load kv_cache quant scale // if file not exist, get empty vector std::string scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight"; - self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); + std::ifstream in(scale_path, std::ios::in); + if (in.is_open()) { + in.close(); + self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); + } else { + self_attn_weights.past_kv_scale = {}; + } } template struct LlamaDecoderLayerWeight; diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc index df7fa694e..0d25ea54e 100644 --- a/src/fastertransformer/models/llama/LlamaV2.cc +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -91,6 +91,7 @@ LlamaV2::LlamaV2(size_t head_num, FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); size_t elem_bits = 0; + fprintf(stdout, "******** quant_policy %d\n", quant_policy); if (quant_policy & QuantPolicy::kCacheKVInt8) { elem_bits = sizeof(int8_t) * 8; if (use_context_fmha) { diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc index c9a6ec9e5..842bdd72e 100644 --- a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -129,7 +129,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); - quant_policy_ = reader.GetInteger("llama", "quant_policy", 4); + quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); handleMissingParams(); From 619ad9b07a12d3f2e22471fa35bd32767bbaf8f8 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 04:37:34 +0000 Subject: [PATCH 4/9] feat(src): add debug --- .../llama/LlamaContextAttentionLayer.cc | 45 +++++++++++++++++++ .../models/llama/LlamaDecoderLayerWeight.cc | 3 +- .../llama/LlamaDecoderSelfAttentionLayer.cc | 44 ++++++++++++------ src/fastertransformer/models/llama/LlamaV2.cc | 1 + .../triton_backend/llama/LlamaTritonModel.cc | 5 ++- 5 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc index 651857493..9fc8871bf 100644 --- a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc @@ -19,6 +19,8 @@ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc +#include +#include #include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h" #include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/unfused_attention_kernels.h" @@ -198,7 +200,23 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* stream_, quant_policy_, weights->past_kv_scale.data()); + + + const std::string path_base = "/workspace/save/context_"; + sync_check_cuda_error(); + std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy"; + std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy"; + if (not std::filesystem::exists(kpath)) { + // auto kptr = k_cache_ptrs[0]; + // auto vptr = v_cache_ptrs[0]; + + // Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, kptr); + // k.saveNpy(kpath); + + // Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, vptr); + // v.saveNpy(vpath); + } if (use_fmha_) { fusedMultiHeadAttention(k_cache_ptrs, @@ -225,6 +243,8 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* max_seq_len, quant_policy_, weights->past_kv_scale.data()); + + } ////////////////////////////////////////////// @@ -237,6 +257,8 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* sync_check_cuda_error(); } + + if (is_free_buffer_after_forward_ == true) { freeBuffer(); } @@ -329,6 +351,23 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac kv_scale); sync_check_cuda_error(); + const std::string path_base = "/workspace/save-fp16/context_"; + std::string hkpath = path_base + "hk" + ".npy"; + std::string hvpath = path_base + "hv" + ".npy"; + std::string hspath = path_base + "hs" + ".npy"; + + if (not std::filesystem::exists(hkpath)) { + + Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, k_cache_buf_); + k.saveNpy(hkpath); + + Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, v_cache_buf_); + v.saveNpy(hvpath); + + Tensor s(MemoryType::MEMORY_CPU, DataType::TYPE_FP32, {2}, kv_scale); + s.saveNpy(hspath); + } + const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); ////////////////////////////////////////////// @@ -398,6 +437,12 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac 0, stream_); sync_check_cuda_error(); + + std::string hapath = path_base + "ha" + "policy" + std::to_string(quant_policy_) + ".npy"; + if (not std::filesystem::exists(hapath)) { + Tensor attn(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, max_q_len}, qkv_buf_3_); + attn.saveNpy(hapath); + } } template class LlamaContextAttentionLayer; diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc index 0d8e1f6c5..5141f0cbb 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -166,6 +166,7 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType if (in.is_open()) { in.close(); self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); + fprintf(stdout, "****** %s %f %f %ld\n", scale_path.c_str(), self_attn_weights.past_kv_scale[0], self_attn_weights.past_kv_scale[1], self_attn_weights.past_kv_scale.size()); } else { self_attn_weights.past_kv_scale = {}; } @@ -174,4 +175,4 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType template struct LlamaDecoderLayerWeight; template struct LlamaDecoderLayerWeight; -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc index 26babcd1c..a59934da3 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -17,7 +17,8 @@ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc - +#include +#include #include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/models/llama/LlamaNcclGuard.h" @@ -99,14 +100,14 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // Set the input buffers. params.q = reinterpret_cast(qkv_buf); - if (int8_mode != 2) { - params.k = reinterpret_cast(qkv_buf) + hidden_units; - params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; - } - else { - params.k = reinterpret_cast(reinterpret_cast(qkv_buf) + hidden_units); - params.v = reinterpret_cast(reinterpret_cast(qkv_buf) + 2 * hidden_units); - } + // if (int8_mode != 2) { + params.k = reinterpret_cast(qkv_buf) + hidden_units; + params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + // } + // else { + // params.k = reinterpret_cast(reinterpret_cast(qkv_buf) + hidden_units); + // params.v = reinterpret_cast(reinterpret_cast(qkv_buf) + 2 * hidden_units); + // } params.stride = 3 * hidden_units; params.finished = const_cast(finished); @@ -149,10 +150,11 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.ia3_value_weights = reinterpret_cast(ia3_value_weights); params.int8_mode = int8_mode; - if (int8_mode == 2) { - params.qkv_scale_out = qkv_scale_out; - params.attention_out_scale = attention_out_scale; - } else if (int8_mode == QuantPolicy::kCacheKVInt8) { + // if (int8_mode == 2) { + // params.qkv_scale_out = qkv_scale_out; + // params.attention_out_scale = attention_out_scale; + // } + if (int8_mode & QuantPolicy::kCacheKVInt8) { params.attention_k_scale = attention_kv_scale[0]; params.attention_v_scale = attention_kv_scale[1]; } @@ -278,6 +280,22 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o stream_); sync_check_cuda_error(); + const std::string path_base = "/workspace/save-fp16/normal_"; + std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy"; + std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy"; + // if (not std::filesystem::exists(kpath)) { + // auto kptr = key_cache_ptrs[0]; + // auto vptr = value_cache_ptrs[0]; + + // Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, kptr); + // k.saveNpy(kpath); + + // Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, vptr); + // v.saveNpy(vpath); + // } + + + linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output); if (tensor_para_.world_size_ > 1) { diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc index 0d25ea54e..6ded3c6b4 100644 --- a/src/fastertransformer/models/llama/LlamaV2.cc +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -94,6 +94,7 @@ LlamaV2::LlamaV2(size_t head_num, fprintf(stdout, "******** quant_policy %d\n", quant_policy); if (quant_policy & QuantPolicy::kCacheKVInt8) { elem_bits = sizeof(int8_t) * 8; + // elem_bits = sizeof(T) * 8; if (use_context_fmha) { FT_LOG_ERROR("use_context_fmha not support int8"); assert(0); diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc index 842bdd72e..3fea89601 100644 --- a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -129,7 +129,8 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); - quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); + quant_policy_ = 0; + // quant_policy_ = reader.GetInteger("llama", "quant_policy", 4); handleMissingParams(); @@ -309,7 +310,7 @@ std::string LlamaTritonModel::toString() << "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ - << "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << std::endl; + << "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl; return ss.str(); } From 2ea4e4af20e6597aefaaf28d2be60e9791490cff Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 08:20:19 +0000 Subject: [PATCH 5/9] fix(kernel): k_cache use int8_t pointer --- README.md | 9 ++++- README_zh-CN.md | 19 ++++++--- ...er_masked_multihead_attention_template.cuh | 30 +++++++------- .../llama/LlamaContextAttentionLayer.cc | 40 ------------------- .../llama/LlamaDecoderSelfAttentionLayer.cc | 22 ++-------- .../triton_backend/llama/LlamaTritonModel.cc | 5 +-- 6 files changed, 42 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 6a189a8dc..f64fd90b5 100644 --- a/README.md +++ b/README.md @@ -158,12 +158,19 @@ bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fast python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1 ``` -## User Guide +## Quantization + +In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users. +First execute the quantization script, and the quantization parameters are stored in the weight directory transformed by `deploy.py`. +Then adjust `config.ini` +* `use_context_fmha` changed to 0, means off +* `quant_policy` is set to 4. This parameter defaults to 0, which means it is not enabled ## Contributing We appreciate all contributions to LLMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. + ## Acknowledgement - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) diff --git a/README_zh-CN.md b/README_zh-CN.md index 613b5f1a6..d3ceeb7de 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -68,7 +68,7 @@ pip install -e . ```shell python3 llmdeploy/serve/fastertransformer/deploy.py llama-7B /path/to/llama-7b llama \ --tokenizer_path /path/to/tokenizer/model -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -79,7 +79,7 @@ bash workspace/service_docker_up.sh ```shell python3 llmdeploy/serve/fastertransformer/deploy.py llama-13B /path/to/llama-13b llama \ --tokenizer_path /path/to/tokenizer/model --tp 2 -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -90,7 +90,7 @@ bash workspace/service_docker_up.sh ```shell python3 llmdeploy/serve/fastertransformer/deploy.py llama-33B /path/to/llama-33b llama \ --tokenizer_path /path/to/tokenizer/model --tp 4 -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -101,7 +101,7 @@ bash workspace/service_docker_up.sh ```shell python3 llmdeploy/serve/fastertransformer/deploy.py llama-65B /path/to/llama-65b llama \ --tokenizer_path /path/to/tokenizer/model --tp 8 -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -119,7 +119,7 @@ python3 -m fastchat.model.apply_delta \ --delta-path lmsys/vicuna-7b-delta-v1.1 python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-7B /path/to/vicuna-7b hf -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -135,7 +135,7 @@ python3 -m fastchat.model.apply_delta \ --delta-path lmsys/vicuna-13b-delta-v1.1 python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-13B /path/to/vicuna-13b hf -bash workspace/service_docker_up.sh +bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer ``` @@ -146,6 +146,13 @@ bash workspace/service_docker_up.sh python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1 ``` +## 量化部署 +在 fp16 模式下,可以开启 kv_cache int8 量化,单卡可服务更多用户。 +首先执行量化脚本,量化参数存放到 `deploy.py` 转换的 weight 目录下。 +然后调整 `config.ini` +* `use_context_fmha` 改为 0,表示关闭 +* `quant_policy` 设置为 4。此参数默认为 0,表示不开启 + ## 贡献指南 我们感谢所有的贡献者为改进和提升 LLMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。 diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index bd969addf..d42ac8d6d 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -1052,13 +1052,13 @@ inline __device__ int32_t quant(float4 a, const float scale) } // float16 to int8 -// inline __device__ int8_t quant(uint16_t a, const float scale) -// { -// int8_t int8; -// float b = half_to_float(a); -// int8 = round(max(-128.f, min(127.f, b.x / scale))); -// return int8; -// } +inline __device__ int8_t quant(uint16_t a, const float scale) +{ + int8_t int8; + float b = half_to_float(a); + int8 = round(max(-128.f, min(127.f, b / scale))); + return int8; +} // float16x2 to int8x2 inline __device__ int16_t quant(uint a, const float scale) { @@ -1460,8 +1460,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params using Packed_Int8_t = typename packed_type::value>::type; Packed_Int8_t k_int8 = quant(k, k_scale); - Packed_Int8_t* dst_ptr = reinterpret_cast(params.k_cache); - dst_ptr[offset] = k_int8; + int8_t* dst_ptr = reinterpret_cast(params.k_cache); + *reinterpret_cast(&dst_ptr[offset]) = k_int8; } else { *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); } @@ -1481,8 +1481,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params using Packed_Int8_t = typename packed_type::value>::type; Packed_Int8_t k_int8 = quant(k, k_scale); - Packed_Int8_t** dst_ptr = reinterpret_cast(params.k_cache_per_sample); - dst_ptr[bi][offset] = k_int8; + int8_t* dst_ptr = reinterpret_cast(params.k_cache_per_sample[bi]); + *reinterpret_cast(&dst_ptr[offset]) = k_int8; } else { *reinterpret_cast(¶ms.k_cache_per_sample[bi][offset]) = vec_conversion(k); @@ -1571,8 +1571,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (params.int8_mode & QuantPolicy::kCacheKVInt8) { // convert k_cache_per_sample to int8 if (params.k_cache_per_sample) { - int8_t** ptr = reinterpret_cast(params.k_cache_per_sample); - k_cache_batch_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; + int8_t* ptr = reinterpret_cast(params.k_cache_per_sample[bi]); + k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; } else { int8_t* ptr = reinterpret_cast(params.k_cache); k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki]; @@ -1755,8 +1755,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.v_cache_per_sample) { - int8_t** ptr = reinterpret_cast(params.v_cache_per_sample); - v_cache_int8 = ptr[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; + int8_t* ptr = reinterpret_cast(params.v_cache_per_sample[bi]); + v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; } else { int8_t* ptr = reinterpret_cast(params.v_cache); v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi]; diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc index 9fc8871bf..bca1027c2 100644 --- a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc @@ -19,8 +19,6 @@ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -#include -#include #include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h" #include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/unfused_attention_kernels.h" @@ -202,22 +200,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* weights->past_kv_scale.data()); - const std::string path_base = "/workspace/save/context_"; - sync_check_cuda_error(); - std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy"; - std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy"; - if (not std::filesystem::exists(kpath)) { - // auto kptr = k_cache_ptrs[0]; - // auto vptr = v_cache_ptrs[0]; - - // Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, kptr); - // k.saveNpy(kpath); - - // Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, batch_size, local_head_num_}, vptr); - // v.saveNpy(vpath); - } - if (use_fmha_) { fusedMultiHeadAttention(k_cache_ptrs, v_cache_ptrs, @@ -351,23 +334,6 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac kv_scale); sync_check_cuda_error(); - const std::string path_base = "/workspace/save-fp16/context_"; - std::string hkpath = path_base + "hk" + ".npy"; - std::string hvpath = path_base + "hv" + ".npy"; - std::string hspath = path_base + "hs" + ".npy"; - - if (not std::filesystem::exists(hkpath)) { - - Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, k_cache_buf_); - k.saveNpy(hkpath); - - Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, v_cache_buf_); - v.saveNpy(hvpath); - - Tensor s(MemoryType::MEMORY_CPU, DataType::TYPE_FP32, {2}, kv_scale); - s.saveNpy(hspath); - } - const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); ////////////////////////////////////////////// @@ -437,12 +403,6 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cac 0, stream_); sync_check_cuda_error(); - - std::string hapath = path_base + "ha" + "policy" + std::to_string(quant_policy_) + ".npy"; - if (not std::filesystem::exists(hapath)) { - Tensor attn(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, max_q_len}, qkv_buf_3_); - attn.saveNpy(hapath); - } } template class LlamaContextAttentionLayer; diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc index a59934da3..4987e9fe2 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -17,8 +17,6 @@ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc -#include -#include #include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/models/llama/LlamaNcclGuard.h" @@ -241,6 +239,10 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; const int memory_len = max_seq_len; + if (kv_cache_layer_offset >= 192937984) { + fprintf(stderr, "gdb"); + } + fusedQKV_masked_attention_dispatch( qkv_buf_, weights->qkv.bias, // query_weight.bias, @@ -280,22 +282,6 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o stream_); sync_check_cuda_error(); - const std::string path_base = "/workspace/save-fp16/normal_"; - std::string kpath = path_base + "k" + std::to_string(layer_id) + ".npy"; - std::string vpath = path_base + "v" + std::to_string(layer_id) + ".npy"; - // if (not std::filesystem::exists(kpath)) { - // auto kptr = key_cache_ptrs[0]; - // auto vptr = value_cache_ptrs[0]; - - // Tensor k(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, kptr); - // k.saveNpy(kpath); - - // Tensor v(MemoryType::MEMORY_GPU, DataType::TYPE_FP16, {size_per_head_, local_head_num_}, vptr); - // v.saveNpy(vpath); - // } - - - linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output); if (tensor_para_.world_size_ > 1) { diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc index 3fea89601..b94c062bb 100644 --- a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -125,12 +125,11 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, session_len_ = reader.GetInteger("llama", "session_len", 0); step_length_ = reader.GetInteger("llama", "step_length", 0); cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0); - use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 0); + use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); - quant_policy_ = 0; - // quant_policy_ = reader.GetInteger("llama", "quant_policy", 4); + quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); handleMissingParams(); From 7a5c236c018ecaf31dba85288f6f07fd64007fdb Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 08:40:11 +0000 Subject: [PATCH 6/9] style(llama): clean code --- .../models/llama/LlamaDecoderLayerWeight.cc | 1 - .../llama/LlamaDecoderSelfAttentionLayer.cc | 16 +------ src/fastertransformer/models/llama/LlamaV2.cc | 1 - .../models/llama/llama_kernels.cu | 42 ------------------- .../models/llama/llama_kernels.h | 19 --------- 5 files changed, 2 insertions(+), 77 deletions(-) diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc index 5141f0cbb..18587b30e 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -166,7 +166,6 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType if (in.is_open()) { in.close(); self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); - fprintf(stdout, "****** %s %f %f %ld\n", scale_path.c_str(), self_attn_weights.past_kv_scale[0], self_attn_weights.past_kv_scale[1], self_attn_weights.past_kv_scale.size()); } else { self_attn_weights.past_kv_scale = {}; } diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc index 4987e9fe2..0b5fbd76a 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -98,14 +98,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // Set the input buffers. params.q = reinterpret_cast(qkv_buf); - // if (int8_mode != 2) { params.k = reinterpret_cast(qkv_buf) + hidden_units; params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; - // } - // else { - // params.k = reinterpret_cast(reinterpret_cast(qkv_buf) + hidden_units); - // params.v = reinterpret_cast(reinterpret_cast(qkv_buf) + 2 * hidden_units); - // } + params.stride = 3 * hidden_units; params.finished = const_cast(finished); @@ -148,10 +143,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.ia3_value_weights = reinterpret_cast(ia3_value_weights); params.int8_mode = int8_mode; - // if (int8_mode == 2) { - // params.qkv_scale_out = qkv_scale_out; - // params.attention_out_scale = attention_out_scale; - // } + if (int8_mode & QuantPolicy::kCacheKVInt8) { params.attention_k_scale = attention_kv_scale[0]; params.attention_v_scale = attention_kv_scale[1]; @@ -239,10 +231,6 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; const int memory_len = max_seq_len; - if (kv_cache_layer_offset >= 192937984) { - fprintf(stderr, "gdb"); - } - fusedQKV_masked_attention_dispatch( qkv_buf_, weights->qkv.bias, // query_weight.bias, diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc index 6ded3c6b4..0d25ea54e 100644 --- a/src/fastertransformer/models/llama/LlamaV2.cc +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -94,7 +94,6 @@ LlamaV2::LlamaV2(size_t head_num, fprintf(stdout, "******** quant_policy %d\n", quant_policy); if (quant_policy & QuantPolicy::kCacheKVInt8) { elem_bits = sizeof(int8_t) * 8; - // elem_bits = sizeof(T) * 8; if (use_context_fmha) { FT_LOG_ERROR("use_context_fmha not support int8"); assert(0); diff --git a/src/fastertransformer/models/llama/llama_kernels.cu b/src/fastertransformer/models/llama/llama_kernels.cu index 5884eb7dd..26cf72cef 100644 --- a/src/fastertransformer/models/llama/llama_kernels.cu +++ b/src/fastertransformer/models/llama/llama_kernels.cu @@ -451,48 +451,6 @@ template void invokeExtendKVCache(half**, int, const float*); -// template -// __global__ void transpose_key_cache(T* k_dst, -// const T** k_src, -// const size_t src_offset, -// const int head_num, -// const int size_per_head, -// const int* seq_length, -// const int max_kv_len, -// const int max_seq_len) -// { -// const int batch_id = blockIdx.y; -// const int head_id = blockIdx.z; -// constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - -// const int idx = blockIdx.x * blockDim.x + threadIdx.x; -// int size_per_head_div_x = size_per_head / X_ELEMS; - -// // x dim is now handled by uint4 type -// const auto key_src = reinterpret_cast(k_src[batch_id] + src_offset); -// const auto key_dst = reinterpret_cast(k_dst); - -// const auto seq_len = seq_length[batch_id]; - -// const int k_head_size_id = idx % size_per_head_div_x; -// const int k_seq_len_id = idx / size_per_head_div_x; - -// if (k_seq_len_id < seq_len) { -// // [B, H, s, D/x] <- [B, H, D/x, S[:s]] - -// const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H -// k_head_size_id * max_seq_len + // D/x -// k_seq_len_id; // s - -// const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B -// head_id * size_per_head_div_x * max_kv_len + // H -// k_seq_len_id * size_per_head_div_x + // s -// k_head_size_id; // D/x - -// key_dst[dst_idx] = key_src[src_idx]; -// } -// } - template __global__ void transpose_value_cache(T* v_dst, // const T** v_src, diff --git a/src/fastertransformer/models/llama/llama_kernels.h b/src/fastertransformer/models/llama/llama_kernels.h index 3e3443bd7..87e71c82a 100644 --- a/src/fastertransformer/models/llama/llama_kernels.h +++ b/src/fastertransformer/models/llama/llama_kernels.h @@ -155,25 +155,6 @@ struct TempBuffer { T* data; }; -// template -// inline T* -// transpose_key_cache(T* key_cache, size_t head_num, size_t size_per_head_by_x, size_t mem_len, size_t x, cudaStream_t st) -// { -// static TempBuffer buf(8192 * 8192); -// // from: H Dx, S, x -// // to : S, H Dx, x -// invokeTransposeAxis01(buf.data, key_cache, head_num * size_per_head_by_x, mem_len, x, st); -// return buf.data; -// } - -// template -// inline T* transpose_value_cache(T* value_cache, size_t head_num, size_t mem_len, size_t size_per_head, cudaStream_t st) -// { -// static TempBuffer buf(8192 * 8192); -// invokeTransposeAxis01(buf.data, value_cache, head_num, mem_len, size_per_head, st); -// return buf.data; -// } - inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st) { int h_seq_len = -1; From bce6779b60a62b578ff152568a0ca76330159573 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 08:47:06 +0000 Subject: [PATCH 7/9] feat(deploy.py): revert to enable fmha --- llmdeploy/serve/fastertransformer/deploy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llmdeploy/serve/fastertransformer/deploy.py b/llmdeploy/serve/fastertransformer/deploy.py index 4baa87841..6fc7db966 100644 --- a/llmdeploy/serve/fastertransformer/deploy.py +++ b/llmdeploy/serve/fastertransformer/deploy.py @@ -147,7 +147,7 @@ def save_bin(param: torch.Tensor, name): step_length=1, cache_max_entry_count=48, cache_chunk_size=8, - use_context_fmha=0)) + use_context_fmha=1)) config = configparser.ConfigParser() for section, key_values in cfg.items(): @@ -301,7 +301,6 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, _files = [file for file in os.listdir(model_path) if file.endswith('.bin')] _files = sorted(_files) - print(_files) _params = {} From 39f803401e51b89683e1042439510542d78a86d9 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 08:56:03 +0000 Subject: [PATCH 8/9] style(LlamaV2): clean code --- src/fastertransformer/models/llama/LlamaV2.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc index 0d25ea54e..df7fa694e 100644 --- a/src/fastertransformer/models/llama/LlamaV2.cc +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -91,7 +91,6 @@ LlamaV2::LlamaV2(size_t head_num, FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); size_t elem_bits = 0; - fprintf(stdout, "******** quant_policy %d\n", quant_policy); if (quant_policy & QuantPolicy::kCacheKVInt8) { elem_bits = sizeof(int8_t) * 8; if (use_context_fmha) { From c3706d438f23729be645eb3bffd8ed2cae2d91aa Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Wed, 28 Jun 2023 09:08:17 +0000 Subject: [PATCH 9/9] feat(deploy.py): add default quant policy --- llmdeploy/serve/fastertransformer/deploy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmdeploy/serve/fastertransformer/deploy.py b/llmdeploy/serve/fastertransformer/deploy.py index 6fc7db966..a69d6b43a 100644 --- a/llmdeploy/serve/fastertransformer/deploy.py +++ b/llmdeploy/serve/fastertransformer/deploy.py @@ -147,7 +147,8 @@ def save_bin(param: torch.Tensor, name): step_length=1, cache_max_entry_count=48, cache_chunk_size=8, - use_context_fmha=1)) + use_context_fmha=1, + quant_policy=0)) config = configparser.ConfigParser() for section, key_values in cfg.items():