Skip to content

Commit

Permalink
fix tritonv2 in make_sure_no_tensor_in_meta_device (AutoGPTQ#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaaZa committed Mar 28, 2024
1 parent 9fd7a6c commit ca187fb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in


def make_sure_no_tensor_in_meta_device(
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False,
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False, use_tritonv2: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin)
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin, use_tritonv2=use_tritonv2)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer("bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
Expand Down

0 comments on commit ca187fb

Please sign in to comment.