From 8ffd6e452a4e9fb43d0076b34daa7da7be1ab680 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 2 Sep 2022 11:06:59 -0400 Subject: [PATCH 1/3] Further reduce the number of alls to head for cached models/tokenizers/pipelines --- src/transformers/utils/hub.py | 5 ++++- tests/models/auto/test_modeling_auto.py | 3 +-- tests/models/auto/test_modeling_tf_auto.py | 3 +-- tests/models/auto/test_tokenization_auto.py | 3 +-- tests/pipelines/test_pipelines_common.py | 3 +-- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 9b1e9a5b85eb02..01666a96a0c815 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -244,6 +244,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h with open(os.path.join(model_cache, "refs", revision)) as f: commit_hash = f.read() + if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)): + return -1 + cached_shas = os.listdir(os.path.join(model_cache, "snapshots")) if commit_hash not in cached_shas: # No cache for this revision and we won't try to return a random revision @@ -352,7 +355,7 @@ def cached_file( # If the file is cached under that commit hash, we return it directly. resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash) if resolved_file is not None: - return resolved_file + return resolved_file if resolved_file != -1 else None user_agent = http_user_agent(user_agent) try: diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 2e1e51a81daac6..91222c4d0062ee 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -370,6 +370,5 @@ def test_cached_model_has_minimum_calls_to_head(self): with RequestCounter() as counter: _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") self.assertEqual(counter.get_request_count, 0) - # There is no pytorch_model.bin so we still get one call for this one. - self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) diff --git a/tests/models/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py index bbde4f582bdfb0..2b4b625e2305c3 100644 --- a/tests/models/auto/test_modeling_tf_auto.py +++ b/tests/models/auto/test_modeling_tf_auto.py @@ -303,6 +303,5 @@ def test_cached_model_has_minimum_calls_to_head(self): with RequestCounter() as counter: _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") self.assertEqual(counter.get_request_count, 0) - # There is no pytorch_model.bin so we still get one call for this one. - self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 830362e29cd654..020eea72cdda21 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -349,6 +349,5 @@ def test_cached_tokenizer_has_minimum_calls_to_head(self): with RequestCounter() as counter: _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") self.assertEqual(counter.get_request_count, 0) - # We still have one extra call because the model does not have a added_tokens.json file - self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 5e0296c7136725..ea32f5cac4d467 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -884,8 +884,7 @@ def test_cached_pipeline_has_minimum_calls_to_head(self): with RequestCounter() as counter: _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") self.assertEqual(counter.get_request_count, 0) - # We still have one extra call because the model does not have a added_tokens.json file - self.assertEqual(counter.head_request_count, 2) + self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) From 562d2b142452851d402c84aa36893f385cce28b4 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 2 Sep 2022 12:09:44 -0400 Subject: [PATCH 2/3] Fix tests --- src/transformers/utils/hub.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 01666a96a0c815..486b1b6ad0c5df 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -341,7 +341,10 @@ def cached_file( resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) if not os.path.isfile(resolved_file): if _raise_exceptions_for_missing_entries: - raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.") + raise EnvironmentError( + f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " + f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files." + ) else: return None return resolved_file @@ -355,7 +358,12 @@ def cached_file( # If the file is cached under that commit hash, we return it directly. resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash) if resolved_file is not None: - return resolved_file if resolved_file != -1 else None + if resolved_file != -1: + return resolved_file + elif not _raise_exceptions_for_missing_entries: + return None + else: + raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.") user_agent = http_user_agent(user_agent) try: From cc0ddcdb1b2d760380e2a5e6e9364868e816a9fa Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 6 Sep 2022 11:45:21 -0400 Subject: [PATCH 3/3] Address review comments --- src/transformers/utils/hub.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 486b1b6ad0c5df..31c3257ffd3646 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -120,6 +120,9 @@ def is_offline_mode(): HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" +# Return value when trying to load a file from cache but the file does not exist in the distant repo. +_CACHED_NO_EXIST = object() + def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: """ @@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None): """ Explores the cache to return the latest cached file for a given revision. + + Args: + cache_dir (`str` or `os.PathLike`): The folder where the cached files lie. + repo_id (`str`): The ID of the repo on huggingface.co. + filename (`str`): The filename to look for inside `repo_id`. + revision (`str`, *optional*): + The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is + provided either. + commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache. + + Returns: + `Optional[str]` or `_CACHED_NO_EXIST`: + Will return `None` if the file was not cached. Otherwise: + - The exact path to the cached file if it's found in the cache + - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was + cached. """ if commit_hash is not None and revision is not None: raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.") @@ -245,7 +264,7 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h commit_hash = f.read() if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)): - return -1 + return _CACHED_NO_EXIST cached_shas = os.listdir(os.path.join(model_cache, "snapshots")) if commit_hash not in cached_shas: @@ -358,7 +377,7 @@ def cached_file( # If the file is cached under that commit hash, we return it directly. resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash) if resolved_file is not None: - if resolved_file != -1: + if resolved_file is not _CACHED_NO_EXIST: return resolved_file elif not _raise_exceptions_for_missing_entries: return None