diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index af562651e9e387..b72d68de5d4e57 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -103,7 +103,7 @@ def _generate_supported_model_class_names( "deberta", "deberta-v2", "distilbert", - "donut", + "donut-swin", "electra", "gpt2", "gpt_neo", diff --git a/tests/models/donut/test_feature_extraction_donut.py b/tests/models/donut/test_feature_extraction_donut.py index 6391ad87777379..9807d5b46cb235 100644 --- a/tests/models/donut/test_feature_extraction_donut.py +++ b/tests/models/donut/test_feature_extraction_donut.py @@ -72,7 +72,7 @@ def prepare_feat_extract_dict(self): "do_pad": self.do_pad, "do_normalize": self.do_normalize, "image_mean": self.image_mean, - "image_std": self.image_std, + "image_std": self.image_std, } @@ -196,4 +196,4 @@ def test_call_pytorch(self): self.feature_extract_tester.size[1], self.feature_extract_tester.size[0], ), - ) \ No newline at end of file + ) diff --git a/tests/models/donut/test_modeling_donut_swin.py b/tests/models/donut/test_modeling_donut_swin.py index f6772ac233d7ee..f909d961880a97 100644 --- a/tests/models/donut/test_modeling_donut_swin.py +++ b/tests/models/donut/test_modeling_donut_swin.py @@ -22,8 +22,8 @@ import unittest from transformers import DonutSwinConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device -from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_torch_fx_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor @@ -36,9 +36,6 @@ from transformers import DonutSwinModel from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST -if is_vision_available(): - from transformers import AutoFeatureExtractor - if is_torch_fx_available(): from transformers.utils.fx import symbolic_trace @@ -465,15 +462,3 @@ def flatten_output(output): torch.allclose(model_output[i], loaded_output[i]), f"serialized model {i}th output doesn't match model {i}th output for {model_class}", ) - - -@require_vision -@require_torch -class DonutSwinModelIntegrationTest(unittest.TestCase): - @cached_property - def default_feature_extractor(self): - return AutoFeatureExtractor.from_pretrained("naver-clova-ix/donut-base") if is_vision_available() else None - - @slow - def test_inference_image_classification_head(self): - raise NotImplementedError("To do")