Skip to content

Commit

Permalink
Use commit hash to look in cache instead of calling head (#18534)
Browse files Browse the repository at this point in the history
* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address Julien's comments

Co-authored-by: Julien Chaumond <julien@huggingface.co>
  • Loading branch information
sgugger and julien-c authored Aug 10, 2022
1 parent 6eb5145 commit 0d0aada
Show file tree
Hide file tree
Showing 15 changed files with 221 additions and 23 deletions.
23 changes: 22 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
from .utils import (
CONFIG_NAME,
PushToHubMixin,
cached_file,
copy_func,
extract_commit_hash,
is_torch_available,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -343,6 +351,8 @@ def __init__(self, **kwargs):

# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)

# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
Expand Down Expand Up @@ -539,6 +549,8 @@ def get_config_dict(
original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in config_dict:
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

# That config file may point us toward another config file to use.
if "configuration_files" in config_dict:
Expand All @@ -564,6 +576,7 @@ def _get_config_dict(
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -599,7 +612,9 @@ def _get_config_dict(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
Expand All @@ -616,6 +631,7 @@ def _get_config_dict(
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
Expand Down Expand Up @@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)

Expand Down Expand Up @@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]:
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def from_pretrained(
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -625,11 +626,15 @@ def from_pretrained(
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype

Expand Down Expand Up @@ -682,6 +687,7 @@ def from_pretrained(
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -748,6 +754,7 @@ def from_pretrained(
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)

# init random models
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -2191,11 +2192,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
Expand Down Expand Up @@ -2253,6 +2258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -2320,6 +2326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)

config.name_or_path = pretrained_model_name_or_path
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -1918,6 +1919,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
Expand Down Expand Up @@ -2004,6 +2008,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -2078,6 +2083,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)

# load pt weights early so that we know which dtype to init the model under
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
Expand Down Expand Up @@ -389,7 +389,8 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
commit_hash = kwargs.get("_commit_hash", None)
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
Expand All @@ -399,13 +400,19 @@ def get_tokenizer_config(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)

with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
result = json.load(reader)
result["_commit_hash"] = commit_hash
return result


class AutoTokenizer:
Expand Down Expand Up @@ -532,6 +539,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in tokenizer_config:
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = None
if "auto_map" in tokenizer_config:
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,12 @@ def pipeline(
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
# this is to keep BC).
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
hub_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
}

if task is None and model is None:
raise RuntimeError(
Expand All @@ -583,8 +588,10 @@ def pipeline(
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash

custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
Expand Down Expand Up @@ -639,6 +646,7 @@ def pipeline(
)
if config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash

if device_map is not None:
if "device_map" in model_kwargs:
Expand Down Expand Up @@ -672,6 +680,7 @@ def pipeline(
)

model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash

load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Iterator, List, Union
from unittest import mock

import huggingface_hub
from transformers import logging as transformers_logging

from .deepspeed import is_deepspeed_available
Expand Down Expand Up @@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e


class RequestCounter:
"""
Helper class that will count all requests made online.
"""

def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
self.old_request = huggingface_hub.file_download.requests.request
huggingface_hub.file_download.requests.request = self.new_request
return self

def __exit__(self, *args, **kwargs):
huggingface_hub.file_download.requests.request = self.old_request

def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1

return self.old_request(method=method, **kwargs)
Loading

0 comments on commit 0d0aada

Please sign in to comment.