Skip to content

Commit

Permalink
[Kernel] [FP8] Improve FP8 linear layer performance (vllm-project#4691)
Browse files Browse the repository at this point in the history
This PR improves the FP8 performance of linear layers, which had been lacking before (vllm-project#4118 (comment) and vllm-project#4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
  • Loading branch information
pcmoritz committed May 9, 2024
1 parent ebce310 commit 379da6d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
28 changes: 27 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
if batch_dim_padding:
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
output = torch.empty(shape,
device=input.device,
dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,14 @@ def apply(self,
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)

# Fused GEMM_DQ
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
Expand All @@ -243,7 +248,7 @@ def apply(self,
bias=bias,
)

return output
return torch.narrow(output, 0, 0, x.shape[0])


def all_close_1d(x: torch.Tensor) -> bool:
Expand Down

0 comments on commit 379da6d

Please sign in to comment.