Skip to content

Commit

Permalink
Default to safetensors for quantized models (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 4, 2023
1 parent 958678d commit 84a2686
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 38 deletions.
2 changes: 1 addition & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False,
batch_size=1, safetensors=True,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
Expand Down
12 changes: 7 additions & 5 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def quantize(self, tokenizer=None, quant_config={},
def fuse_layers(model):
pass

def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

# Save model
Expand All @@ -67,7 +67,9 @@ def forward(self, x): return x
self.quant_config.save_pretrained(save_dir)

# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
default_path = f'{save_dir}/model.safetensors'
if os.path.exists(default_path):
os.remove(default_path)

# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
Expand Down Expand Up @@ -130,7 +132,7 @@ def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = tor
@classmethod
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
trust_remote_code=True, safetensors=True, is_quantized=True,
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
Expand Down Expand Up @@ -180,11 +182,11 @@ def from_quantized(self, model_path, model_type, model_filename='',

return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

def _load_config(self, model_path, model_filename, safetensors=False,
def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/basic_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

Expand Down
1 change: 0 additions & 1 deletion examples/basic_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
# NOTE: pass safetensors=True to save quantized model weights as safetensors
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

Expand Down
28 changes: 0 additions & 28 deletions examples/basic_safetensors_generate.py

This file was deleted.

2 changes: 1 addition & 1 deletion examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main(args):
parser.add_argument("--model_path", type=str, default="casperhansen/mistral-7b-instruct-v0.1-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
parser.add_argument("--safetensors", default=True, action="store_false", help="Use for disabling safetensors")
args = parser.parse_args()

main(args)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

requirements = [
"torch>=2.1.0",
"transformers>=4.34.0",
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"accelerate",
"sentencepiece",
Expand Down

0 comments on commit 84a2686

Please sign in to comment.