Skip to content

Commit

Permalink
Use commit hash to look in cache instead of calling head (huggingface…
Browse files Browse the repository at this point in the history
…#18534)

* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address Julien's comments

Co-authored-by: Julien Chaumond <julien@huggingface.co>
  • Loading branch information
2 people authored and amyeroberts committed Oct 18, 2022
1 parent 28a6894 commit fa72984
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 17 deletions.
12 changes: 11 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
from .utils import (
CONFIG_NAME,
PushToHubMixin,
cached_file,
copy_func,
extract_commit_hash,
is_torch_available,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -607,7 +615,9 @@ def _get_config_dict(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def from_pretrained(
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -691,6 +692,7 @@ def from_pretrained(
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -2458,6 +2459,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -2525,6 +2527,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)

config.name_or_path = pretrained_model_name_or_path
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
add_end_docstrings,
cached_file,
copy_func,
get_file_from_repo,
extract_commit_hash,
is_flax_available,
is_offline_mode,
is_tf_available,
Expand Down Expand Up @@ -1705,6 +1705,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
Expand Down Expand Up @@ -1739,7 +1743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)

if len(unresolved_files) > 0:
logger.info(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
cached_file,
default_cache_path,
define_sagemaker_information,
extract_commit_hash,
get_cached_models,
get_file_from_repo,
get_full_repo_name,
Expand Down
56 changes: 45 additions & 11 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
Expand Down Expand Up @@ -203,11 +204,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua


def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash

search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None


def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
"""
Explores the cache to return the latest cached file for a given revision.
"""
if revision is None:
if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
if revision is None and commit_hash is None:
revision = "main"

model_id = repo_id.replace("/", "--")
Expand All @@ -219,18 +236,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
if not os.path.isdir(os.path.join(model_cache, subfolder)):
return None

# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(model_cache, "refs", revision)) as f:
revision = f.read()
if commit_hash is None:
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read()

cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if revision not in cached_shas:
if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None

cached_file = os.path.join(model_cache, "snapshots", revision, filename)
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
return cached_file if os.path.isfile(cached_file) else None


Expand Down Expand Up @@ -268,8 +286,9 @@ def cached_file(
local_files_only: bool = False,
subfolder: str = "",
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries=True,
_raise_exceptions_for_connection_errors=True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
Expand Down Expand Up @@ -321,6 +340,13 @@ def cached_file(
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
# Private arguments
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
Expand All @@ -342,6 +368,13 @@ def cached_file(
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

if _commit_hash is not None:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None:
return resolved_file

user_agent = http_user_agent(user_agent)
try:
# Load from URL or cache if already cached
Expand Down Expand Up @@ -852,6 +885,7 @@ def get_checkpoint_shard_files(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
Expand Down
3 changes: 2 additions & 1 deletion tests/models/auto/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,6 @@ def test_cached_model_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
3 changes: 2 additions & 1 deletion tests/models/auto/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,6 @@ def test_cached_model_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
3 changes: 2 additions & 1 deletion tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,5 +349,6 @@ def test_cached_tokenizer_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
3 changes: 2 additions & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,8 @@ def test_cached_pipeline_has_minimum_calls_to_head(self):
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)


Expand Down

0 comments on commit fa72984

Please sign in to comment.