Skip to content

Commit

Permalink
ERNIE 3.0 serving deploy update (PaddlePaddle#2146)
Browse files Browse the repository at this point in the history
* optimized code

* optimized code

* optimized code2

Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
  • Loading branch information
heliqi and ZeyuChen committed May 15, 2022
1 parent 9f00755 commit ba57969
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 44 deletions.
122 changes: 96 additions & 26 deletions model_zoo/ernie-3.0/deploy/serving/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# ERNIE-3.0服务化部署
# 基于Paddle Serving的服务化部署

本文档将介绍如何使用[Paddle Serving](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)工具部署ERNIE 3.0新闻分类和命名实体识别模型的pipeline在线服务。

## 目录
- [环境准备](#环境准备)
- [模型转换](#模型转换)
- [部署模型](#部署模型)

## 环境准备
需要[准备PaddleNLP的运行环境]()和Paddle Serving的运行环境。

### 安装Paddle Serving
### 安装Paddle Serving
安装指令如下,更多wheel包请参考[serving官网文档](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Latest_Packages_CN.md)
```
# 安装client和serving app,用于向服务发送请求
pip install paddle_serving_app paddle_serving_clinet
# 安装serving,用于启动服务
# CPU server
pip install paddle_serving_server
# GPU server, 需要确认环境再选择执行哪一条:
# GPU server, 选择跟本地环境一致的命令:
# CUDA10.2 + Cudnn7 + TensorRT6
pip install paddle-serving-server-gpu==0.8.3.post102 -i https://pypi.tuna.tsinghua.edu.cn/simple
# CUDA10.1 + TensorRT6
Expand All @@ -22,41 +32,67 @@ pip install paddle-serving-server-gpu==0.8.3.post112 -i https://pypi.tuna.tsingh
默认开启国内清华镜像源来加速下载,如果您使用 HTTP 代理可以关闭(-i https://pypi.tuna.tsinghua.edu.cn/simple)


### 安装Paddle库
更多Paddle库下载安装可参考[Paddle官网文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/index_cn.html)
### 安装FasterTokenizer文本处理加速库(可选)
如果部署环境是Linux,推荐安装faster_tokenizer可以得到更极致的文本处理效率,进一步提升服务性能。目前暂不支持Windows设备安装,将会在下个版本支持。
```
# CPU 环境请执行
pip3 install paddlepaddle
# GPU CUDA 环境(默认CUDA10.2)
pip3 install paddlepaddle-gpu
pip install faster_tokenizers
```

## 准备模型和数据
下载[Erine-3.0模型](TODO)

### 转换模型
如果是链接中下载的部署模型或训练导出的静态图推理模型(含`xx.pdmodel``xx.pdiparams`),需要转换成serving模型
## 模型转换

使用Paddle Serving做服务化部署时,需要将保存的inference模型转换为serving易于部署的模型。

下载ERNIE 3.0的新闻分类、命名实体识别模型:

```bash
# 下载并解压新闻分类模型
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/tnews_pruned_infer_model.zip
unzip tnews_pruned_infer_model.zip
# 下载并解压命名实体识别模型
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/msra_ner_pruned_infer_model.zip
unzip msra_ner_pruned_infer_model.zip
```

用已安装的paddle_serving_client将inference模型转换成serving格式。

```bash
# 模型地址根据实际填写即可
python -m paddle_serving_client.convert --dirname models/erinie-3.0 --model_filename infer.pdmodel --params_filename infer.pdiparams
# 转换新闻分类模型
python -m paddle_serving_client.convert --dirname tnews_pruned_infer_model --model_filename float32.pdmodel --params_filename float32.pdiparams

# 转换命名实体识别模型
python -m paddle_serving_client.convert --dirname msra_ner_pruned_infer_model --model_filename float32.pdmodel --params_filename float32.pdiparams

# 可通过指令查看参数含义
# 可通过命令查参数含义
python -m paddle_serving_client.convert --help
```
转换成功后的目录如下:
```
serving_server
├── infer.pdiparams
├── infer.pdmodel
├── float32.pdiparams
├── float32.pdmodel
├── serving_server_conf.prototxt
└── serving_server_conf.stream.prototxt
```

## 部署模型

serving目录包含启动pipeline服务和发送预测请求的代码,包括:

```
seq_cls_config.yml # 新闻分类任务启动服务端的配置文件
seq_cls_rpc_client.py # 新闻分类任务发送pipeline预测请求的脚本
seq_cls_service.py # 新闻分类任务启动服务端的脚本
token_cls_config.yml # 命名实体识别任务启动服务端的配置文件
token_cls_rpc_client.py # 命名实体识别任务发送pipeline预测请求的脚本
token_cls_service.py # 命名实体识别任务启动服务端的脚本
```


## 服务化部署模型
### 修改配置文件
目录中的`xx_config.yml`文件解释了每一个参数的含义,可以根据实际需要修改其中的配置。比如:
目录中的`seq_cls_config.yml``token_cls_config.yml`文件解释了每一个参数的含义,可以根据实际需要修改其中的配置。比如:
```
# 修改模型目录为下载的模型目录或自己的模型目录:
model_config: no_task_emb/serving_server => model_config: erine-3.0-tiny/serving_server
Expand All @@ -70,10 +106,25 @@ device_type: 1 => device_type: 0

### 分类任务
#### 启动服务
修改好配置文件后,执行下面指令启动服务:
修改好配置文件后,执行下面命令启动服务:
```
python seq_cls_service.py
```
输出打印如下:
```
[DAG] Succ init
[PipelineServicer] succ init
--- Running analysis [ir_graph_build_pass]
......
--- Running analysis [ir_graph_to_program_pass]
I0515 05:36:48.316895 62364 analysis_predictor.cc:714] ======= optimize end =======
I0515 05:36:48.320442 62364 naive_executor.cc:98] --- skip [feed], feed -> token_type_ids
I0515 05:36:48.320463 62364 naive_executor.cc:98] --- skip [feed], feed -> input_ids
I0515 05:36:48.321842 62364 naive_executor.cc:98] --- skip [linear_113.tmp_1], fetch -> fetch
[2022-05-15 05:36:49,316] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'ernie-3.0-medium-zh'.
[2022-05-15 05:36:49,317] [ INFO] - Already cached /vdb1/home/heliqi/.paddlenlp/models/ernie-3.0-medium-zh/ernie_3.0_medium_zh_vocab.txt
[OP Object] init success
```

#### 启动client测试
注意执行客户端请求时关闭代理,并根据实际情况修改init_client函数中的ip地址(启动服务所在的机器)
Expand All @@ -82,32 +133,51 @@ python seq_cls_rpc_client.py
```
输出打印如下:
```
{'label': array([6, 2]), 'confidence': array([4.9473147, 5.7493963], dtype=float32)}
acc: 0.5745
{'label': array([6, 2]), 'confidence': array([0.5543532, 0.9495907], dtype=float32)}acc: 0.5745
```

### 实体识别任务
#### 启动服务
修改好配置文件后,执行下面指令启动服务:
修改好配置文件后,执行下面命令启动服务:
```
python token_cls_service.py
```
输出打印如下:
```
[DAG] Succ init
[PipelineServicer] succ init
--- Running analysis [ir_graph_build_pass]
......
--- Running analysis [ir_graph_to_program_pass]
I0515 05:36:48.316895 62364 analysis_predictor.cc:714] ======= optimize end =======
I0515 05:36:48.320442 62364 naive_executor.cc:98] --- skip [feed], feed -> token_type_ids
I0515 05:36:48.320463 62364 naive_executor.cc:98] --- skip [feed], feed -> input_ids
I0515 05:36:48.321842 62364 naive_executor.cc:98] --- skip [linear_113.tmp_1], fetch -> fetch
[2022-05-15 05:36:49,316] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'ernie-3.0-medium-zh'.
[2022-05-15 05:36:49,317] [ INFO] - Already cached /vdb1/home/heliqi/.paddlenlp/models/ernie-3.0-medium-zh/ernie_3.0_medium_zh_vocab.txt
[OP Object] init success
```

#### 启动client测试
注意执行客户端请求时关闭代理,并根据实际情况修改init_client函数中的ip地址(启动服务所在的机器)
```
python seq_cls_rpc_client.py
python token_cls_rpc_client.py
```
输出打印如下:
```
input data: 古老的文明,使我们引以为豪,彼此钦佩
input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食
The model detects all entities:
entity: 北京 label: LOC pos: [0, 1]
entity: 重庆 label: LOC pos: [6, 7]
entity: 成都 label: LOC pos: [12, 13]
-----------------------------
input data: 原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。
The model detects all entities:
entity: 玛雅 label: LOC pos: [2, 3]
entity: 华夏 label: LOC pos: [14, 15]
-----------------------------
PipelineClient::predict pack_data time:1652593013.713769
PipelineClient::predict before time:1652593013.7141528
input data: ['从', '首', '都', '利', '隆', '圭', '乘', '车', '向', '湖', '边', '小', '镇', '萨', '利', '马', '进', '发', '时', ',', '不', '到', '1', '0', '0', '公', '里', '的', '道', '路', '上', '坑', '坑', '洼', '洼', ',', '又', '逢', '阵', '雨', '迷', '蒙', ',', '令', '人', '不', '时', '发', '出', '路', '难', '行', '的', '慨', '叹', '。']
The model detects all entities:
entity: 利隆圭 label: LOC pos: [3, 5]
Expand Down
5 changes: 2 additions & 3 deletions model_zoo/ernie-3.0/deploy/serving/seq_cls_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ op:
client_type: local_predictor

#模型路径
model_config: /vdb1/home/heliqi/tnews_0421/original_fp32_serving_server
# model_config: no_task_emb/serving_server
model_config: serving_server

#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["linear_75.tmp_1"]
fetch_list: ["linear_113.tmp_1"]

# device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
device_type: 1
Expand Down
11 changes: 8 additions & 3 deletions model_zoo/ernie-3.0/deploy/serving/seq_cls_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class ErnieSeqClsOp(Op):
def init_op(self):
from paddlenlp.transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-medium-zh")
# Output nodes may differ from model to model
# You can see the output node name in the conf.prototxt file of serving_server
self.fetch_names = ["linear_113.tmp_1", ]

def preprocess(self, input_dicts, data_id, log_id):
# convert input format
Expand Down Expand Up @@ -64,11 +67,13 @@ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
It is handled in the same way as exception.
prod_errinfo: "" default
"""
result = fetch_dict["linear_75.tmp_1"]
# np.argpartition
result = fetch_dict[self.fetch_names[0]]
max_value = np.max(result, axis=1, keepdims=True)
exp_data = np.exp(result - max_value)
probs = exp_data / np.sum(exp_data, axis=1, keepdims=True)
out_dict = {
"label": result.argmax(axis=-1),
"confidence": result.max(axis=-1)
"confidence": probs.max(axis=-1)
}
return out_dict, None, ""

Expand Down
5 changes: 2 additions & 3 deletions model_zoo/ernie-3.0/deploy/serving/token_cls_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ op:
client_type: local_predictor

#模型路径
model_config: /vdb1/home/heliqi/ner_0506/original_model/serving_server
# model_config: no_task_emb/serving_server
model_config: serving_server

#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["linear_75.tmp_1"]
fetch_list: ["linear_113.tmp_1"]

# device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
device_type: 1
Expand Down
3 changes: 1 addition & 2 deletions model_zoo/ernie-3.0/deploy/serving/token_cls_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def init_client():

def test_demo(client):
text1 = [
"古老的文明,使我们引以为豪,彼此钦佩。",
"北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。",
"原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。",
]
ret = client.predict(feed_dict={"tokens": text1})
Expand All @@ -118,4 +118,3 @@ def test_demo(client):
if __name__ == "__main__":
client = init_client()
test_demo(client)
# test_ner_dataset(client)
16 changes: 9 additions & 7 deletions model_zoo/ernie-3.0/deploy/serving/token_cls_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ class ErnieTokenClsOp(Op):
def init_op(self):
from paddlenlp.transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-medium-zh")
self.labele_names = [
# The label names of NER models trained by different data sets may be different
self.label_names = [
'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'
]
# Output nodes may differ from model to model
# You can see the output node name in the conf.prototxt file of serving_server
self.fetch_names = ["linear_113.tmp_1", ]

def get_input_data(self, input_dicts):
(_, input_dict), = input_dicts.items()
Expand Down Expand Up @@ -69,7 +73,6 @@ def preprocess(self, input_dicts, data_id, log_id):
is_split_into_words=is_split_into_words)

input_ids = data["input_ids"]
# print("input shape:", len(input_ids), len(input_ids[0]))
token_type_ids = data["token_type_ids"]
return {
"input_ids": np.array(
Expand All @@ -93,17 +96,16 @@ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
prod_errinfo: "" default
"""
input_data = self.get_input_data(input_dicts)
result = fetch_dict["linear_75.tmp_1"]
result = fetch_dict[self.fetch_names[0]]
tokens_label = result.argmax(axis=-1).tolist()
# 获取batch中每个token的实体
value = []
for batch, token_label in enumerate(tokens_label):
# print("label:", token_label)
start = -1
label_name = ""
items = []
for i, label in enumerate(token_label):
if label == 0 and start >= 0:
if self.label_names[label] == "O" and start >= 0:
entity = input_data[batch][start:i - 1]
if isinstance(entity, list):
entity = "".join(entity)
Expand All @@ -113,9 +115,9 @@ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
"label": label_name,
})
start = -1
elif label in [1, 3, 5]:
elif "B-" in self.label_names[label]:
start = i - 1
label_name = self.labele_names[label][2:]
label_name = self.label_names[label][2:]
if start >= 0:
items.append({
"pos": [start, len(token_label) - 1],
Expand Down

0 comments on commit ba57969

Please sign in to comment.