Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generation config in ORTModel #651

Merged
merged 5 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

Expand Down Expand Up @@ -339,6 +339,7 @@ def __init__(
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs
):
"""
Expand All @@ -357,6 +358,9 @@ def __init__(
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
# TODO: remove at version 2.0
def show_deprecated_argument(arg_name):
Expand Down Expand Up @@ -399,6 +403,10 @@ def show_deprecated_argument(arg_name):
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name

if generation_config is None:
self.generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config

@staticmethod
def load_model(
decoder_path: Union[str, Path],
Expand Down Expand Up @@ -626,13 +634,28 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

return cls(
model[0],
config,
decoder_with_past_session=model[1],
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
)

@classmethod
Expand Down Expand Up @@ -784,3 +807,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True
36 changes: 33 additions & 3 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import logging
import re
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -26,7 +25,7 @@

import numpy as np
import torch
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

Expand All @@ -37,7 +36,7 @@
from ..exporters.tasks import TasksManager
from ..onnx.utils import _get_external_data_paths
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils.file_utils import find_files_matching_pattern, validate_file_exists
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_decoder import ORTDecoder
Expand Down Expand Up @@ -728,6 +727,7 @@ def __init__(
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
"""
Expand All @@ -748,6 +748,9 @@ def __init__(
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
# TODO: remove at version 2.0
def show_deprecated_argument(arg_name):
Expand Down Expand Up @@ -804,6 +807,10 @@ def show_deprecated_argument(arg_name):
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config

@abstractmethod
def _initialize_encoder(
self,
Expand Down Expand Up @@ -1076,13 +1083,28 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

return cls(
*model[:2],
config,
decoder_with_past_session=model[2],
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
)

@classmethod
Expand Down Expand Up @@ -1286,6 +1308,10 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
)
return reordered_past

def can_generate(self):
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin):
"""
Expand Down Expand Up @@ -1398,3 +1424,7 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True