diff --git a/model_zoo/ernie-3.0/compress_qa.py b/model_zoo/ernie-3.0/compress_qa.py index 4fc101e2bf68..5cef05942400 100644 --- a/model_zoo/ernie-3.0/compress_qa.py +++ b/model_zoo/ernie-3.0/compress_qa.py @@ -14,69 +14,45 @@ import os import sys -import yaml from functools import partial -import distutils.util -import os.path as osp -from typing import Optional -import numpy as np import paddle -import paddle.nn as nn -import paddle.nn.functional as F -import paddlenlp -from paddlenlp.data import DataCollatorWithPadding -from paddlenlp.trainer import ( - PdArgumentParser, - TrainingArguments, - Trainer, -) +from paddlenlp.data import DataCollatorWithPadding +from paddlenlp.trainer import PdArgumentParser, CompressionArguments, Trainer from paddlenlp.trainer import EvalPrediction, get_last_checkpoint -from paddlenlp.transformers import ( - AutoTokenizer, - AutoModelForQuestionAnswering, -) -from compress_trainer import CompressConfig, PTQConfig +from paddlenlp.transformers import AutoTokenizer, AutoModelForQuestionAnswering from paddlenlp.utils.log import logger from datasets import load_metric, load_dataset sys.path.append("../ernie-1.0/finetune") -from question_answering import ( - QuestionAnsweringTrainer, - CrossEntropyLossForSQuAD, - prepare_train_features, - prepare_validation_features, -) -from utils import ( - ALL_DATASETS, - DataArguments, - ModelArguments, -) +from question_answering import QuestionAnsweringTrainer, CrossEntropyLossForSQuAD, prepare_train_features, prepare_validation_features +from utils import ALL_DATASETS, DataArguments, ModelArguments def main(): parser = PdArgumentParser( - (ModelArguments, DataArguments, TrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + (ModelArguments, DataArguments, CompressionArguments)) + model_args, data_args, compression_args = parser.parse_args_into_dataclasses( + ) - paddle.set_device(training_args.device) + paddle.set_device(compression_args.device) data_args.dataset = data_args.dataset.strip() if data_args.dataset in ALL_DATASETS: # if you custom you hyper-parameters in yaml config, it will overwrite all args. config = ALL_DATASETS[data_args.dataset] - for args in (model_args, data_args, training_args): + for args in (model_args, data_args, compression_args): for arg in vars(args): if arg in config.keys(): setattr(args, arg, config[arg]) - training_args.per_device_train_batch_size = config["batch_size"] - training_args.per_device_eval_batch_size = config["batch_size"] + compression_args.per_device_train_batch_size = config["batch_size"] + compression_args.per_device_eval_batch_size = config["batch_size"] # Log model and data config - training_args.print_config(model_args, "Model") - training_args.print_config(data_args, "Data") + compression_args.print_config(model_args, "Model") + compression_args.print_config(data_args, "Data") dataset_config = data_args.dataset.split(" ") raw_datasets = load_dataset( @@ -102,7 +78,7 @@ def main(): train_dataset = raw_datasets["train"] # Create train feature from dataset - with training_args.main_process_first( + with compression_args.main_process_first( desc="train dataset map pre-processing"): # Dataset pre-process train_dataset = train_dataset.map( @@ -115,7 +91,7 @@ def main(): desc="Running tokenizer on train dataset", ) eval_examples = raw_datasets["validation"] - with training_args.main_process_first( + with compression_args.main_process_first( desc="evaluate dataset map pre-processing"): eval_dataset = eval_examples.map( partial(prepare_validation_features, @@ -151,7 +127,7 @@ def post_processing_function(examples, features, predictions, stage="eval"): trainer = QuestionAnsweringTrainer( model=model, - args=training_args, + args=compression_args, train_dataset=train_dataset, eval_dataset=eval_dataset, eval_examples=eval_examples, @@ -159,17 +135,12 @@ def post_processing_function(examples, features, predictions, stage="eval"): post_process_function=post_processing_function, tokenizer=tokenizer) - output_dir = os.path.join(model_args.model_name_or_path, "compress") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - prune = True - compress_config = CompressConfig(quantization_config=PTQConfig( - algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16])) - trainer.compress(output_dir, - pruning=prune, - quantization=True, - compress_config=compress_config) + if not os.path.exists(compression_args.output_dir): + os.makedirs(compression_args.output_dir) + + compression_args.print_config() + + trainer.compress() if __name__ == "__main__": diff --git a/model_zoo/ernie-3.0/compress_seq_cls.py b/model_zoo/ernie-3.0/compress_seq_cls.py index 82c3b5e53963..8fa0250bf4e2 100644 --- a/model_zoo/ernie-3.0/compress_seq_cls.py +++ b/model_zoo/ernie-3.0/compress_seq_cls.py @@ -14,61 +14,46 @@ import os import sys -import yaml from functools import partial import paddle from paddlenlp.data import DataCollatorWithPadding from paddlenlp.datasets import load_dataset -from paddlenlp.trainer import ( - PdArgumentParser, - TrainingArguments, - Trainer, -) - -from paddlenlp.transformers import ( - AutoTokenizer, - AutoModelForSequenceClassification, -) +from paddlenlp.trainer import PdArgumentParser, Trainer, CompressionArguments +from paddlenlp.transformers import AutoTokenizer, AutoModelForSequenceClassification from paddlenlp.utils.log import logger -from compress_trainer import CompressConfig, PTQConfig - sys.path.append("../ernie-1.0/finetune") - from sequence_classification import seq_trans_fn, clue_trans_fn -from utils import ( - ALL_DATASETS, - DataArguments, - ModelArguments, -) +from utils import ALL_DATASETS, DataArguments, ModelArguments def main(): parser = PdArgumentParser( - (ModelArguments, DataArguments, TrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + (ModelArguments, DataArguments, CompressionArguments)) + model_args, data_args, compression_args = parser.parse_args_into_dataclasses( + ) - paddle.set_device(training_args.device) + paddle.set_device(compression_args.device) data_args.dataset = data_args.dataset.strip() if data_args.dataset in ALL_DATASETS: - # if you custom you hyper-parameters in yaml config, it will overwrite all args. + # If you custom you hyper-parameters in yaml config, it will overwrite all args. config = ALL_DATASETS[data_args.dataset] - logger.info("Over-writing training config by yaml config!") - for args in (model_args, data_args, training_args): + logger.info("Over-writing compression config by yaml config!") + for args in (model_args, data_args, compression_args): for arg in vars(args): if arg in config.keys(): setattr(args, arg, config[arg]) - training_args.per_device_train_batch_size = config["batch_size"] - training_args.per_device_eval_batch_size = config["batch_size"] + compression_args.per_device_train_batch_size = config["batch_size"] + compression_args.per_device_eval_batch_size = config["batch_size"] # Log model and data config - training_args.print_config(model_args, "Model") - training_args.print_config(data_args, "Data") + compression_args.print_config(model_args, "Model") + compression_args.print_config(data_args, "Data") dataset_config = data_args.dataset.split(" ") raw_datasets = load_dataset( @@ -81,42 +66,37 @@ def main(): raw_datasets['train'].label_list) criterion = paddle.nn.CrossEntropyLoss() - # Define tokenizer, model, loss function. + # Defines tokenizer, model, loss function. tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) model = AutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, num_classes=num_classes) - # Define dataset pre-process function + # Defines dataset pre-process function if "clue" in data_args.dataset: trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args) else: trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args) - # Define data collector + # Defines data collector data_collator = DataCollatorWithPadding(tokenizer) train_dataset = raw_datasets["train"].map(trans_fn) eval_dataset = raw_datasets["dev"].map(trans_fn) - trainer = Trainer(model=model, - args=training_args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - criterion=criterion) - - output_dir = os.path.join(model_args.model_name_or_path, "compress") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - compress_config = CompressConfig(quantization_config=PTQConfig( - algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16])) - - trainer.compress(output_dir, - pruning=True, - quantization=True, - compress_config=compress_config) + trainer = Trainer( + model=model, + args=compression_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + criterion=criterion) # Strategy`dynabert` needs arguments `criterion` + + compression_args.print_config() + + if not os.path.exists(compression_args.output_dir): + os.makedirs(compression_args.output_dir) + + trainer.compress() if __name__ == "__main__": diff --git a/model_zoo/ernie-3.0/compress_token_cls.py b/model_zoo/ernie-3.0/compress_token_cls.py index fbefc20cd8a4..529376f9b321 100644 --- a/model_zoo/ernie-3.0/compress_token_cls.py +++ b/model_zoo/ernie-3.0/compress_token_cls.py @@ -14,50 +14,66 @@ import os import sys -import yaml from functools import partial -import distutils.util -import os.path as osp -from typing import Optional -import numpy as np import paddle import paddle.nn as nn -import paddle.nn.functional as F from datasets import load_dataset -import paddlenlp from paddlenlp.data import DataCollatorForTokenClassification -from paddlenlp.trainer import ( - PdArgumentParser, - TrainingArguments, - Trainer, -) - -from paddlenlp.transformers import ( - AutoTokenizer, - AutoModelForTokenClassification, -) +from paddlenlp.trainer import PdArgumentParser, CompressionArguments, Trainer +from paddlenlp.transformers import AutoTokenizer, AutoModelForTokenClassification from paddlenlp.utils.log import logger -from compress_trainer import CompressConfig, PTQConfig - sys.path.append("../ernie-1.0/finetune") -from token_classification import ner_trans_fn -from utils import ( - ALL_DATASETS, - DataArguments, - ModelArguments, -) +from utils import ALL_DATASETS, DataArguments, ModelArguments + + +def tokenize_and_align_labels(example, + tokenizer, + no_entity_id, + max_seq_len=512): + if example['tokens'] == []: + tokenized_input = { + 'labels': [], + 'input_ids': [], + 'token_type_ids': [], + 'seq_len': 0, + 'length': 0, + } + return tokenized_input + tokenized_input = tokenizer( + example['tokens'], + max_seq_len=max_seq_len, + # We use this argument because the texts in our dataset are lists of words (with a label for each word). + is_split_into_words=True, + return_length=True) + label_ids = example['ner_tags'] + if len(tokenized_input['input_ids']) - 2 < len(label_ids): + label_ids = label_ids[:len(tokenized_input['input_ids']) - 2] + label_ids = [no_entity_id] + label_ids + [no_entity_id] + + label_ids += [no_entity_id + ] * (len(tokenized_input['input_ids']) - len(label_ids)) + tokenized_input["labels"] = label_ids + return tokenized_input + + +def ner_trans_fn(example, tokenizer, args): + return tokenize_and_align_labels(example, + tokenizer=tokenizer, + no_entity_id=args.no_entity_id, + max_seq_len=args.max_seq_length) def main(): parser = PdArgumentParser( - (ModelArguments, DataArguments, TrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + (ModelArguments, DataArguments, CompressionArguments)) + model_args, data_args, compression_args = parser.parse_args_into_dataclasses( + ) - paddle.set_device(training_args.device) + paddle.set_device(compression_args.device) data_args.dataset = data_args.dataset.strip() if data_args.dataset not in ALL_DATASETS: @@ -66,17 +82,17 @@ def main(): if data_args.dataset in ALL_DATASETS: # if you custom you hyper-parameters in yaml config, it will overwrite all args. config = ALL_DATASETS[data_args.dataset] - for args in (model_args, data_args, training_args): + for args in (model_args, data_args, compression_args): for arg in vars(args): if arg in config.keys(): setattr(args, arg, config[arg]) - training_args.per_device_train_batch_size = config["batch_size"] - training_args.per_device_eval_batch_size = config["batch_size"] + compression_args.per_device_train_batch_size = config["batch_size"] + compression_args.per_device_eval_batch_size = config["batch_size"] # Log model and data config - training_args.print_config(model_args, "Model") - training_args.print_config(data_args, "Data") + compression_args.print_config(model_args, "Model") + compression_args.print_config(data_args, "Data") dataset_config = data_args.dataset.split(" ") raw_datasets = load_dataset( @@ -100,7 +116,7 @@ class criterion(nn.Layer): def __init__(self): super(criterion, self).__init__() - self.loss_fn = paddle.nn.loss.CrossEntropyLoss( + self.loss_fn = nn.CrossEntropyLoss( ignore_index=data_args.ignore_label) def forward(self, *args, **kwargs): @@ -124,27 +140,20 @@ def forward(self, *args, **kwargs): eval_dataset = raw_datasets["test"].map(trans_fn, remove_columns=column_names) - trainer = Trainer(model=model, criterion=loss_fct, - args=training_args, + args=compression_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer) - output_dir = os.path.join(model_args.model_name_or_path, "compress") - - if not os.path.exists(output_dir): - os.makedirs(output_dir) + if not os.path.exists(compression_args.output_dir): + os.makedirs(compression_args.output_dir) - compress_config = CompressConfig(quantization_config=PTQConfig( - algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16])) + compression_args.print_config() - trainer.compress(output_dir, - pruning=True, - quantization=True, - compress_config=compress_config) + trainer.compress() if __name__ == "__main__": diff --git a/model_zoo/ernie-3.0/config.yml b/model_zoo/ernie-3.0/config.yml index 68f077512f56..b44179cdec1c 100644 --- a/model_zoo/ernie-3.0/config.yml +++ b/model_zoo/ernie-3.0/config.yml @@ -16,6 +16,11 @@ DefaultArgs: do_train: True save_steps: 100 max_answer_length: 50 + strategy: "dynabert+ptq" + algo_list: ["hist", "mse"] + batch_num_list: [1] + batch_size_list: [4, 8, 16] + width_mult_list: [0.75] # Datasets which used for sequence classfication SequenceClassification: clue afqmc: diff --git a/paddlenlp/trainer/__init__.py b/paddlenlp/trainer/__init__.py index 445f516d7e13..1c5bdad43c3c 100644 --- a/paddlenlp/trainer/__init__.py +++ b/paddlenlp/trainer/__init__.py @@ -14,6 +14,8 @@ from .argparser import * from .training_args import * +from .compression_args import * from .trainer_base import * from .trainer_callback import * from .trainer_utils import * +from .trainer_compress import * \ No newline at end of file diff --git a/paddlenlp/trainer/compression_args.py b/paddlenlp/trainer/compression_args.py new file mode 100644 index 000000000000..d2f3f85d6eb4 --- /dev/null +++ b/paddlenlp/trainer/compression_args.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import types +from typing import List, Optional + +import paddle + +from ..utils.log import logger +from .training_args import TrainingArguments + +__all__ = [ + "CompressionArguments", +] + + +@dataclass +class CompressionArguments(TrainingArguments): + """ + CompressionArguments is the subset of the arguments we use in our example + scripts **which relate to the training loop itself**. + + Using [`PdArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) + arguments that can be specified on the command line. + + Parameters: + strategy (`str`): + Compression strategy. It supports 'dynabert+ptq', 'dynabert' and 'ptq' now. + """ + strategy: Optional[str] = field( + default="dynabert+ptq", + metadata={ + "help": + "Compression strategy. It supports 'dynabert+ptq', 'dynabert' and 'ptq' now." + }, + ) + # dynabert + width_mult_list: Optional[List[float]] = field( + default=None, + metadata={ + "help": + ("List of width multiplicator for pruning using DynaBERT strategy.") + }, + ) + # ptq: + algo_list: Optional[List[str]] = field( + default=None, + metadata={ + "help": + "Algorithm list for Post-Quantization, and it supports 'hist', 'KL', " \ + "'mse', 'avg', 'abs_max' and 'emd'.'KL' uses KL-divergenc method to get " \ + "the KL threshold for quantized activations and get the abs_max value " \ + "forquantized weights. 'abs_max' gets the abs max value for activations " \ + "and weights. 'min_max' gets the min and max value for quantized " \ + "activations and weights. 'avg' gets the average value among the max " \ + "values for activations. 'hist' gets the value of 'hist_percent' " \ + "quantile as the threshold. 'mse' gets the value which makes the " \ + "quantization mse loss minimal. Defaults to ['hist']." + }, ) + + batch_num_list: Optional[List[int]] = field( + default=None, + metadata={ + "help": + "List of batch_num. 'batch_num' is the number of batchs for sampling. " \ + "the number of calibrate data is batch_size * batch_nums. " \ + "If batch_nums is None, use all data provided by data loader as calibrate data." + }, ) + batch_size_list: Optional[List[int]] = field( + default=None, + metadata={ + "help": + "List of batch_size. 'batch_size' is the batch of data loader" + }, + ) + weight_quantize_type: Optional[str] = field( + default='channel_wise_abs_max', + metadata={ + "help": + "Support 'abs_max' and 'channel_wise_abs_max'. This param only specifies " \ + "the fake ops in saving quantized model, and we save the scale obtained " \ + "by post training quantization in fake ops. Compared to 'abs_max', " \ + "the model accuracy is usually higher when it is 'channel_wise_abs_max'." + }, ) + round_type: Optional[str] = field( + default='round', + metadata={ + "help": + "The method of converting the quantized weights value float->int. " \ + "Currently supports ['round', 'adaround'] methods. Default is `round`, " \ + "which is rounding nearest to the integer. 'adaround' is refer to " \ + "https://arxiv.org/abs/2004.10568." + }, ) + bias_correction: Optional[bool] = field( + default=False, + metadata={ + "help": + "If set to True, use the bias correction method of " \ + "https://arxiv.org/abs/1810.05723. Default is False." + }, + ) + infer_model_path: Optional[str] = field( + default=None, + metadata={ + "help": + "If you have only inference model, quantization is also supported." \ + " The format is `dirname/file_prefix` or `file_prefix`. Default " \ + "is None." + }, + ) + + def print_config(self, args=None, key=""): + """ + Prints all config values. + """ + + compression_arg_name = [ + 'width_mult_list', 'batch_num_list', 'bias_correction', + 'round_type', 'algo_list', 'batch_size_list', 'strategy', + 'weight_quantize_type', 'infer_model_path' + ] + + logger.info("=" * 60) + if args is None: + args = self + key = "Compression" + + logger.info('{:^40}'.format("{} Configuration Arguments".format(key))) + if key == "Compression": + logger.info("Compression Suggestions: `Strategy` supports 'dynabert+ptq'," \ + "'dynabert' and 'ptq'. `width_mult_list` is needed in " \ + "`dynabert`, and `algo_list`, `batch_num_list`, `batch_size_list`," \ + " `round_type`, `bias_correction`, `weight_quantize_type`, " \ + "`infer_model_path` are needed in 'ptq'. " + ) + logger.info('{:30}:{}'.format("paddle commit id", + paddle.version.commit)) + + for arg in dir(args): + if key == "Compression" and arg not in compression_arg_name: + continue + if arg[:2] != "__": #don't print double underscore methods + v = getattr(args, arg) + if not isinstance(v, types.MethodType): + logger.info('{:30}:{}'.format(arg, v)) + + logger.info("") diff --git a/model_zoo/ernie-3.0/compress_trainer.py b/paddlenlp/trainer/trainer_compress.py similarity index 54% rename from model_zoo/ernie-3.0/compress_trainer.py rename to paddlenlp/trainer/trainer_compress.py index ec73a354eb93..2fda00effc5b 100644 --- a/model_zoo/ernie-3.0/compress_trainer.py +++ b/paddlenlp/trainer/trainer_compress.py @@ -15,213 +15,100 @@ import time import os import copy - +import math import numpy as np + import paddle +from paddle.utils import try_import import paddle.nn as nn import paddle.nn.functional as F from paddle.metric import Accuracy from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization -nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward -nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv - -from paddlenlp.trainer import Trainer -from paddlenlp.utils.log import logger -from paddlenlp.data import Pad - -from paddlenlp.transformers import AutoModelForSequenceClassification -from paddlenlp.transformers import AutoModelForQuestionAnswering -from paddlenlp.transformers import AutoModelForTokenClassification -from paddlenlp.transformers import export_model - -from paddlenlp.metrics import ChunkEvaluator -from paddlenlp.metrics.squad import squad_evaluate, compute_prediction - - -def try_import_paddleslim(): - try: - import paddleslim - except ImportError: - raise ImportError( - 'Cannot import paddleslim, please install paddleslim.') - - -class DynabertConfig: - - def __init__(self, - width_mult_list=[3 / 4], - output_filename_prefix="float32"): - """ - Pruning class config of DynaBERT stratedy. - Args: - width_mult_list (list of float): - Width mult list for DynaBERT. - Defaults to `[3/4]`. - output_filename_prefix (str): - Prefix of pruned model's filename. - Defaults to `float32`. - """ - self.compress_type = "dynabert" - self.width_mult_list = width_mult_list - self.output_filename_prefix = output_filename_prefix - - -class PTQConfig: - - def __init__(self, - algo_list=["hist"], - batch_size_list=[4], - input_dir=None, - input_filename_prefix="float32", - output_filename_prefix="int8"): - """ - Quantization class config of Post-Training method. - Args: - algo_list (list of str): - Algorithm name list of `PostTrainingQuantization`. Each - algorithm would be performed to input model. Supported - algorithms are `KL`, `abs_max`, `min_max`, `avg`, `hist` - and `mse`. - Defaults to `["hist"]`. - batch_size_list (list of int): - Number of calibration samples. - Defaults to `[4]`. - input_dir (str): - Directory name of model to be quantized. - Defaults to `None`. - input_filename_prefix (str): - Prefix of model filename after quantization. - Defaults to `"float32"`. - output_filename_prefix (str): - Prefix of model filename after quantization. - Defaults to `"int8"`. - """ - self.compress_type = "ptq" - self.algo_list = algo_list - self.batch_size_list = batch_size_list - self.input_dir = input_dir - self.input_filename_prefix = input_filename_prefix - self.output_filename_prefix = output_filename_prefix - - -class CompressConfig: - - def __init__(self, - prune_config=DynabertConfig(), - quantization_config=PTQConfig()): - """ - Model compression Config class. It accepts Hyperparameters of - pruning and quantization. - Args: - prune_config (`DynabertConfig`): - Accepts Hyperparameters of pruning. More prune strategies would - be supported in the future. - Defaults to `DynabertConfig()`. - quantization_config (`PTQConfig`): - Accepts Hyperparameters of pruning. More quantization methods - would be supported in the future. - Defaults to `PTQConfig()`. - - """ - assert isinstance(prune_config, (DynabertConfig)), \ - "`prune_config` should be an instance of `DynabertConfig`." - assert isinstance(quantization_config, (PTQConfig)), \ - "`quantization_config` shoule be an instance of `PTQConfig`." - self.prune_config = prune_config - self.quantization_config = quantization_config - - -def compress(self, - output_dir, - pruning=True, - quantization=True, - compress_config=CompressConfig()): +from ..utils.log import logger +from ..data import Pad +from ..transformers import AutoModelForSequenceClassification +from ..transformers import AutoModelForQuestionAnswering +from ..transformers import AutoModelForTokenClassification +from ..transformers import export_model +from ..transformers.ofa_utils import * +from ..transformers.model_outputs import BaseModelOutputWithPoolingAndCrossAttentions +from ..metrics import ChunkEvaluator +from ..metrics.squad import squad_evaluate, compute_prediction + +from .trainer_base import Trainer + + +def compress(self): """ - Supports pruning and quantization. If both are needed, pruning would be - performed before quantizaton. - Args: - output_dir (str): - Directory name of Pruning or quantized models. - pruning (bool): - Whether to prune. - Defaults to `True`. - quantization (bool): - Whether to quantize. - Defaults to `True`. - compress_config (`CompressConfig`): - Compress config instance to pass parameters for pruning or - quantization. - Defaults to `CompressConfig()`. + Supports pruning DynaBERT and post-training quantization. If both are + needed, pruning DynaBERT would be performed before quantizaton. """ - if pruning: - try_import_paddleslim() - self.prune(output_dir, compress_config.prune_config) - if quantization: - for width_mult in compress_config.prune_config.width_mult_list: - output_dir_width = os.path.join(output_dir, str(width_mult)) - self.quant(output_dir_width, output_dir_width, - compress_config.quantization_config) - elif quantization: - input_dir = compress_config.quantization_config.input_dir - if input_dir is None: - compress_config.quantization_config.input_filename_prefix = "model" + args = self.args + if "dynabert" in args.strategy: + try_import('paddleslim') + _dynabert(self, self.model, args.output_dir) + if "ptq" in args.strategy: + self.args.input_filename_prefix = "pruned_model" + for width_mult in args.width_mult_list: + output_dir_width = os.path.join(args.output_dir, + "width_mult_" + str(width_mult)) + self.quant(output_dir_width, "ptq") + elif args.strategy == "ptq": + # Input model is an inference model + if args.infer_model_path is not None: + model_dir = os.path.dirname(args.infer_model_path) + self.args.input_filename_prefix = os.path.basename( + args.infer_model_path) + self.quant(model_dir, args.strategy) + # Input model is load from Trainer API in dygraph. + else: + # Prefix of `export_model` is 'model' + self.args.input_filename_prefix = "model" input_spec = [ paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids paddle.static.InputSpec(shape=[None, None], - dtype="int64") # segment_ids + dtype="int64") # token_type_ids ] - original_inference_model_dir = os.path.join(output_dir, "inference") + input_dir = args.output_dir export_model(model=self.model, input_spec=input_spec, - path=original_inference_model_dir) - self.quant(original_inference_model_dir, output_dir, - compress_config.quantization_config) - + path=input_dir) + self.quant(input_dir, args.strategy) -def prune(self, output_dir, prune_config=DynabertConfig()): - """ - Supports DynaBERT strategy now. - """ - assert isinstance(prune_config, (DynabertConfig)), \ - "`prune_config` should be an instance of `DynabertConfig`." - if prune_config.compress_type == "dynabert": - _dynabert(self, self.model, output_dir, prune_config) - -def quant(self, input_dir, output_dir, quantization_config=PTQConfig()): +def quant(self, input_dir, strategy): """ Supports Post-Training Quantization now. """ - assert isinstance(quantization_config, (PTQConfig)), \ - "`quantization_config` shoule be an instance of `PTQConfig`." - eval_dataloader = self.get_eval_dataloader(self.eval_dataset) - nn.MultiHeadAttention._prepare_qkv = nn.MultiHeadAttention._ori_prepare_qkv - _post_training_quantization_grid_search(eval_dataloader, self.eval_dataset, - input_dir, output_dir, - quantization_config) + if strategy == "ptq": + eval_dataloader = self.get_eval_dataloader(self.eval_dataset) + _post_training_quantization_grid_search(eval_dataloader, + self.eval_dataset, + self.args.device, input_dir, + self.args) -def _dynabert(self, model, output_dir, dynabert_config): - model.base_model_class._ori_forward = model.base_model_class.forward - model.base_model_class.forward = auto_model_forward +def _dynabert(self, model, output_dir): + args = self.args + model = _replace_auto_model_forward(model) # Each batch is a dict. train_dataloader = self.get_train_dataloader() - eval_dataloader = self.get_eval_dataloader(self.eval_dataset) + if "QuestionAnswering" in model.__class__.__name__: eval_dataloader_with_label = self.get_eval_dataloader( self.eval_examples) - ofa_model, teacher_model = _dynabert_init( - model, eval_dataloader_with_label, self.criterion, - dynabert_config.width_mult_list) + ofa_model, teacher_model = _dynabert_init(model, + eval_dataloader_with_label, + self.criterion, + args.width_mult_list) else: - ofa_model, teacher_model = _dynabert_init( - model, eval_dataloader, self.criterion, - dynabert_config.width_mult_list) - args = self.args + ofa_model, teacher_model = _dynabert_init(model, eval_dataloader, + self.criterion, + args.width_mult_list) # TODO: args.gradient_accumulation_steps if args.max_steps > 0: @@ -230,27 +117,88 @@ def _dynabert(self, model, output_dir, dynabert_config): len(train_dataloader)) else: args.num_training_steps = len(train_dataloader) * args.num_train_epochs + args.num_train_epochs = math.ceil(args.num_train_epochs) self.create_optimizer_and_scheduler( num_training_steps=args.num_training_steps) ofa_model = _dynabert_training(self, ofa_model, model, teacher_model, train_dataloader, eval_dataloader, - dynabert_config.width_mult_list, - self.criterion, args.num_train_epochs, - output_dir) + args.width_mult_list, self.criterion, + args.num_train_epochs, args.output_dir) # Each width_mult best model would be exported. - _dynabert_export(ofa_model, dynabert_config, output_dir) + _dynabert_export(ofa_model, args.width_mult_list, args.output_dir) + + ofa_model, ofa_model.model = _recover_transformer_func( + ofa_model, True), _recover_transformer_func(ofa_model.model, True) + ofa_model.model = _recover_auto_model_forward(ofa_model.model) + logger.info("Pruning is finished using DynaBERT strategy.") + + +def _replace_transformer_func(self): + nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward + nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv + + nn.MultiHeadAttention._forward = mha_ofa_forward + nn.MultiHeadAttention.__prepare_qkv = prepare_qkv_ofa + nn.TransformerEncoder._forward = encoder_ofa_forward + nn.TransformerEncoderLayer._forward = encoder_layer_ofa_forward + + def init_func(layer): + if isinstance(layer, nn.MultiHeadAttention): + layer.forward = layer._forward + layer._prepare_qkv = layer.__prepare_qkv + elif isinstance(layer, nn.TransformerEncoderLayer): + layer.forward = layer._forward + elif isinstance(layer, nn.TransformerEncoder): + layer.forward = layer._forward + + for layer in self.children(): + layer.apply(init_func) + return self + + +def _recover_transformer_func(self, all_recover=False): + + def init_func(layer): + if isinstance(layer, nn.MultiHeadAttention): + layer.forward = layer._ori_forward + elif isinstance(layer, nn.TransformerEncoderLayer): + layer.forward = layer._ori_forward + elif isinstance(layer, nn.TransformerEncoder): + layer.forward = layer._ori_forward + if all_recover: + if isinstance(layer, nn.MultiHeadAttention): + layer._prepare_qkv = layer._ori_prepare_qkv + + for layer in self.children(): + layer.apply(init_func) + + return self + + +def _replace_auto_model_forward(self): + self.base_model_class._forward = auto_model_forward + self.base_model_class._ori_forward = self.base_model_class.forward + + def init_func(layer): + if isinstance(layer, self.base_model_class): + layer.forward = layer._forward + + for layer in self.children(): + layer.apply(init_func) + return self - model.base_model_class.forward = model.base_model_class._ori_forward - logger.info("DynaBERT training finished.") +def _recover_auto_model_forward(self): -def _recover_transormer_func(): - nn.TransformerEncoder.forward = paddle.nn.TransformerEncoder._ori_forward - nn.TransformerEncoderLayer.forward = paddle.nn.TransformerEncoderLayer._ori_forward - nn.MultiHeadAttention.forward = paddle.nn.MultiHeadAttention._ori_forward - # nn.MultiHeadAttention._prepare_qkv = nn.MultiHeadAttention._ori_prepare_qkv + def init_func(layer): + if isinstance(layer, self.base_model_class): + layer.forward = layer._ori_forward + + for layer in self.children(): + layer.apply(init_func) + return self def _dynabert_init(model, eval_dataloader, criterion, width_mult_list): @@ -291,11 +239,8 @@ def _dynabert_init(model, eval_dataloader, criterion, width_mult_list): # Step6: Calculate the importance of neurons and head, # and then reorder them according to the importance. - # NOTE: Importing `nlp_utils` would rewrite `forward` function of - # TransformerEncoder, TransformerEncoderLayer, MultiHeadAttention and - # `_prepare_qkv` function of MultiHeadAttention. - from paddleslim.nas.ofa.utils import nlp_utils - + ofa_model.model, ofa_model = _replace_transformer_func( + ofa_model.model), _replace_transformer_func(ofa_model) head_importance, neuron_importance = compute_neuron_head_importance( model=ofa_model.model, data_loader=eval_dataloader, @@ -349,10 +294,10 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): all_end_logits.append(end_logits_tensor.numpy()[idx]) else: - input_ids, segment_ids, labels = batch['input_ids'], batch[ + input_ids, token_type_ids, labels = batch['input_ids'], batch[ 'token_type_ids'], batch['labels'] logits = model(input_ids, - segment_ids, + token_type_ids, attention_mask=[None, None]) if isinstance(logits, tuple): logits = logits[0] @@ -413,14 +358,13 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): global_step = 0 lambda_logit = 1.0 tic_train = time.time() - best_acc = 0.0 + best_acc = [0.0] * len(width_mult_list) acc = 0.0 logger.info("DynaBERT training starts. This period will cost some time.") for epoch in range(num_train_epochs): # Step7: Set current epoch and task. ofa_model.set_epoch(epoch) ofa_model.set_task('width') - for step, batch in enumerate(train_dataloader): global_step += 1 if "QuestionAnswering" in model.__class__.__name__: @@ -474,11 +418,11 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): tic_eval = time.time() acc = evaluate(ofa_model, criterion, eval_dataloader, width_mult) - if acc > best_acc: - best_acc = acc + if acc > best_acc[idx]: + best_acc[idx] = acc if paddle.distributed.get_rank() == 0: output_dir_width = os.path.join( - output_dir, str(width_mult)) + output_dir, "width_mult_" + str(width_mult)) if not os.path.exists(output_dir_width): os.makedirs(output_dir_width) # need better way to get inner model of DataParallel @@ -488,15 +432,17 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): logger.info("eval done total : %s s" % (time.time() - tic_eval)) if global_step > self.args.num_training_steps: - if best_acc == 0.0: - output_dir_width = os.path.join(output_dir, str(width_mult)) + if best_acc[idx] == 0.0: + output_dir_width = os.path.join( + output_dir, "width_mult_" + str(width_mult)) if not os.path.exists(output_dir_width): os.makedirs(output_dir_width) # need better way to get inner model of DataParallel model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir_width) - logger.info("Best acc: %.4f" % (best_acc)) + logger.info("Best acc of width_mult %s: %.4f" % + (width_mult, best_acc[idx])) return ofa_model if "QuestionAnswering" in model.__class__.__name__: @@ -509,11 +455,11 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): tic_eval = time.time() acc = evaluate(ofa_model, criterion, eval_dataloader, width_mult) - if acc > best_acc: - best_acc = acc + if acc > best_acc[idx]: + best_acc[idx] = acc if paddle.distributed.get_rank() == 0: - output_dir_width = os.path.join(output_dir, - str(width_mult)) + output_dir_width = os.path.join( + output_dir, "width_mult_" + str(width_mult)) if not os.path.exists(output_dir_width): os.makedirs(output_dir_width) # need better way to get inner model of DataParallel @@ -522,18 +468,22 @@ def evaluate(model, criterion, data_loader, width_mult=1.0): model_to_save.save_pretrained(output_dir_width) logger.info("eval done total : %s s" % (time.time() - tic_eval)) - logger.info("Best acc: %.4f" % (best_acc)) + for idx, width_mult in enumerate(width_mult_list): + logger.info("Best acc of width_mult %s: %.4f" % + (width_mult, best_acc[idx])) return ofa_model -def _dynabert_export(ofa_model, dynabert_config, output_dir): +def _dynabert_export(ofa_model, width_mult_list, output_dir): from paddleslim.nas.ofa import OFA, DistillConfig, utils - ofa_model.model.base_model_class.forward = auto_model_forward ofa_model._add_teacher = False - _recover_transormer_func() + ofa_model, ofa_model.model = _recover_transformer_func( + ofa_model), _recover_transformer_func(ofa_model.model) - for width_mult in dynabert_config.width_mult_list: - model_dir = os.path.join(output_dir, str(width_mult)) + ori_num_heads = ofa_model.model.base_model.encoder.layers[ + 0].self_attn.num_heads + for width_mult in width_mult_list: + model_dir = os.path.join(output_dir, "width_mult_" + str(width_mult)) state_dict = paddle.load(os.path.join(model_dir, "model_state.pdparams")) if "QuestionAnswering" in ofa_model.model.__class__.__name__: @@ -551,29 +501,31 @@ def _dynabert_export(ofa_model, dynabert_config, output_dir): input_shapes=[[1, 1], [1, 1]], input_dtypes=['int64', 'int64'], origin_model=origin_model) - for name, sublayer in origin_model_new.named_sublayers(): if isinstance(sublayer, paddle.nn.MultiHeadAttention): sublayer.num_heads = int(width_mult * sublayer.num_heads) - input_shape = [ paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64') ] - pruned_infer_model_dir = os.path.join( - model_dir, dynabert_config.output_filename_prefix) + pruned_infer_model_dir = os.path.join(model_dir, "pruned_model") + net = paddle.jit.to_static(origin_model_new, input_spec=input_shape) paddle.jit.save(net, pruned_infer_model_dir) + # Recover num_heads of ofa_model.model + for layer in ofa_model.model.base_model.encoder.layers: + layer.self_attn.num_heads = ori_num_heads def _post_training_quantization_grid_search(eval_dataloader, eval_dataset, - input_dir, output_dir, - quantization_config): + device, model_dir, args): paddle.enable_static() - place = paddle.set_device("gpu") + place = paddle.set_device(device) exe = paddle.static.Executor(place) - def _post_training_quantization(algo, batch_size): + args.output_filename_prefix = "int8" + + def _post_training_quantization(algo, batch_size, batch_nums): def _batch_generator_func(): batch_data = [[], []] @@ -582,44 +534,42 @@ def _batch_generator_func(): batch_data[1].append(data['token_type_ids']) if len(batch_data[0]) == batch_size: input_ids = Pad(axis=0, pad_val=0)(batch_data[0]) - segment_ids = Pad(axis=0, pad_val=0)(batch_data[1]) - yield [input_ids, segment_ids] + token_type_ids = Pad(axis=0, pad_val=0)(batch_data[1]) + yield [input_ids, token_type_ids] batch_data = [[], []] post_training_quantization = PostTrainingQuantization( executor=exe, batch_generator=_batch_generator_func, - model_dir=input_dir, - model_filename=quantization_config.input_filename_prefix + - ".pdmodel", - params_filename=quantization_config.input_filename_prefix + - ".pdiparams", + model_dir=model_dir, + model_filename=args.input_filename_prefix + ".pdmodel", + params_filename=args.input_filename_prefix + ".pdiparams", batch_size=batch_size, - batch_nums=1, + batch_nums=batch_nums, scope=None, algo=algo, hist_percent=0.9999, - bias_correction=False, + round_type=args.round_type, + bias_correction=args.bias_correction, quantizable_op_type=['matmul', 'matmul_v2'], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', - weight_quantize_type='channel_wise_abs_max', - onnx_format=True, + weight_quantize_type=args.weight_quantize_type, + onnx_format=False, optimize_model=False) post_training_quantization.quantize() post_training_quantization.save_quantized_model( - save_model_path=os.path.join(output_dir, algo + str(batch_size)), - model_filename=quantization_config.output_filename_prefix + - ".pdmodel", - params_filename=quantization_config.output_filename_prefix + - ".pdiparams") + save_model_path=os.path.join(model_dir, algo + str(batch_size)), + model_filename=args.output_filename_prefix + ".pdmodel", + params_filename=args.output_filename_prefix + ".pdiparams") logger.info("Post training quantization starts.") - for algo in quantization_config.algo_list: - for batch_size in quantization_config.batch_size_list: - _post_training_quantization(algo, batch_size) + for algo in args.algo_list: + for batch_size in args.batch_size_list: + for batch_nums in args.batch_num_list: + _post_training_quantization(algo, batch_size, batch_nums) paddle.disable_static() logger.info("Post training quantization ends.") @@ -629,131 +579,52 @@ def auto_model_forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=[None, None]): + attention_mask=[None, None], + task_type_ids=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): wtype = self.pooler.dense.fn.weight.dtype if hasattr( self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype if attention_mask is None: attention_mask = paddle.unsqueeze( - (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) - if attention_mask[0] is None: + (input_ids == self.pad_token_id).astype(wtype) * -1e4, axis=[1, 2]) + elif isinstance(attention_mask, paddle.Tensor) and attention_mask.ndim == 2: + attention_mask = paddle.unsqueeze(attention_mask, + axis=[1, 2]).astype(wtype) + attention_mask = (1.0 - attention_mask) * -1e4 + elif attention_mask[0] is None: attention_mask[0] = paddle.unsqueeze( - (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) - embedding_output = self.embeddings(input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids) - encoder_outputs = self.encoder(embedding_output, attention_mask) - sequence_output = encoder_outputs - pooled_output = self.pooler(sequence_output) - return sequence_output, pooled_output - - -def reorder_neuron_head(model, head_importance, neuron_importance): - """ - Reorders weights according head importance and neuron importance - """ - from paddleslim.nas.ofa.utils import nlp_utils - # Reorders heads and ffn neurons - for layer, current_importance in enumerate(neuron_importance): - # Reorders heads - idx = paddle.argsort(head_importance[layer], descending=True) - nlp_utils.reorder_head(model.base_model.encoder.layers[layer].self_attn, - idx) - # Reorders neurons - idx = paddle.argsort(paddle.to_tensor(current_importance), - descending=True) - nlp_utils.reorder_neuron( - model.base_model.encoder.layers[layer].linear1.fn, idx, dim=1) - - nlp_utils.reorder_neuron( - model.base_model.encoder.layers[layer].linear2.fn, idx, dim=0) - - -def compute_neuron_head_importance(model, - data_loader, - num_layers, - num_heads, - loss_fct=nn.loss.CrossEntropyLoss(), - intermediate_name='linear1', - output_name='linear2'): - """ - Compute the importance of multi-head attention and feed-forward neuron in - each transformer layer. - - Args: - model(paddle.nn.Layer): - The instance of transformer model. - data_loader (DataLoader): - An iterable data loader is used for evaluate. An instance of - `paddle.io.Dataloader`. - num_layers (int): - Number of transformer layers. - num_heads (int): - Number of heads in each multi-head attention. - loss_fct (Loss|optional): - Loss function can be a `paddle.nn.Layer` instance. Default: `nn.loss.CrossEntropyLoss()`. - intermediate_name (str|optional): - The name of intermediate `Linear` layer in feed-forward. - Defaults to `linear1`. - output_name (str|optional): - The name of output `Linear` layer in feed-forward. - Defaults to `linear2`. - """ - head_importance = paddle.zeros(shape=[num_layers, num_heads], - dtype='float32') - head_mask = paddle.ones(shape=[num_layers, num_heads], dtype='float32') - head_mask.stop_gradient = False - - intermediate_weight = [] - intermediate_bias = [] - output_weight = [] - - for name, w in model.named_parameters(): - if intermediate_name in name: - if len(w.shape) > 1: - intermediate_weight.append(w) - else: - intermediate_bias.append(w) - - if output_name in name: - if len(w.shape) > 1: - output_weight.append(w) - - neuron_importance = [] - for w in intermediate_weight: - neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32')) - - for batch in data_loader: - if isinstance(batch, dict): - if "QuestionAnswering" in model.__class__.__name__: - input_ids, segment_ids, start_positions, end_positions = batch[ - 'input_ids'], batch['token_type_ids'], batch[ - 'start_positions'], batch['end_positions'] - else: - input_ids, segment_ids, labels = batch['input_ids'], batch[ - 'token_type_ids'], batch['labels'] - else: - input_ids, segment_ids, labels = batch - logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) - if "QuestionAnswering" in model.__class__.__name__: - start_logits, end_logits = logits - loss = (loss_fct(start_logits, start_positions) + - loss_fct(end_logits, end_positions)) / 2 - else: - loss = loss_fct(logits, labels) - loss.backward() - head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient())) - - for w1, b1, w2, current_importance in zip(intermediate_weight, - intermediate_bias, - output_weight, - neuron_importance): - current_importance += np.abs( - (np.sum(w1.numpy() * w1.gradient(), axis=0) + - b1.numpy() * b1.gradient())) - current_importance += np.abs( - np.sum(w2.numpy() * w2.gradient(), axis=1)) - - return head_importance, neuron_importance + (input_ids == self.pad_token_id).astype(wtype) * -1e4, axis=[1, 2]) + if "use_task_id" in self.config: + embedding_output = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids) + else: + embedding_output = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder(embedding_output, + src_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + if isinstance(encoder_outputs, type(embedding_output)): + sequence_output = encoder_outputs + pooled_output = self.pooler(sequence_output) + return (sequence_output, pooled_output) + else: + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) def soft_cross_entropy(inp, target): @@ -763,5 +634,4 @@ def soft_cross_entropy(inp, target): Trainer.compress = compress -Trainer.prune = prune Trainer.quant = quant diff --git a/paddlenlp/transformers/ofa_utils.py b/paddlenlp/transformers/ofa_utils.py new file mode 100644 index 000000000000..ceaf958807c8 --- /dev/null +++ b/paddlenlp/transformers/ofa_utils.py @@ -0,0 +1,350 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = [ + 'prepare_qkv_ofa', 'mha_ofa_forward', 'encoder_ofa_forward', + 'encoder_layer_ofa_forward', 'compute_neuron_head_importance', + 'reorder_neuron_head' +] + + +def prepare_qkv_ofa(self, query, key, value, cache=None): + q = self.q_proj(query) + if hasattr(self.q_proj, + 'fn') and self.q_proj.fn.cur_config['expand_ratio'] != None: + self.num_heads = int(self.num_heads * + self.q_proj.fn.cur_config['expand_ratio']) + q = paddle.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = paddle.concat([cache.k, k], axis=2) + v = paddle.concat([cache.v, v], axis=2) + cache = self.Cache(k, v) + + return (q, k, v) if cache is None else (q, k, v, cache) + + +def mha_ofa_forward(self, query, key, value, attn_mask=None, cache=None): + """ + monkey patch for MultiHeadAttention forward to accept head_mask + attn_mask[0] = attn_mask, attn_mask[1] = head_mask + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if cache is None: + q, k, v = self._prepare_qkv(query, key, value, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, cache) + + # scale dot product attention + # TODO: use paddle.matmul, however it doesn't support `alpha` + product = paddle.fluid.layers.matmul(x=q, + y=k, + transpose_y=True, + alpha=self.head_dim**-0.5) + if attn_mask[0] is not None: + # TODO(guosheng): support bool mask + product = product + attn_mask[0] + weights = F.softmax(product) + if self.dropout: + weights = F.dropout(weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + if attn_mask[1] is not None: + weights = weights * attn_mask[1] + + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + outs = [out] + if self.need_weights: + outs.append(weights) + if cache is not None: + outs.append(cache) + + if hasattr(self.q_proj, + 'fn') and self.q_proj.fn.cur_config['expand_ratio'] != None: + self.num_heads = int( + float(self.num_heads) / self.q_proj.fn.cur_config['expand_ratio']) + return out if len(outs) == 1 else tuple(outs) + + +def encoder_ofa_forward(self, + src, + src_mask=[None, None], + cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + """ + monkey patch for TransformerEncoder forward to accept head_mask + attn_mask[0] = attn_mask, attn_mask[1] = head_mask + """ + output = src + if src_mask[1] is not None: + head_mask = src_mask[1] + if len(head_mask.shape) == 1: + head_mask = paddle.unsqueeze( + paddle.unsqueeze( + paddle.unsqueeze(paddle.unsqueeze(head_mask, 0), 0), -1), + -1) + head_mask = paddle.expand(head_mask, + shape=[self.num_layers] + + head_mask.shape[1:]) + elif len(head_mask.shape) == 2: + head_mask = paddle.unsqueeze( + paddle.unsqueeze(paddle.unsqueeze(head_mask, 1), -1), -1) + else: + head_mask = [None] * self.num_layers + + for i, mod in enumerate(self.layers): + output = mod(output, src_mask=[src_mask[0], head_mask[i]]) + + if self.norm is not None: + output = self.norm(output) + + return output + + +def encoder_layer_ofa_forward(self, + src, + src_mask=None, + cache=None, + output_attentions=False): + residual = src + if self.normalize_before: + src = self.norm1(src) + # Add cache for encoder for the usage like UniLM + if cache is None: + src = self.self_attn(src, src, src, src_mask) + else: + src, incremental_cache = self.self_attn(src, src, src, src_mask, cache) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src if cache is None else (src, incremental_cache) + + +def reorder_head(layer, index): + """ + Reorder head weights according index. + Args: + layer(paddle.nn.Layer): the instance of `paddle.nn.MultiHeadAttention` layer. + index(list): the sort indices of multi-head. + """ + assert isinstance(layer, nn.MultiHeadAttention), \ + "layer in reorder_head must be the instance of `paddle.nn.MultiHeadAttention`." + n, a = layer.num_heads, layer.head_dim + idx = paddle.reshape(paddle.index_select(paddle.reshape(paddle.arange( + 0, n * a, dtype='int64'), + shape=[n, a]), + index=index, + axis=0), + shape=[-1]) + + def reorder_head_matrix(linearLayer, index, dim=1): + W = paddle.index_select(linearLayer.weight, index, axis=dim).detach() + if linearLayer.bias is not None: + if dim == 0: + b = paddle.assign(linearLayer.bias).detach() + else: + b = paddle.assign( + paddle.index_select(linearLayer.bias, index, + axis=0)).detach() + + linearLayer.weight.stop_gradient = True + linearLayer.weight.set_value(W) + linearLayer.weight.stop_gradient = False + if linearLayer.bias is not None: + linearLayer.bias.stop_gradient = True + linearLayer.bias.set_value(b) + linearLayer.bias.stop_gradient = False + + reorder_head_matrix( + layer.q_proj.fn if hasattr(layer.q_proj, 'fn') else layer.q_proj, idx) + reorder_head_matrix( + layer.k_proj.fn if hasattr(layer.k_proj, 'fn') else layer.k_proj, idx) + reorder_head_matrix( + layer.v_proj.fn if hasattr(layer.v_proj, 'fn') else layer.v_proj, idx) + reorder_head_matrix( + layer.out_proj.fn if hasattr(layer.out_proj, 'fn') else layer.out_proj, + idx, + dim=0) + + +def reorder_neuron(layer, index, dim=0): + """ + Reorder feed-forward weights according index. + Args: + layer(paddle.nn.Layer): the instance of `paddle.nn.Linear` layer. + index(list): the sort indices of feed-forward. + dim(int): select weights according to the dim. + """ + linearLayer = layer.fn if hasattr(layer, 'fn') else layer + W = paddle.index_select(linearLayer.weight, index, axis=dim).detach() + if linearLayer.bias is not None: + if dim == 0: + b = paddle.assign(linearLayer.bias).detach() + else: + b = paddle.assign( + paddle.index_select(linearLayer.bias, index, axis=0)).detach() + linearLayer.weight.stop_gradient = True + linearLayer.weight.set_value(W) + linearLayer.weight.stop_gradient = False + + if linearLayer.bias is not None: + linearLayer.bias.stop_gradient = True + linearLayer.bias.set_value(b) + linearLayer.bias.stop_gradient = False + + +def reorder_neuron_head(model, head_importance, neuron_importance): + """ + Reorders weights according head importance and neuron importance + """ + # Reorders heads and ffn neurons + for layer, current_importance in enumerate(neuron_importance): + # Reorders heads + idx = paddle.argsort(head_importance[layer], descending=True) + reorder_head(model.base_model.encoder.layers[layer].self_attn, idx) + # Reorders neurons + idx = paddle.argsort(paddle.to_tensor(current_importance), + descending=True) + reorder_neuron(model.base_model.encoder.layers[layer].linear1.fn, + idx, + dim=1) + + reorder_neuron(model.base_model.encoder.layers[layer].linear2.fn, + idx, + dim=0) + + +def compute_neuron_head_importance(model, + data_loader, + num_layers, + num_heads, + loss_fct=nn.loss.CrossEntropyLoss(), + intermediate_name='linear1', + output_name='linear2'): + """ + Compute the importance of multi-head attention and feed-forward neuron in + each transformer layer. + + Args: + model(paddle.nn.Layer): + The instance of transformer model. + data_loader (DataLoader): + An iterable data loader is used for evaluate. An instance of + `paddle.io.Dataloader`. + num_layers (int): + Number of transformer layers. + num_heads (int): + Number of heads in each multi-head attention. + loss_fct (Loss|optional): + Loss function can be a `paddle.nn.Layer` instance. Default: `nn.loss.CrossEntropyLoss()`. + intermediate_name (str|optional): + The name of intermediate `Linear` layer in feed-forward. + Defaults to `linear1`. + output_name (str|optional): + The name of output `Linear` layer in feed-forward. + Defaults to `linear2`. + """ + head_importance = paddle.zeros(shape=[num_layers, num_heads], + dtype='float32') + head_mask = paddle.ones(shape=[num_layers, num_heads], dtype='float32') + head_mask.stop_gradient = False + + intermediate_weight = [] + intermediate_bias = [] + output_weight = [] + + for name, w in model.named_parameters(): + if intermediate_name in name: + if len(w.shape) > 1: + intermediate_weight.append(w) + else: + intermediate_bias.append(w) + + if output_name in name: + if len(w.shape) > 1: + output_weight.append(w) + + neuron_importance = [] + for w in intermediate_weight: + neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32')) + + for batch in data_loader: + if isinstance(batch, dict): + if "QuestionAnswering" in model.__class__.__name__: + input_ids, segment_ids, start_positions, end_positions = batch[ + 'input_ids'], batch['token_type_ids'], batch[ + 'start_positions'], batch['end_positions'] + else: + input_ids, segment_ids, labels = batch['input_ids'], batch[ + 'token_type_ids'], batch['labels'] + else: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) + if "QuestionAnswering" in model.__class__.__name__: + start_logits, end_logits = logits + loss = (loss_fct(start_logits, start_positions) + + loss_fct(end_logits, end_positions)) / 2 + else: + loss = loss_fct(logits, labels) + loss.backward() + head_importance += paddle.abs(paddle.to_tensor(head_mask.gradient())) + + for w1, b1, w2, current_importance in zip(intermediate_weight, + intermediate_bias, + output_weight, + neuron_importance): + current_importance += np.abs( + (np.sum(w1.numpy() * w1.gradient(), axis=0) + + b1.numpy() * b1.gradient())) + current_importance += np.abs( + np.sum(w2.numpy() * w2.gradient(), axis=1)) + + return head_importance, neuron_importance