Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: dtype mismatch during quantization #43

Merged
merged 1 commit into from
May 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Fix: dtype mismatch during quantization
  • Loading branch information
junrushao committed May 3, 2023
commit 990b9b5e9147d81351d55fbe29354f969acf3e64
24 changes: 15 additions & 9 deletions mlc_llm/transform/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from typing import List

import tvm
from tvm import relax
from tvm import te, tir, topi
from tvm import relax, te, tir, topi
from tvm.ir.module import IRModule
from tvm.relax.expr_functor import mutator, PyExprMutator
from tvm.relax.analysis import remove_all_unused
from tvm.relax.expr_functor import PyExprMutator, mutator
from tvm.relax.op.builtin import stop_lift_params

from tvm.script import tir as T


Expand Down Expand Up @@ -152,7 +150,7 @@ def te_encode_sym(weight: te.Tensor):
max_abs_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(tir.if_then_else(j * group_size + k < weight.shape[1], te.abs(weight[i, j * group_size + k]), tir.min_value(dtype)), axis=k), name="max_abs_value")

def f_compute_scale(i, j):
max_value = tir.Max(max_abs_value[i, j], tir.const(1e-4, dtype))
max_value = tir.max(max_abs_value[i, j], tir.const(1e-4, dtype))
return (max_value / tir.const(max_int_value, dtype)) if mode.startswith("int") else max_value

scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale")
Expand Down Expand Up @@ -295,7 +293,9 @@ def __init__(
self.storage_nbit = storage_nbit
self.dtype = dtype

def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
def transform_module(
self, mod: IRModule, ctx: tvm.transform.PassContext
) -> IRModule:
@mutator
class QuantizeMutator(PyExprMutator):
def __init__(
Expand Down Expand Up @@ -347,10 +347,16 @@ def emit_encoding(self, x: relax.Expr, transpose: bool) -> List[relax.Expr]:
)

decode_args = []
decode_args.append(self.builder_.emit(relax.TupleGetItem(encoded_data, 0)))
decode_args.append(self.builder_.emit(relax.TupleGetItem(encoded_data, 1)))
decode_args.append(
self.builder_.emit(relax.TupleGetItem(encoded_data, 0))
)
decode_args.append(
self.builder_.emit(relax.TupleGetItem(encoded_data, 1))
)
if self.dtype == "float16" and not self.sym:
decode_args.append(self.builder_.emit(relax.TupleGetItem(encoded_data, 2)))
decode_args.append(
self.builder_.emit(relax.TupleGetItem(encoded_data, 2))
)
for i, arg in enumerate(decode_args):
decode_args[i] = self.builder_.emit(stop_lift_params(arg))
return decode_args
Expand Down