Skip to content

Commit

Permalink
refactor inference (modelscope#1245)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jun 28, 2024
1 parent 8728264 commit da58255
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 187 deletions.
2 changes: 1 addition & 1 deletion swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def prepare_model_template(args: InferArguments,

def read_media_file(infer_kwargs: Dict[str, Any], infer_media_type: Literal['none', 'round', 'dialogue']) -> None:
text = 'Input a media path or URL <<< '
images = infer_kwargs.get('images', [])
images = infer_kwargs.get('images') or []
if infer_media_type == 'none':
return
if infer_media_type == 'round' or len(images) == 0:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ class DeployArguments(InferArguments):
def __post_init__(self):
super().__post_init__()
model_info = MODEL_MAPPING[self.model_type]
tags = model_info.get('tags', [])
tags = model_info.get('tags') or []
self.is_multimodal = 'multi-modal' in tags


Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
response = conversations[-1][self.value_key]
system = sys
history = h
tools = d.get('tools', [])
tools = d.get('tools') or []
row = {'system': system, 'history': history, 'history_roles': hr}
row.update({
'query': query,
Expand Down
205 changes: 94 additions & 111 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,38 +542,22 @@ def __next__(self) -> List[int]:
return value


@torch.inference_mode()
def inference_stream(model: PreTrainedModel,
template: Template,
query: str,
history: Optional[History] = None,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
generation_info: Optional[Dict[str, int]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Iterator[Tuple[str, History]]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
def _prepare_inputs(model: PreTrainedModel,
template: Template,
query: str,
history: History,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], int]:
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
history = deepcopy(history)
if images is None:
images = []

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
act_length = len(history[-1][-1])
query = None

example = {
'query': query,
'history': history,
Expand All @@ -587,7 +571,7 @@ def inference_stream(model: PreTrainedModel,
truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
if len(inputs) == 0 and truncation_strategy == 'delete':
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
return '', history
return {}, tokenizer_kwargs, 0

inputs.pop('labels', None)
tokenizer = template.tokenizer
Expand All @@ -606,11 +590,8 @@ def inference_stream(model: PreTrainedModel,
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
model.eval()
if generation_config is None:
generation_config = getattr(model, 'generation_config', None)
generation_config = getattr(model, 'generation_config')
generation_config = deepcopy(generation_config)
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)

if tokenizer.eos_token_id is not None:
generation_config.eos_token_id = tokenizer.eos_token_id
Expand All @@ -627,21 +608,69 @@ def inference_stream(model: PreTrainedModel,
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs = to_device(inputs, device)
if generation_info is not None:
generation_info['num_prompt_tokens'] = token_len
if 'inputs_embeds' in inputs:
inputs.pop('input_ids', None)
streamer = TokenListIteratorStreamer()
if adapter_names is not None:
inputs['adapter_names'] = adapter_names
generation_kwargs = {
'streamer': streamer,
'generation_config': generation_config,
'stopping_criteria': stopping_criteria,
**inputs
}

stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs['stopping_criteria'] = stopping_criteria
inputs['generation_config'] = generation_config
return inputs, tokenizer_kwargs, token_len


@torch.inference_mode()
def inference_stream(model: PreTrainedModel,
template: Template,
query: str,
history: Optional[History] = None,
system: Optional[str] = None,
images: Optional[List[str]] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[StopWords] = None,
generation_info: Optional[Dict[str, int]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs) -> Iterator[Tuple[str, History]]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if history is None:
history = []
else:
history = deepcopy(history)
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
model,
template,
query,
history,
system,
images,
generation_config=generation_config,
stop_words=stop_words,
adapter_names=adapter_names,
**kwargs)
if len(inputs) == 0:
return '', history
if generation_info is None:
generation_info = {}
generation_info['num_prompt_tokens'] = token_len

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
act_length = len(history[-1][-1])
query = None

generation_config = inputs['generation_config']
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)

streamer = TokenListIteratorStreamer()
generation_kwargs = {'streamer': streamer, **inputs}
_model_generate = model.generate
if is_torch_npu_available():

Expand All @@ -667,8 +696,7 @@ def _model_generate(*args, **kwargs):
except StopIteration:
is_finished = True
generate_ids = template.get_generate_ids(torch.tensor(raw_generate_ids)[None], token_len)
if generation_info is not None:
generation_info['num_generated_tokens'] = len(generate_ids)
generation_info['num_generated_tokens'] = len(generate_ids)
response = template.generate_ids_to_response(
generate_ids,
is_finished,
Expand Down Expand Up @@ -702,58 +730,38 @@ def inference(model: PreTrainedModel,
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
history = deepcopy(history)
if images is None:
images = []
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
model,
template,
query,
history,
system,
images,
generation_config=generation_config,
stop_words=stop_words,
adapter_names=adapter_names,
**kwargs)
if len(inputs) == 0:
return '', history
if generation_info is None:
generation_info = {}
generation_info['num_prompt_tokens'] = token_len

# agent support
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
if is_observation:
history[-1][-1] = history[-1][-1] + query
query = None

example = {
'query': query,
'history': history,
'system': system,
'images': images, # for vl. str.
'tools': kwargs.pop('tools', None)
}
template.model = model
inputs, tokenizer_kwargs = template.encode(example)

truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
if len(inputs) == 0 and truncation_strategy == 'delete':
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
return '', history

inputs.pop('labels', None)
tokenizer = template.tokenizer
device = next(model.parameters()).device
if 'input_ids' in inputs:
input_ids = torch.tensor(inputs['input_ids'])[None]
inputs['input_ids'] = input_ids
token_len = input_ids.shape[1]
if 'inputs_embeds' in inputs:
inputs_embeds = inputs['inputs_embeds'][None]
inputs['inputs_embeds'] = inputs_embeds
token_len = inputs_embeds.shape[1]

inputs['attention_mask'] = torch.ones(token_len)[None]
if 'token_type_ids' in inputs:
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
model.eval()
if generation_config is None:
generation_config = getattr(model, 'generation_config', None)
generation_config = deepcopy(generation_config)
if stream and not verbose:
logger.warning('Please set verbose to True to support TextStreamer, or use `inference_stream.`')
stream = False
streamer = None
tokenizer = template.tokenizer
if stream:
streamer = TextStreamer(tokenizer, skip_prompt=True)
if verbose:
Expand All @@ -762,37 +770,12 @@ def inference(model: PreTrainedModel,
print(
f'{prompt_prefix}{safe_tokenizer_decode(tokenizer, input_ids[0], **tokenizer_kwargs)}{output_prefix}',
end='')
elif 'query' in example:
query = example['query']
else:
print(f'[QUERY]{query}\n{output_prefix}', end='')
if tokenizer.eos_token_id is not None:
generation_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
generation_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
generation_config.bos_token_id = tokenizer.bos_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
max_length = get_max_model_len(model.config)
if max_length and token_len + generation_config.max_new_tokens > max_length:
generation_config.max_new_tokens = max_length - token_len
if generation_config.max_new_tokens <= 0:
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
inputs = to_device(inputs, device)
if generation_info is not None:
generation_info['num_prompt_tokens'] = token_len
if 'inputs_embeds' in inputs:
inputs.pop('input_ids', None)
if adapter_names is not None:
inputs['adapter_names'] = adapter_names
generate_ids = model.generate(
streamer=streamer, generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs)

generate_ids = model.generate(streamer=streamer, **inputs)
generate_ids = template.get_generate_ids(generate_ids, token_len)
if generation_info is not None:
generation_info['num_generated_tokens'] = len(generate_ids)
generation_info['num_generated_tokens'] = len(generate_ids)
if verbose and stream is False:
response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
print(response)
Expand Down
Loading

0 comments on commit da58255

Please sign in to comment.