-
Notifications
You must be signed in to change notification settings - Fork 817
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] reward model inferencer support
- Loading branch information
Yizhen
committed
Jun 22, 2024
1 parent
e5ab2fd
commit 3b2cb04
Showing
14 changed files
with
747 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. | ||
import logging | ||
import os | ||
import sys | ||
|
||
from transformers import ( | ||
HfArgumentParser | ||
) | ||
|
||
from lmflow.datasets import Dataset | ||
from lmflow.models.auto_model import AutoModel | ||
from lmflow.pipeline.auto_pipeline import AutoPipeline | ||
from lmflow.args import ( | ||
ModelArguments, | ||
DatasetArguments, | ||
AutoArguments, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def main(): | ||
# Parses arguments | ||
pipeline_name = "rm_inferencer" | ||
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) | ||
|
||
parser = HfArgumentParser(( | ||
ModelArguments, | ||
DatasetArguments, | ||
PipelineArguments | ||
)) | ||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | ||
# If we pass only one argument to the script and it's the path to a json file, | ||
# let's parse it to get our arguments. | ||
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||
else: | ||
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() | ||
|
||
dataset = Dataset(data_args) | ||
model = AutoModel.get_model(model_args, tune_strategy='none', use_accelerator=pipeline_args.use_accelerator) | ||
inferencer = AutoPipeline.get_pipeline( | ||
pipeline_name=pipeline_name, | ||
model_args=model_args, | ||
data_args=data_args, | ||
pipeline_args=pipeline_args | ||
) | ||
|
||
res = inferencer.inference( | ||
model, | ||
dataset, | ||
) | ||
|
||
if pipeline_args.save_results: | ||
res.save(pipeline_args.results_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/bin/bash | ||
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. | ||
|
||
# Parses arguments | ||
run_name=rm_inference | ||
# model_name_or_path=sfairXC/FsfairX-LLaMA3-RM-v0.1 | ||
model_name_or_path=/vol/yizhenjia/projs/RLHFlow-fox/models/rm/sfairXC-FsfairX-LLaMA3-RM-v0.1 | ||
dataset_path=data/alpaca/test | ||
output_dir=data/rm_inference_results | ||
output_file_name=results.json | ||
|
||
# Safety related arguments | ||
trust_remote_code=0 | ||
|
||
while [[ $# -ge 1 ]]; do | ||
key="$1" | ||
case ${key} in | ||
-r|--run_name) | ||
run_name="$2" | ||
shift | ||
;; | ||
-m|--model_name_or_path) | ||
model_name_or_path="$2" | ||
shift | ||
;; | ||
-d|--dataset_path) | ||
dataset_path="$2" | ||
shift | ||
;; | ||
--output_dir) | ||
output_dir="$2" | ||
shift | ||
;; | ||
--output_file_name) | ||
output_file_name="$2" | ||
shift | ||
;; | ||
--trust_remote_code) | ||
trust_remote_code="$2" | ||
shift | ||
;; | ||
*) | ||
echo "error: unknown option \"${key}\"" 1>&2 | ||
exit 1 | ||
esac | ||
shift | ||
done | ||
|
||
# inference | ||
project_dir=$(cd "$(dirname $0)"/..; pwd) | ||
log_dir=${project_dir}/log/${run_name} | ||
output_file_path=${output_dir}/${run_name}/${output_file_name} | ||
mkdir -p ${output_dir}/${run_name} ${log_dir} | ||
|
||
accelerate launch --config_file configs/accelerator_multigpu_config.yaml \ | ||
examples/rm_inference.py \ | ||
--trust_remote_code ${trust_remote_code} \ | ||
--model_name_or_path ${model_name_or_path} \ | ||
--arch_type text_regression \ | ||
--use_accelerator True \ | ||
--block_size 4096 \ | ||
--inference_batch_size 16 \ | ||
--dataset_path ${dataset_path} \ | ||
--preprocessing_num_workers 16 \ | ||
--save_results True \ | ||
--results_path ${output_file_path} \ | ||
2>&1 | tee ${log_dir}/rm_inference.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.