From 0d0aada56444ad554021947addaa035feb55948f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 10 Aug 2022 11:55:18 -0400 Subject: [PATCH] Use commit hash to look in cache instead of calling head (#18534) * Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond * Address Julien's comments Co-authored-by: Julien Chaumond --- src/transformers/configuration_utils.py | 23 +++++++- src/transformers/modeling_flax_utils.py | 7 +++ src/transformers/modeling_tf_utils.py | 7 +++ src/transformers/modeling_utils.py | 6 ++ .../models/auto/tokenization_auto.py | 15 ++++- src/transformers/pipelines/__init__.py | 11 +++- src/transformers/testing_utils.py | 28 +++++++++ src/transformers/tokenization_utils_base.py | 16 +++++- src/transformers/utils/__init__.py | 1 + src/transformers/utils/hub.py | 57 +++++++++++++++---- tests/models/auto/test_modeling_auto.py | 19 +++++++ tests/models/auto/test_modeling_tf_auto.py | 19 +++++++ tests/models/auto/test_tokenization_auto.py | 12 ++++ tests/pipelines/test_pipelines_common.py | 11 ++++ tests/test_configuration_common.py | 12 ++-- 15 files changed, 221 insertions(+), 23 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b924cec9ae021c..41503255ac2adb 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -27,7 +27,15 @@ from . import __version__ from .dynamic_module_utils import custom_object_save -from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging +from .utils import ( + CONFIG_NAME, + PushToHubMixin, + cached_file, + copy_func, + extract_commit_hash, + is_torch_available, + logging, +) logger = logging.get_logger(__name__) @@ -343,6 +351,8 @@ def __init__(self, **kwargs): # Name or path to the pretrained checkpoint self._name_or_path = str(kwargs.pop("name_or_path", "")) + # Config hash + self._commit_hash = kwargs.pop("_commit_hash", None) # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -539,6 +549,8 @@ def get_config_dict( original_kwargs = copy.deepcopy(kwargs) # Get config dict associated with the base config file config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + if "_commit_hash" in config_dict: + original_kwargs["_commit_hash"] = config_dict["_commit_hash"] # That config file may point us toward another config file to use. if "configuration_files" in config_dict: @@ -564,6 +576,7 @@ def _get_config_dict( subfolder = kwargs.pop("subfolder", "") from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) if trust_remote_code is True: logger.warning( @@ -599,7 +612,9 @@ def _get_config_dict( user_agent=user_agent, revision=revision, subfolder=subfolder, + _commit_hash=commit_hash, ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -616,6 +631,7 @@ def _get_config_dict( try: # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) + config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError( f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." @@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": # We remove them so they don't appear in `return_unused_kwargs`. kwargs.pop("_from_auto", None) kwargs.pop("_from_pipeline", None) + # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. + if "_commit_hash" in kwargs and "_commit_hash" in config_dict: + kwargs["_commit_hash"] = config_dict["_commit_hash"] config = cls(**config_dict) @@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]: output["model_type"] = self.__class__.model_type if "_auto_class" in output: del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index af75b418cad23e..683e25631c0f44 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -595,6 +595,7 @@ def from_pretrained( from_auto_class = kwargs.pop("_from_auto", False) _do_init = kwargs.pop("_do_init", True) subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) if trust_remote_code is True: logger.warning( @@ -625,11 +626,15 @@ def from_pretrained( revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, + _commit_hash=commit_hash, **kwargs, ) else: model_kwargs = kwargs + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype @@ -682,6 +687,7 @@ def from_pretrained( revision=revision, subfolder=subfolder, _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, ) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) @@ -748,6 +754,7 @@ def from_pretrained( use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, + _commit_hash=commit_hash, ) # init random models diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 68ee4117a2f9db..3587354b9326a9 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2161,6 +2161,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) if trust_remote_code is True: logger.warning( @@ -2191,11 +2192,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, + _commit_hash=commit_hash, **kwargs, ) else: model_kwargs = kwargs + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. is_sharded = False @@ -2253,6 +2258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): revision=revision, subfolder=subfolder, _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, ) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) @@ -2320,6 +2326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, + _commit_hash=commit_hash, ) config.name_or_path = pretrained_model_name_or_path diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1d895baecfedac..d77258c94ea089 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1840,6 +1840,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P load_in_8bit = kwargs.pop("load_in_8bit", False) int8_threshold = kwargs.pop("int8_threshold", 6.0) subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) if trust_remote_code is True: logger.warning( @@ -1918,6 +1919,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: model_kwargs = kwargs + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. is_sharded = False @@ -2004,6 +2008,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder, _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, ) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) @@ -2078,6 +2083,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, revision=revision, subfolder=subfolder, + _commit_hash=commit_hash, ) # load pt weights early so that we know which dtype to init the model under diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index d8759fd4e7842e..8ece13b79fe3fa 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -25,7 +25,7 @@ from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...tokenization_utils_fast import PreTrainedTokenizerFast -from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging +from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging from ..encoder_decoder import EncoderDecoderConfig from .auto_factory import _LazyAutoMapping from .configuration_auto import ( @@ -389,7 +389,8 @@ def get_tokenizer_config( tokenizer.save_pretrained("tokenizer-test") tokenizer_config = get_tokenizer_config("tokenizer-test") ```""" - resolved_config_file = get_file_from_repo( + commit_hash = kwargs.get("_commit_hash", None) + resolved_config_file = cached_file( pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, cache_dir=cache_dir, @@ -399,13 +400,19 @@ def get_tokenizer_config( use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, ) if resolved_config_file is None: logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") return {} + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) with open(resolved_config_file, encoding="utf-8") as reader: - return json.load(reader) + result = json.load(reader) + result["_commit_hash"] = commit_hash + return result class AutoTokenizer: @@ -532,6 +539,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # Next, let's try to use the tokenizer_config file to get the tokenizer class. tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + if "_commit_hash" in tokenizer_config: + kwargs["_commit_hash"] = tokenizer_config["_commit_hash"] config_tokenizer_class = tokenizer_config.get("tokenizer_class") tokenizer_auto_map = None if "auto_map" in tokenizer_config: diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index dfa75768d8f811..5752790aa9614b 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -557,7 +557,12 @@ def pipeline( # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs, # this is to keep BC). use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token) - hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code} + hub_kwargs = { + "revision": revision, + "use_auth_token": use_auth_token, + "trust_remote_code": trust_remote_code, + "_commit_hash": None, + } if task is None and model is None: raise RuntimeError( @@ -583,8 +588,10 @@ def pipeline( # Instantiate config if needed if isinstance(config, str): config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs) + hub_kwargs["_commit_hash"] = config._commit_hash elif config is None and isinstance(model, str): config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) + hub_kwargs["_commit_hash"] = config._commit_hash custom_tasks = {} if config is not None and len(getattr(config, "custom_pipelines", {})) > 0: @@ -639,6 +646,7 @@ def pipeline( ) if config is None and isinstance(model, str): config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) + hub_kwargs["_commit_hash"] = config._commit_hash if device_map is not None: if "device_map" in model_kwargs: @@ -672,6 +680,7 @@ def pipeline( ) model_config = model.config + hub_kwargs["_commit_hash"] = model.config._commit_hash load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 80f7bf9c863c87..d21f353a60a8f5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -31,6 +31,7 @@ from typing import Iterator, List, Union from unittest import mock +import huggingface_hub from transformers import logging as transformers_logging from .deepspeed import is_deepspeed_available @@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False): raise SubprocessCallException( f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" ) from e + + +class RequestCounter: + """ + Helper class that will count all requests made online. + """ + + def __enter__(self): + self.head_request_count = 0 + self.get_request_count = 0 + self.other_request_count = 0 + self.old_request = huggingface_hub.file_download.requests.request + huggingface_hub.file_download.requests.request = self.new_request + return self + + def __exit__(self, *args, **kwargs): + huggingface_hub.file_download.requests.request = self.old_request + + def new_request(self, method, **kwargs): + if method == "GET": + self.get_request_count += 1 + elif method == "HEAD": + self.head_request_count += 1 + else: + self.other_request_count += 1 + + return self.old_request(method=method, **kwargs) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f85dc73cb659cb..566fd3fbf92b05 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -42,7 +42,7 @@ add_end_docstrings, cached_file, copy_func, - get_file_from_repo, + extract_commit_hash, is_flax_available, is_offline_mode, is_tf_available, @@ -1651,6 +1651,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], subfolder = kwargs.pop("subfolder", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} if from_pipeline is not None: @@ -1690,7 +1691,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], if "tokenizer_file" in vocab_files: # Try to get the tokenizer config to see if there are versioned tokenizer files. fast_tokenizer_file = FULL_TOKENIZER_FILE - resolved_config_file = get_file_from_repo( + resolved_config_file = cached_file( pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, cache_dir=cache_dir, @@ -1701,7 +1702,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], revision=revision, local_files_only=local_files_only, subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) if resolved_config_file is not None: with open(resolved_config_file, encoding="utf-8") as reader: tokenizer_config = json.load(reader) @@ -1730,7 +1736,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], subfolder=subfolder, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, ) + commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) if len(unresolved_files) > 0: logger.info( @@ -1763,6 +1771,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], use_auth_token=use_auth_token, cache_dir=cache_dir, local_files_only=local_files_only, + _commit_hash=commit_hash, **kwargs, ) @@ -1776,6 +1785,7 @@ def _from_pretrained( use_auth_token=None, cache_dir=None, local_files_only=False, + _commit_hash=None, **kwargs ): # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json @@ -1791,6 +1801,7 @@ def _from_pretrained( use_auth_token=use_auth_token, cache_dir=cache_dir, local_files_only=local_files_only, + _commit_hash=_commit_hash, **(copy.deepcopy(kwargs)), ) else: @@ -1823,6 +1834,7 @@ def _from_pretrained( use_auth_token=use_auth_token, cache_dir=cache_dir, local_files_only=local_files_only, + _commit_hash=_commit_hash, ) config_tokenizer_class = config.tokenizer_class except (OSError, ValueError, KeyError): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 27276aa4946d5e..2a2a4c41257492 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -63,6 +63,7 @@ cached_file, default_cache_path, define_sagemaker_information, + extract_commit_hash, get_cached_models, get_file_from_repo, get_full_repo_name, diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 07164e735db901..00f9c277c41773 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -38,6 +38,7 @@ whoami, ) from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT +from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests.exceptions import HTTPError from transformers.utils.logging import tqdm @@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: return ua -def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None): +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + + search = re.search(r"snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None): """ Explores the cache to return the latest cached file for a given revision. """ - if revision is None: + if commit_hash is not None and revision is not None: + raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.") + if revision is None and commit_hash is None: revision = "main" model_id = repo_id.replace("/", "--") @@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None): if not os.path.isdir(os.path.join(model_cache, subfolder)): return None - # Resolve refs (for instance to convert main to the associated commit sha) - cached_refs = os.listdir(os.path.join(model_cache, "refs")) - if revision in cached_refs: - with open(os.path.join(model_cache, "refs", revision)) as f: - revision = f.read() + if commit_hash is None: + # Resolve refs (for instance to convert main to the associated commit sha) + cached_refs = os.listdir(os.path.join(model_cache, "refs")) + if revision in cached_refs: + with open(os.path.join(model_cache, "refs", revision)) as f: + commit_hash = f.read() cached_shas = os.listdir(os.path.join(model_cache, "snapshots")) - if revision not in cached_shas: + if commit_hash not in cached_shas: # No cache for this revision and we won't try to return a random revision return None - cached_file = os.path.join(model_cache, "snapshots", revision, filename) + cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename) return cached_file if os.path.isfile(cached_file) else None @@ -265,8 +283,9 @@ def cached_file( local_files_only: bool = False, subfolder: str = "", user_agent: Optional[Union[str, Dict[str, str]]] = None, - _raise_exceptions_for_missing_entries=True, - _raise_exceptions_for_connection_errors=True, + _raise_exceptions_for_missing_entries: bool = True, + _raise_exceptions_for_connection_errors: bool = True, + _commit_hash: Optional[str] = None, ): """ Tries to locate a file in a local folder and repo, downloads and cache it if necessary. @@ -318,6 +337,13 @@ def cached_file( # Download a model weight from the Hub and cache it. model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin") ```""" + # Private arguments + # _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return + # None. + # _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return + # None. + # _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or + # a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache. if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True @@ -339,6 +365,13 @@ def cached_file( cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + + if _commit_hash is not None: + # If the file is cached under that commit hash, we return it directly. + resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash) + if resolved_file is not None: + return resolved_file + user_agent = http_user_agent(user_agent) try: # Load from URL or cache if already cached @@ -803,6 +836,7 @@ def get_checkpoint_shard_files( user_agent=None, revision=None, subfolder="", + _commit_hash=None, ): """ For a given model: @@ -848,6 +882,7 @@ def get_checkpoint_shard_files( user_agent=user_agent, revision=revision, subfolder=subfolder, + _commit_hash=_commit_hash, ) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 3731d70f5bb5af..2e1e51a81daac6 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, + RequestCounter, require_scatter, require_torch, slow, @@ -354,3 +355,21 @@ def test_model_from_tf_suggestion(self): def test_model_from_flax_suggestion(self): with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + def test_cached_model_has_minimum_calls_to_head(self): + # Make sure we have cached the model. + _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with RequestCounter() as counter: + _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") + self.assertEqual(counter.get_request_count, 0) + self.assertEqual(counter.head_request_count, 1) + self.assertEqual(counter.other_request_count, 0) + + # With a sharded checkpoint + _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") + with RequestCounter() as counter: + _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") + self.assertEqual(counter.get_request_count, 0) + # There is no pytorch_model.bin so we still get one call for this one. + self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.other_request_count, 0) diff --git a/tests/models/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py index a803a3451107e2..bbde4f582bdfb0 100644 --- a/tests/models/auto/test_modeling_tf_auto.py +++ b/tests/models/auto/test_modeling_tf_auto.py @@ -21,6 +21,7 @@ from transformers.testing_utils import ( DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, + RequestCounter, require_tensorflow_probability, require_tf, slow, @@ -287,3 +288,21 @@ def test_model_file_not_found(self): def test_model_from_pt_suggestion(self): with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"): _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + + def test_cached_model_has_minimum_calls_to_head(self): + # Make sure we have cached the model. + _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with RequestCounter() as counter: + _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") + self.assertEqual(counter.get_request_count, 0) + self.assertEqual(counter.head_request_count, 1) + self.assertEqual(counter.other_request_count, 0) + + # With a sharded checkpoint + _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") + with RequestCounter() as counter: + _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") + self.assertEqual(counter.get_request_count, 0) + # There is no pytorch_model.bin so we still get one call for this one. + self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.other_request_count, 0) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 1e1abb9245842c..830362e29cd654 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -48,6 +48,7 @@ DUMMY_DIFF_TOKENIZER_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, + RequestCounter, require_tokenizers, slow, ) @@ -213,6 +214,7 @@ def test_auto_tokenizer_fast_no_slow(self): def test_get_tokenizer_config(self): # Check we can load the tokenizer config of an online model. config = get_tokenizer_config("bert-base-cased") + _ = config.pop("_commit_hash", None) # If we ever update bert-base-cased tokenizer config, this dict here will need to be updated. self.assertEqual(config, {"do_lower_case": False}) @@ -340,3 +342,13 @@ def test_revision_not_found(self): EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)" ): _ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa") + + def test_cached_tokenizer_has_minimum_calls_to_head(self): + # Make sure we have cached the tokenizer. + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + self.assertEqual(counter.get_request_count, 0) + # We still have one extra call because the model does not have a added_tokens.json file + self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.other_request_count, 0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 5d5c8fa2333eb6..5e0296c7136725 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -49,6 +49,7 @@ TOKEN, USER, CaptureLogger, + RequestCounter, is_pipeline_test, is_staging_test, nested_simplify, @@ -877,6 +878,16 @@ def test_dynamic_pipeline(self): [{"label": "LABEL_0", "score": 0.505}], ) + def test_cached_pipeline_has_minimum_calls_to_head(self): + # Make sure we have cached the pipeline. + _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") + with RequestCounter() as counter: + _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") + self.assertEqual(counter.get_request_count, 0) + # We still have one extra call because the model does not have a added_tokens.json file + self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.other_request_count, 0) + @require_torch @is_staging_test diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 397346c7deec77..5447fb6afb70eb 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -246,7 +246,7 @@ def test_push_to_hub(self): config.push_to_hub("test-config", use_auth_token=self._token) new_config = BertConfig.from_pretrained(f"{USER}/test-config") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) @@ -258,7 +258,7 @@ def test_push_to_hub(self): config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token) new_config = BertConfig.from_pretrained(f"{USER}/test-config") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) @@ -269,7 +269,7 @@ def test_push_to_hub_in_organization(self): config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token) new_config = BertConfig.from_pretrained("valid_org/test-config-org") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) @@ -283,7 +283,7 @@ def test_push_to_hub_in_organization(self): ) new_config = BertConfig.from_pretrained("valid_org/test-config-org") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) @@ -323,7 +323,9 @@ def test_config_common_kwargs_is_complete(self): base_config = PretrainedConfig() missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] # If this part of the test fails, you have arguments to addin config_common_kwargs above. - self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"]) + self.assertListEqual( + missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"] + ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: raise ValueError(