Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further reduce the number of alls to head for cached objects #18871

Merged
merged 3 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def is_offline_mode():
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"

# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()


def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Expand Down Expand Up @@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
"""
Explores the cache to return the latest cached file for a given revision.

Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
repo_id (`str`): The ID of the repo on huggingface.co.
filename (`str`): The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.

Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
Expand All @@ -244,6 +263,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read()

if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
return _CACHED_NO_EXIST

cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision
Expand Down Expand Up @@ -338,7 +360,10 @@ def cached_file(
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
)
else:
return None
return resolved_file
Expand All @@ -352,7 +377,12 @@ def cached_file(
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None:
return resolved_file
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file
elif not _raise_exceptions_for_missing_entries:
return None
else:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")

user_agent = http_user_agent(user_agent)
try:
Expand Down
3 changes: 1 addition & 2 deletions tests/models/auto/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,5 @@ def test_cached_model_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
3 changes: 1 addition & 2 deletions tests/models/auto/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,5 @@ def test_cached_model_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
3 changes: 1 addition & 2 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,5 @@ def test_cached_tokenizer_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
3 changes: 1 addition & 2 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,7 @@ def test_cached_pipeline_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)


Expand Down