Skip to content

Commit

Permalink
Untie weights for safetensors serialization (AutoGPTQ#536)
Browse files Browse the repository at this point in the history
* safetensors can not handle tied weights

* add recursive getattr/setattr
  • Loading branch information
fxmarty committed Feb 12, 2024
1 parent 1c6a62c commit 682ceb0
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 7 deletions.
3 changes: 1 addition & 2 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def from_pretrained(cls, save_dir: str, **kwargs):
field_names = [field.name for field in fields(cls)]
with open(resolved_config_file, "r", encoding="utf-8") as f:
args_from_json = json.load(f)

if transformers_config:
args_from_json = args_from_json["quantization_config"]

Expand Down Expand Up @@ -1029,7 +1029,6 @@ def skip(*args, **kwargs):
gptq_layers = set()
non_gptq_params = set()
with safe_open(model_save_name, framework="pt") as f:
state_dict_keys = list(f.keys())
for state_dict_key in f.keys():
if "qweight" not in state_dict_key and "qzeros" not in state_dict_key and "scales" not in state_dict_key:
non_gptq_params.add(state_dict_key)
Expand Down
4 changes: 2 additions & 2 deletions auto_gptq/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = getLogger(__name__)

try:
import marlin_cuda
import autogptq_marlin_cuda

_marlin_available = True
except ImportError:
Expand All @@ -40,7 +40,7 @@ def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1):
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
"""
marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms)
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms)


# Precompute permutations for Marlin weight and scale shuffling
Expand Down
2 changes: 1 addition & 1 deletion auto_gptq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
QIGEN_EXCEPTION = e

try:
import marlin_cuda
import autogptq_marlin_cuda

MARLIN_AVAILABLE = True
except Exception as e:
Expand Down
14 changes: 13 additions & 1 deletion auto_gptq/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from safetensors.torch import save_file as safe_save
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import dequantize_weight

from accelerate.utils import find_tied_parameters
import gc, os, copy
from logging import getLogger
from tqdm import tqdm

from .modeling_utils import recurse_setattr, recurse_getattr

logger = getLogger(__name__)

def prepare_model_for_marlin_load(
Expand Down Expand Up @@ -51,6 +53,16 @@ def prepare_model_for_marlin_load(
# Convert model to marlin, repacking weights into Marlin format.
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=True)

# Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202
tied_params = find_tied_parameters(model)

for weight_group in tied_params:
for param_name in weight_group:
if isinstance(recurse_getattr(model, param_name), torch.nn.Parameter):
recurse_setattr(model, param_name, torch.nn.Parameter(recurse_getattr(model, param_name).clone()))
else:
recurse_setattr(model, param_name, recurse_getattr(model, param_name).clone())

# Cache the converted model.
safe_save(model.state_dict(), model_save_name)

Expand Down
26 changes: 26 additions & 0 deletions auto_gptq/utils/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import functools

def recurse_getattr(obj, attr: str):
"""
Recursive `getattr`.
Args:
obj:
A class instance holding the attribute.
attr (`str`):
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
"""

def _getattr(obj, attr):
return getattr(obj, attr)

return functools.reduce(_getattr, [obj] + attr.split("."))


def recurse_setattr(module, name, value):
"""A function to recursively set attributes to a module."""
if "." not in name:
setattr(module, name, value)
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@

extensions.append(
cpp_extension.CUDAExtension(
'marlin_cuda',
'autogptq_marlin_cuda',
[
'autogptq_extension/marlin/marlin_cuda.cpp',
'autogptq_extension/marlin/marlin_cuda_kernel.cu'
Expand Down

0 comments on commit 682ceb0

Please sign in to comment.