Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support wildcard for source files in convert utilities #399

Merged
merged 2 commits into from
Jul 2, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support wildcard for source files
  • Loading branch information
guillaume-be committed Jun 25, 2023
commit a1e4beab7b5a0e12e00c77945b7563521aaca207
96 changes: 55 additions & 41 deletions utils/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@


import argparse
import numpy as np
import glob
import logging
import subprocess
import sys
import torch

from pathlib import Path

import numpy as np
import torch
from torch import Tensor


Expand All @@ -32,15 +34,22 @@ def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
nan_mask = np.logical_and(byte_array, 0x7FFF_FFFF) > 0x7F80_0000
round_bit = 0x0000_8000
output_val = np.right_shift(byte_array, 16)
threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (np.logical_and(byte_array, (3*round_bit-1)) != 0)
output = np.where(nan_mask, nan_value, np.where(threshold_mask, output_val+1, output_val)).astype(np.uint16)
threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (
np.logical_and(byte_array, (3 * round_bit - 1)) != 0
)
output = np.where(
nan_mask, nan_value, np.where(threshold_mask, output_val + 1, output_val)
).astype(np.uint16)
return output


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"source_file", nargs="+", help="Absolute path to the Pytorch weights file to convert"
"source_file",
nargs="+",
help="""Absolute path (or file pattern) to the Pytorch weights file(s) to convert.
A single file, list of files, glob pattern or list of glob patterns can be provided.""",
)
parser.add_argument(
"--skip_embeddings",
Expand Down Expand Up @@ -70,41 +79,46 @@ def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
nps = {}
target_folder = Path(args.source_file[0]).parent

for source_file in args.source_file:
source_file = Path(source_file)
weights = torch.load(str(source_file), map_location="cpu")

for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
if v.dtype == torch.bfloat16:
tensor = get_bf16_repr(v)
else:
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype)))
for source_file_or_pattern in args.source_file:
source_files = glob.glob(source_file_or_pattern)
for source_file in source_files:
logging.info(f"Processing source file {source_file}...")
source_file = Path(source_file)
weights = torch.load(str(source_file), map_location="cpu")

for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
if v.dtype == torch.bfloat16:
tensor = get_bf16_repr(v)
else:
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(
tensor.astype(np.dtype(args.dtype))
)
else:
nps[k] = np.ascontiguousarray(tensor)
logging.info(f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes")
else:
nps[k] = np.ascontiguousarray(tensor)
print(f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes")
else:
print(f"skipped non-tensor object: {k}")
logging.info(f"skipped non-tensor object: {k}")
np.savez(target_folder / "model.npz", **nps)

source = str(target_folder / "model.npz")
Expand All @@ -119,7 +133,7 @@ def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
"--",
source,
target,
]
]
if args.download_libtorch:
cargo_args += ["--features", "download-libtorch"]
subprocess.run(cargo_args)
Loading