-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model compression API #2777
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,69 +14,45 @@ | |
|
||
import os | ||
import sys | ||
import yaml | ||
from functools import partial | ||
import distutils.util | ||
import os.path as osp | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
import paddlenlp | ||
from paddlenlp.data import DataCollatorWithPadding | ||
|
||
from paddlenlp.trainer import ( | ||
PdArgumentParser, | ||
TrainingArguments, | ||
Trainer, | ||
) | ||
from paddlenlp.data import DataCollatorWithPadding | ||
from paddlenlp.trainer import PdArgumentParser, CompressionArguments, Trainer | ||
from paddlenlp.trainer import EvalPrediction, get_last_checkpoint | ||
from paddlenlp.transformers import ( | ||
AutoTokenizer, | ||
AutoModelForQuestionAnswering, | ||
) | ||
from compress_trainer import CompressConfig, PTQConfig | ||
from paddlenlp.transformers import AutoTokenizer, AutoModelForQuestionAnswering | ||
from paddlenlp.utils.log import logger | ||
from datasets import load_metric, load_dataset | ||
|
||
sys.path.append("../ernie-1.0/finetune") | ||
from question_answering import ( | ||
QuestionAnsweringTrainer, | ||
CrossEntropyLossForSQuAD, | ||
prepare_train_features, | ||
prepare_validation_features, | ||
) | ||
from utils import ( | ||
ALL_DATASETS, | ||
DataArguments, | ||
ModelArguments, | ||
) | ||
from question_answering import QuestionAnsweringTrainer, CrossEntropyLossForSQuAD, prepare_train_features, prepare_validation_features | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里有个疑问,为什么类似 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以考虑进入框架,这里没有放的主要原因是 参考 的 huggingface。 当时可能的考虑有:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
from utils import ALL_DATASETS, DataArguments, ModelArguments | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有个疑问哈: 为什么这3个数据类型 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些是 custom 用户自定义的东西。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 跟数据,任务类型关系比较大。这里应该是 ernie-3.0 和 ernie-1.0 任务比较相似,所以共用。但是对于其他模型来讲,可能不一定适用。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,清楚了。 |
||
|
||
|
||
def main(): | ||
parser = PdArgumentParser( | ||
(ModelArguments, DataArguments, TrainingArguments)) | ||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||
(ModelArguments, DataArguments, CompressionArguments)) | ||
model_args, data_args, compression_args = parser.parse_args_into_dataclasses( | ||
) | ||
|
||
paddle.set_device(training_args.device) | ||
paddle.set_device(compression_args.device) | ||
data_args.dataset = data_args.dataset.strip() | ||
|
||
if data_args.dataset in ALL_DATASETS: | ||
# if you custom you hyper-parameters in yaml config, it will overwrite all args. | ||
config = ALL_DATASETS[data_args.dataset] | ||
for args in (model_args, data_args, training_args): | ||
for args in (model_args, data_args, compression_args): | ||
for arg in vars(args): | ||
if arg in config.keys(): | ||
setattr(args, arg, config[arg]) | ||
Comment on lines
42
to
48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一段逻辑能否加上注释说明? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 第43行有注释,如果有自定义的yaml文件,则会替换args传递来的参数; |
||
|
||
training_args.per_device_train_batch_size = config["batch_size"] | ||
training_args.per_device_eval_batch_size = config["batch_size"] | ||
compression_args.per_device_train_batch_size = config["batch_size"] | ||
compression_args.per_device_eval_batch_size = config["batch_size"] | ||
|
||
# Log model and data config | ||
training_args.print_config(model_args, "Model") | ||
training_args.print_config(data_args, "Data") | ||
compression_args.print_config(model_args, "Model") | ||
compression_args.print_config(data_args, "Data") | ||
|
||
dataset_config = data_args.dataset.split(" ") | ||
raw_datasets = load_dataset( | ||
|
@@ -102,7 +78,7 @@ def main(): | |
|
||
train_dataset = raw_datasets["train"] | ||
# Create train feature from dataset | ||
with training_args.main_process_first( | ||
with compression_args.main_process_first( | ||
desc="train dataset map pre-processing"): | ||
# Dataset pre-process | ||
train_dataset = train_dataset.map( | ||
|
@@ -115,7 +91,7 @@ def main(): | |
desc="Running tokenizer on train dataset", | ||
) | ||
eval_examples = raw_datasets["validation"] | ||
with training_args.main_process_first( | ||
with compression_args.main_process_first( | ||
desc="evaluate dataset map pre-processing"): | ||
eval_dataset = eval_examples.map( | ||
partial(prepare_validation_features, | ||
|
@@ -151,25 +127,20 @@ def post_processing_function(examples, features, predictions, stage="eval"): | |
|
||
trainer = QuestionAnsweringTrainer( | ||
model=model, | ||
args=training_args, | ||
args=compression_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
eval_examples=eval_examples, | ||
data_collator=data_collator, | ||
post_process_function=post_processing_function, | ||
tokenizer=tokenizer) | ||
|
||
output_dir = os.path.join(model_args.model_name_or_path, "compress") | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
prune = True | ||
compress_config = CompressConfig(quantization_config=PTQConfig( | ||
algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16])) | ||
trainer.compress(output_dir, | ||
pruning=prune, | ||
quantization=True, | ||
compress_config=compress_config) | ||
if not os.path.exists(compression_args.output_dir): | ||
os.makedirs(compression_args.output_dir) | ||
|
||
compression_args.print_config() | ||
|
||
trainer.compress() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这里会有 ernie-1.0 路径?