Skip to content

Commit

Permalink
[NPU] support npu llama2-13B export & inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed May 16, 2024
1 parent 05acad5 commit a59da09
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 13 deletions.
14 changes: 14 additions & 0 deletions csrc_npu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# PaddleNLP 自定义 OP

此文档介绍如何编译安装 PaddleNLP NPU 自定义 OP。

# 1. 安装 PaddleCustomDevice

参考 [PaddleCustomDevice NPU 安装文档](https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/npu/README_cn.md) 进行安装

# 2. 安装 paddlenlp_ops
```shell
python setup.py build bdist_wheel

pip install dist/paddlenlp_ops*.whl
```
15 changes: 15 additions & 0 deletions csrc_npu/python/paddlenlp_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle_custom_device.npu.ops import *
59 changes: 59 additions & 0 deletions csrc_npu/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from setuptools import Distribution, setup

packages = []
package_data = {}


class BinaryDistribution(Distribution):
def has_ext_modules(self):
return True


def main():
setup(
name="paddlenlp_ops",
version="0.0.0",
description="PaddleNLP NPU CustomOps",
long_description="",
long_description_content_type="text/markdown",
author_email="Paddle-better@baidu.com",
maintainer="PaddlePaddle",
maintainer_email="Paddle-better@baidu.com",
project_urls={},
license="Apache Software License",
packages=[
"paddlenlp_ops",
],
include_package_data=True,
package_data={
"": ["*.py"],
},
package_dir={
"": "python",
},
zip_safe=False,
distclass=BinaryDistribution,
entry_points={"console_scripts": []},
classifiers=[],
keywords="PaddleNLP NPU CustomOps",
)


if __name__ == "__main__":
main()
14 changes: 11 additions & 3 deletions llm/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def load_inference_model(model_path, model_name, param_name, exe):
return paddle.static.io.load_inference_model(model_path, exe)


def validate_pdmodel(model_path, model_prefix):
def validate_pdmodel(model_path, model_prefix, device):
paddle.enable_static()
place = paddle.CUDAPlace(0)
if device == "gpu":
place = paddle.CUDAPlace(0)
else:
place = paddle.CustomPlace(device, 0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()

Expand Down Expand Up @@ -95,7 +98,12 @@ def main():

if tensor_parallel_degree > 1:
export_args.output_path = os.path.join(export_args.output_path, f"rank_{tensor_parallel_rank}")
validate_pdmodel(export_args.output_path, predictor_args.model_prefix)
validate_pdmodel(export_args.output_path, predictor_args.model_prefix, predictor_args.device)

if predictor_args.device == "npu":
from llama.npu.export_utils import process_params

process_params(os.path.join(export_args.output_path, predictor_args.model_prefix))


if __name__ == "__main__":
Expand Down
110 changes: 110 additions & 0 deletions llm/llama/npu/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import numpy as np
import paddle
from tqdm import tqdm


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default="inference/model", help="The directory of exported model.")
return parser.parse_args()


def trans_weight(var):
shape = var.desc.shape()
new_shape = [shape[1], shape[0]]
var.desc.set_shape(new_shape)

var_data = np.array(var.get_value())
var.get_value().set(var_data.T, paddle.CPUPlace())


def convert_dequant_scale(var):
deq_scale = np.array(var.get_value()).astype(np.float32)
new_deq_scale = np.stack([deq_scale.reshape(-1, 1), np.zeros_like(deq_scale).reshape(-1, 1)], axis=-1).reshape(-1)
var.get_value().set(np.frombuffer(new_deq_scale.tobytes(), dtype=np.int64), paddle.CPUPlace())


def process_params(model_path):
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())

prog = paddle.static.Program()
startup_prog = paddle.static.Program()
scope = paddle.static.Scope()
with paddle.base.scope_guard(scope):
with paddle.base.program_guard(prog, startup_prog):
[program, feed_target_names, fetch_targets] = paddle.static.io.load_inference_model(model_path, exe)

feed_targets = []
for var in program.list_vars():
if var.name in feed_target_names:
feed_targets.append(var)

block = program.global_block()

for op in tqdm(block.ops, desc="processing the linear layer for NPU"):
if op.type == "matmul_v2":
w_name = op.input_arg_names[-1]
if w_name.endswith("qkv_weight") and op.attr("trans_y") == False:
op._set_attr("trans_y", True)
w = block.var(w_name)
trans_weight(w)
elif w_name.endswith("out_proj_weight") and op.attr("trans_y") == False:
op._set_attr("trans_y", True)
w = block.var(w_name)
trans_weight(w)
elif w_name.endswith("ffn1_weight") and op.attr("trans_y") == False:
op._set_attr("trans_y", True)
w = block.var(w_name)
trans_weight(w)
elif w_name.endswith("ffn2_weight") and op.attr("trans_y") == False:
op._set_attr("trans_y", True)
w = block.var(w_name)
trans_weight(w)
elif w_name == "llama_lm_head_0.w_0" and op.attr("trans_y") == False:
op._set_attr("trans_y", True)
w = block.var(w_name)
trans_weight(w)

for var_name in tqdm(block.vars, desc="processing the dequant layer for NPU"):
if var_name.endswith("qkv_out_scale"):
var = block.var(var_name)
convert_dequant_scale(var)
elif var_name.endswith("linear_out_scale"):
var = block.var(var_name)
convert_dequant_scale(var)
elif var_name.endswith("ffn1_out_scale"):
var = block.var(var_name)
convert_dequant_scale(var)
elif var_name.endswith("ffn2_out_scale"):
var = block.var(var_name)
convert_dequant_scale(var)

paddle.static.save_inference_model(
model_path, feed_targets, fetch_targets, exe, program=program, skip_prune_program=True
)


def main():
args = parse_arguments()
process_params(args.model_path)


if __name__ == "__main__":
main()
35 changes: 26 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,12 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.dtype == "bfloat16":
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")

device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
config.enable_new_executor()

if self.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -1076,11 +1080,23 @@ def _create_predictor(self, predictor_args: PredictorArgument):
config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")

config.switch_ir_optim(False)
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
# config.disable_glog_info()
# config.enable_memory_optim()

if predictor_args.device == "npu":
import paddle_custom_device.npu.passes as passes

config.switch_ir_optim(True)
pass_builder = config.pass_builder()
passes.addPasses(pass_builder, self.model_config.model_type, self.model_config.quant_type)
pass_builder.turn_on_debug()

if self.tensor_parallel_degree > 1:
trainer_endpoints = fleet.worker_endpoints()
current_endpoint = trainer_endpoints[self.tensor_parallel_rank]
Expand Down Expand Up @@ -1527,6 +1543,11 @@ def predict():
fleet.init(is_collective=True, strategy=strategy)

predictor = create_predictor(predictor_args, model_args)

if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)
return

source_texts = []
target_texts = []
if model_args.data_file:
Expand Down Expand Up @@ -1570,14 +1591,10 @@ def predict():
out = {"src": source, "tgt": target, "output": output}
f.write(json.dumps(out, ensure_ascii=False) + "\n")

if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)


def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
test_texts = "hello world, how are you?"
benchmark_texts = [test_texts + "<pad>" * predictor_args.src_length for _ in range(predictor_args.batch_size)]
benchmark_texts = [predictor.tokenizer.pad_token * (predictor_args.src_length - 1) for _ in range(predictor_args.batch_size)]

batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
print("***********Start Benchmark**********")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def compute_layernorm_before_qkv(self, src, i):
return ln_out

def compute_qkv_linear(self, ln_out, i):
if float(paddle.version.cuda()) < 11.6:
if paddle.version.cuda() == "False" or float(paddle.version.cuda()) < 11.6:

Check warning on line 573 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L573

Added line #L573 was not covered by tests
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
if self.qkv_biases[i] is not None:
qkv_out = paddle.add(qkv_out, self.qkv_biases[i])
Expand Down

0 comments on commit a59da09

Please sign in to comment.