Skip to content

Commit

Permalink
[FX] _generate_dummy_input supports audio-classification models for l…
Browse files Browse the repository at this point in the history
…abels (huggingface#18580)

* Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not

* Use ENV_VARS_TRUE_VALUES
  • Loading branch information
michaelbenayoun authored and oneraghavan committed Sep 26, 2022
1 parent 2abc343 commit 4ee3590
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import math
import operator
import os
import random
import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
Expand Down Expand Up @@ -48,11 +49,12 @@
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata


logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES


def _generate_supported_model_class_names(
Expand Down Expand Up @@ -678,7 +680,12 @@ def _generate_dummy_input(
if input_name in ["labels", "start_positions", "end_positions"]:

batch_size = shape[0]
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
if model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
Expand Down Expand Up @@ -710,11 +717,6 @@ def _generate_dummy_input(
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)

elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
Expand All @@ -725,7 +727,9 @@ def _generate_dummy_input(
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class_name} not supported yet.")
raise NotImplementedError(
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
)
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = getattr(model.config, "image_size", None)
Expand Down Expand Up @@ -846,7 +850,8 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")

return rv

Expand Down

0 comments on commit 4ee3590

Please sign in to comment.