Skip to content

Commit

Permalink
[Improvement] accelerate T5 model conversion and fix bloom model on m…
Browse files Browse the repository at this point in the history
…ulti-process (NVIDIA#447)

* accelerate T5 model conversion on large models

* fix the bloom error

* convert to use the same setup
  • Loading branch information
Qing Lan committed Feb 13, 2023
1 parent 7634698 commit 9b6d718
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 52 deletions.
29 changes: 15 additions & 14 deletions examples/pytorch/gpt/utils/huggingface_bloom_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ def handle_exceptions(model_config: PretrainedConfig,
return param


def convert_and_save_parameter(config: PretrainedConfig,
name: str,
param: torch.nn.Parameter,
def convert_and_save_parameter(param_name: str,
param,
tensor_para_size: Optional[int],
save_dir: PathLike):
""" Convert and save to FT parameter
Expand All @@ -228,12 +227,6 @@ def convert_and_save_parameter(config: PretrainedConfig,
save_dir: str or Path, a base directory of binary files.
"""

# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(config, param_name, param)

param = param.detach().cpu().numpy()
save_dir = Path(save_dir)

if not is_split_param(param_name):
Expand Down Expand Up @@ -325,15 +318,23 @@ def main():
f'{len(list(model.parameters()))} params')
if args.processes > 1:
pool = multiprocessing.Pool(args.processes)
pool.starmap_async(
convert_and_save_parameter,
[(model.config, name, param, tp_size, save_dir)
for name, param in model.named_parameters()])
star_args = []
for name, param in model.named_parameters():
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model.config, param_name, param)
star_args.append((param_name, param.detach().cpu().numpy(), tp_size, save_dir))
pool.starmap_async(convert_and_save_parameter, star_args)
pool.close()
pool.join()
else:
for name, param in model.named_parameters():
convert_and_save_parameter(model.config, name, param, tp_size, save_dir)
# Preprocess
param_name = convert_parameter_name(name)
param = safe_transpose(param)
param = handle_exceptions(model.config, param_name, param)
convert_and_save_parameter(param_name, param.detach().cpu().numpy(), tp_size, save_dir)
elapsed_time = time.time() - start_time
logger.info(f'Checkpoint conversion (HF >> FT) has done '
f'(elapsed time: {elapsed_time:.2f} sec)')
Expand Down
91 changes: 53 additions & 38 deletions examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

import argparse
import configparser
import multiprocessing
from datetime import datetime
import logging
from pathlib import Path

import sys
import os

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(dir_path + "/../../../../3rdparty/transformers/src/")

Expand All @@ -30,8 +32,10 @@

LOGGER = logging.getLogger(__name__)

rename_mapping={"relative_attention_num_buckets":"relative_attention_num_buckets_or_max_pos_seq_len"}
new_configs={"structure":{"t5_with_bias":"false", "use_gated_activation":"false", "position_embedding_type":"relative"}}
rename_mapping = {"relative_attention_num_buckets": "relative_attention_num_buckets_or_max_pos_seq_len"}
new_configs = {
"structure": {"t5_with_bias": "false", "use_gated_activation": "false", "position_embedding_type": "relative"}}


def get_weight_data_type(data_type):
if data_type == "fp32":
Expand All @@ -41,12 +45,13 @@ def get_weight_data_type(data_type):
else:
assert False, f"Invalid weight data type {data_type}"


def fuse_decoder_qkv(model, factor, saved_dir, np_weight_data_type):
model_dict = {}
for name, param in model.named_parameters():
if name.find("decoder") != -1 and name.find("SelfAttention") != -1:
model_dict[name] = param

for i in range(model.decoder.config.num_layers):
shape = model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].T.shape
qkv = torch.cat([model_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"].T,
Expand All @@ -55,67 +60,67 @@ def fuse_decoder_qkv(model, factor, saved_dir, np_weight_data_type):

qkv = qkv.reshape([shape[0], 3, shape[1]])
qkv = qkv.cpu().detach().numpy().astype(np_weight_data_type)

split_vals = np.split(qkv, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir / f"decoder.block.{i}.layer.0.SelfAttention.qkv.weight.{j}.bin"
split_vals[j].tofile(saved_path.as_posix())

def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type):


def split_and_convert_process(key, val, factor, saved_dir):
if val.dim() == 2:
val = val.transpose(1, 0)
val = val.cpu().detach().numpy().astype(np_weight_data_type)
saved_key = key
LOGGER.debug(f"key: {key}, val.shape: {val.shape}")

if key.find("shared.weight") != -1:
# shared weights, only need to convert the weights of rank 0
saved_path = saved_dir / f"{saved_key}.bin"
val.tofile(saved_path.as_posix())

saved_path = saved_dir / f"{saved_key}_T.bin"
val.T.tofile(saved_path.as_posix())
elif key.find("lm_head.weight") != -1:
# lm_head weights, only need to convert the weights of rank 0
val = val.transpose(1, 0) # For lm_head, we use TN gemm to compute, so we don't need to transpose
val = val.transpose(1, 0) # For lm_head, we use TN gemm to compute, so we don't need to transpose
saved_path = saved_dir / f"{saved_key}.bin"
val.tofile(saved_path.as_posix())

elif key.find("layer_norm.weight") != -1:
# shared weights, only need to convert the weights of rank 0
saved_path = saved_dir / f"{saved_key}.bin"
val.tofile(saved_path.as_posix())

elif (
key.find("SelfAttention.o.weight") != -1
or key.find("EncDecAttention.o.weight") != -1
or key.find("DenseReluDense.wo.weight") != -1
):
key.find("SelfAttention.o.weight") != -1
or key.find("EncDecAttention.o.weight") != -1
or key.find("DenseReluDense.wo.weight") != -1
):
split_vals = np.split(val, factor, axis=0)
for j in range(factor):
saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
split_vals[j].tofile(saved_path.as_posix())

elif (
key.find("DenseReluDense.wi.weight") != -1
or (key.find("encoder") != -1 and (
key.find("DenseReluDense.wi.weight") != -1
or (key.find("encoder") != -1 and (
key.find("SelfAttention.q.weight") != -1
or key.find("SelfAttention.k.weight") != -1
or key.find("SelfAttention.v.weight") != -1
)
)
)
or key.find("EncDecAttention.q.weight") != -1
or key.find("EncDecAttention.k.weight") != -1
or key.find("EncDecAttention.v.weight") != -1
):
or key.find("EncDecAttention.q.weight") != -1
or key.find("EncDecAttention.k.weight") != -1
or key.find("EncDecAttention.v.weight") != -1
):
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
split_vals[j].tofile(saved_path.as_posix())
elif (
key.find("DenseReluDense.wi_0.weight") != -1
or key.find("DenseReluDense.wi_1.weight") != -1
):
key.find("DenseReluDense.wi_0.weight") != -1
or key.find("DenseReluDense.wi_1.weight") != -1
):
# For gated activation.
if key.find("DenseReluDense.wi_0.weight") != -1:
saved_key = key.replace("wi_0", "wi")
Expand All @@ -131,20 +136,21 @@ def split_and_convert_process(key, val, factor, saved_dir, np_weight_data_type):
saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
split_vals[j].tofile(saved_path.as_posix())
elif (
key.find("decoder") != -1 and
(
key.find("SelfAttention.q.weight") != -1
or key.find("SelfAttention.k.weight") != -1
or key.find("SelfAttention.v.weight") != -1
)
):
key.find("decoder") != -1 and
(
key.find("SelfAttention.q.weight") != -1
or key.find("SelfAttention.k.weight") != -1
or key.find("SelfAttention.v.weight") != -1
)
):
pass
elif key.find("encoder.embed_tokens.weight") != -1 or \
key.find("decoder.embed_tokens.weight") != -1:
key.find("decoder.embed_tokens.weight") != -1:
LOGGER.warning(f"Not save {key}, using shared.weight directly.")
else:
LOGGER.warning(f"cannot find key '{key}' with shape {val.shape}")


def convert_checkpoint(args):
saved_dir = Path(args.saved_dir) / f"{args.inference_tensor_para_size:d}-gpu"
saved_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -180,20 +186,29 @@ def convert_checkpoint(args):
with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile:
config.write(configfile)
np_weight_data_type = get_weight_data_type(args.weight_data_type)

i_gpu_num = args.inference_tensor_para_size

for name, param in t5_model.state_dict().items():
split_and_convert_process(name, param, i_gpu_num, saved_dir, np_weight_data_type)

pool = multiprocessing.Pool(args.processes)
pool.starmap_async(split_and_convert_process,
[(name, param.cpu().detach().numpy().astype(np_weight_data_type), i_gpu_num, saved_dir)
for name, param in t5_model.state_dict().items()])

pool.close()
pool.join()

if not args.encoder_only:
fuse_decoder_qkv(t5_model, i_gpu_num, saved_dir, np_weight_data_type)


if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True)
parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True)
parser.add_argument("-inference_tensor_para_size", "-i_g", type=int, help="How many gpus for inference", required=True)
parser.add_argument("-inference_tensor_para_size", "-i_g", type=int, help="How many gpus for inference",
required=True)
parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)",
default=4)
parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"])
parser.add_argument("--encoder_only", "-e", action="store_true")
parser.add_argument("--verbose", action="store_true", help="Provide verbose messages")
Expand All @@ -205,7 +220,7 @@ def convert_checkpoint(args):
LOGGER.info(f"{key}: {vars(args)[key]}")
LOGGER.info("========================================")

start_time = datetime.now()
start_time = datetime.now()
convert_checkpoint(args)
stop_time = datetime.now()
run_time = (stop_time - start_time)
Expand Down

0 comments on commit 9b6d718

Please sign in to comment.