Skip to content

Commit

Permalink
update dataset preprocess (modelscope#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jun 30, 2024
1 parent dbfa9ce commit 6c4f9e5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 47 deletions.
1 change: 1 addition & 0 deletions requirements/llm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ charset_normalizer
cpm_kernels
fastapi
gradio>=3.40.0
openai
sentencepiece
tiktoken
uvicorn
3 changes: 1 addition & 2 deletions swift/llm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from tqdm import tqdm

from swift.utils import append_to_jsonl, get_logger, get_main, seed_everything
from . import DeployArguments
from .infer import merge_lora, prepare_model_template
from .utils import EvalArguments, XRequestConfig, inference, inference_client_async
from .utils import DeployArguments, EvalArguments, XRequestConfig, inference, inference_client_async

logger = get_logger()
mp.set_start_method('spawn', force=True)
Expand Down
46 changes: 1 addition & 45 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import pandas as pd
from datasets import Dataset as HfDataset
from datasets import Sequence, Value, concatenate_datasets
from datasets import concatenate_datasets
from datasets import load_dataset as load_hf_dataset
from numpy.random import RandomState
from pandas import DataFrame
Expand Down Expand Up @@ -2383,56 +2383,12 @@ def get_dataset(
assert model_name is not None and model_author is not None
dataset = _preprocess_self_cognition_dataset(dataset, model_name, model_author)

def _reduce_column(row: Dict[str, Any]) -> Dict[str, Any]:
res = {}
if 'query' in row and isinstance(row['query'], (list, tuple)):
res['query'] = np.random.choice(row['query'])
if 'response' in row and isinstance(row['response'], (list, tuple)):
res['response'] = np.random.choice(row['response'])
if 'rejected_response' in row and isinstance(row['rejected_response'], (list, tuple)):
res['rejected_response'] = np.random.choice(row['rejected_response'])
if 'history' in row:
if not row['history']:
res['_history'] = None
else:
res['_history'] = row['history']
if 'history_roles' in row:
if not row['history_roles']:
res['_history_roles'] = None
else:
res['_history_roles'] = row['history_roles']
if 'system' in row:
res['_system'] = row['system']
return res

def _reduce_dataset(ds: HfDataset):
features = ds.features.copy()
if 'history' in ds.features:
features['_history'] = Sequence(feature=Sequence(feature=Value(dtype='string')))
if 'history_roles' in ds.features:
features['_history_roles'] = Sequence(feature=Sequence(feature=Value(dtype='string')))
if 'system' in ds.features:
features['_system'] = Value(dtype='string')
ds = ds.map(_reduce_column, load_from_cache_file=False, features=features)
if 'history' in ds.features:
ds = ds.remove_columns(['history']).rename_column('_history', 'history')
if 'history_roles' in ds.features:
ds = ds.remove_columns(['history_roles']).rename_column('_history_roles', 'history_roles')
if 'system' in ds.features:
ds = ds.remove_columns(['system']).rename_column('_system', 'system')
return ds

train_d: HfDataset
if isinstance(dataset, (list, tuple)):
train_d, val_d = dataset
else:
train_d, val_d = dataset, None

if train_d:
train_d = _reduce_dataset(train_d)
if val_d:
val_d = _reduce_dataset(val_d)

assert train_d is not None or val_d is not None
if train_d is not None:
train_dataset_list.append(train_d)
Expand Down
38 changes: 38 additions & 0 deletions swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,42 @@
PreprocessFunc = Callable[[HfDataset], HfDataset]


def _reduce_dataset(cls: type) -> type:
if getattr(cls, '_patching', False):
return cls

call_func = cls.__call__
preprocess = cls.preprocess
cls._patching = True

def new_call_func(self, dataset: HfDataset) -> HfDataset:
self.column_state = set()
dataset = call_func(self, dataset)
for k in dataset.features.keys():
if k not in self.column_state:
dataset = dataset.remove_columns([k])
return dataset

def new_preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
row = preprocess(self, row)
for k, v in row.items():
if k == 'query_role':
if k not in self.column_state and v and v != 'user':
self.column_state.add(k)
elif k == 'history_roles':
if k not in self.column_state and v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v):
self.column_state.add(k)
else:
if v:
self.column_state.add(k)
return row

cls.__call__ = new_call_func
cls.preprocess = new_preprocess

return cls


def parse_medias(d: Dict[str, Any], media_key=None):
if isinstance(media_key, str):
if media_key in d:
Expand Down Expand Up @@ -90,6 +126,7 @@ def __call__(self, dataset: HfDataset) -> HfDataset:
return dataset


@_reduce_dataset
class AlpacaPreprocessor(MediaMixin, RowPreprocessMixin):

def __init__(self, concat_inst_inp: Optional[Callable[[str, str], str]] = None, **kwargs):
Expand Down Expand Up @@ -138,6 +175,7 @@ def _default_repair_conversations(s: Union[str, Any]) -> Any:
return s


@_reduce_dataset
class ConversationsPreprocessor(MediaMixin, RowPreprocessMixin):

def __init__(self,
Expand Down

0 comments on commit 6c4f9e5

Please sign in to comment.