Skip to content

Commit

Permalink
Support generation config in ORTModel (#651)
Browse files Browse the repository at this point in the history
* support generation config

* add can_generate method in ORTModelForConditionalGeneration

* trigger actions

* fix typog

* rollback
  • Loading branch information
fxmarty authored Jan 25, 2023
1 parent d2436cf commit 50e87a6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
29 changes: 28 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

Expand Down Expand Up @@ -339,6 +339,7 @@ def __init__(
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs
):
"""
Expand All @@ -357,6 +358,9 @@ def __init__(
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
self.shared_attributes_init(
decoder_session,
Expand Down Expand Up @@ -400,6 +404,10 @@ def show_deprecated_argument(arg_name):
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config

@staticmethod
def load_model(
decoder_path: Union[str, Path],
Expand Down Expand Up @@ -617,13 +625,28 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")

return cls(
model[0],
config,
decoder_with_past_session=model[1],
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
)

@classmethod
Expand Down Expand Up @@ -777,3 +800,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True
40 changes: 38 additions & 2 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import logging
import re
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -26,7 +25,7 @@

import numpy as np
import torch
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

Expand Down Expand Up @@ -728,6 +727,7 @@ def __init__(
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
"""
Expand All @@ -748,6 +748,9 @@ def __init__(
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
# TODO: remove at version 2.0
def show_deprecated_argument(arg_name):
Expand Down Expand Up @@ -804,6 +807,10 @@ def show_deprecated_argument(arg_name):
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config

@abstractmethod
def _initialize_encoder(
self,
Expand Down Expand Up @@ -1066,13 +1073,28 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")

return cls(
*inference_sessions[:2],
config,
decoder_with_past_session=inference_sessions[2],
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
)

@classmethod
Expand Down Expand Up @@ -1170,6 +1192,12 @@ def to(self, device: Union[torch.device, str, int]):

return self

def can_generate(self):
logger.warning(
"ORTModelForConditionalGeneration is an abstract class and is not meant to be used for generation. Please use ORTModelForSeq2SeqLM or ORTModelForSpeechSeq2Seq."
)
return False


class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin):
"""
Expand Down Expand Up @@ -1278,6 +1306,10 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin):
"""
Expand Down Expand Up @@ -1390,3 +1422,7 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

0 comments on commit 50e87a6

Please sign in to comment.