diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 67fc81d280a79b..995296485b9402 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -186,6 +186,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] TFAutoModelForImageClassification +## TFAutoModelForSemanticSegmentation + +[[autodoc]] TFAutoModelForSemanticSegmentation + ## TFAutoModelForMaskedLM [[autodoc]] TFAutoModelForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index be2be2727f0146..2f53db07f078f0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2105,6 +2105,7 @@ "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", "TFAutoModelForQuestionAnswering", + "TFAutoModelForSemanticSegmentation", "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", @@ -4599,6 +4600,7 @@ TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, + TFAutoModelForSemanticSegmentation, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForSpeechSeq2Seq, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 139d4feda336e0..ec253f6037a3d3 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -128,6 +128,7 @@ "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", "TFAutoModelForQuestionAnswering", + "TFAutoModelForSemanticSegmentation", "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", @@ -271,6 +272,7 @@ TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, + TFAutoModelForSemanticSegmentation, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForSpeechSeq2Seq, diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 6df601ca646af3..fec5ffe700808a 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -362,6 +362,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFAutoModelForSemanticSegmentation(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForSeq2SeqLM(metaclass=DummyObject): _backends = ["tf"]