Skip to content

Commit

Permalink
Clean up hub (#18497)
Browse files Browse the repository at this point in the history
* Clean up utils.hub

* Remove imports

* More fixes

* Last fix
  • Loading branch information
sgugger authored Aug 8, 2022
1 parent a456255 commit 377cdde
Show file tree
Hide file tree
Showing 14 changed files with 67 additions and 708 deletions.
2 changes: 0 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@
"TensorType",
"add_end_docstrings",
"add_start_docstrings",
"cached_path",
"is_apex_available",
"is_datasets_available",
"is_faiss_available",
Expand Down Expand Up @@ -3214,7 +3213,6 @@
TensorType,
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
Expand Down
19 changes: 9 additions & 10 deletions src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand Down Expand Up @@ -91,11 +90,10 @@
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
cached_path,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from .utils import hf_bucket_url, logging
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging


if is_torch_available():
Expand Down Expand Up @@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(

# Initialise TF model
if config_file in aws_config_map:
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
Expand All @@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(

# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
pytorch_checkpoint_path = cached_file(
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)

Expand Down Expand Up @@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
print("-" * 100)

if config_shortcut_name in aws_config_map:
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
else:
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
config_file = config_shortcut_name

if model_shortcut_name in aws_model_maps:
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
model_file = model_shortcut_name

if os.path.isfile(model_shortcut_name):
model_shortcut_name = "converted_model"
Expand Down
18 changes: 4 additions & 14 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@

from huggingface_hub import HfFolder, model_info

from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -219,18 +212,15 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
submodule = "local"
else:
module_file_or_url = hf_bucket_url(
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

try:
# Load from URL or cache if already cached
resolved_module_file = cached_path(
module_file_or_url,
resolved_module_file = cached_file(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,14 @@
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
cached_property,
copy_func,
default_cache_path,
define_sagemaker_information,
filename_to_url,
get_cached_models,
get_file_from_repo,
get_from_cache,
get_full_repo_name,
get_list_of_files,
has_file,
hf_bucket_url,
http_get,
http_user_agent,
is_apex_available,
is_coloredlogs_available,
Expand All @@ -94,7 +88,6 @@
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_local_clone,
is_offline_mode,
is_onnx_available,
is_pandas_available,
Expand All @@ -105,7 +98,6 @@
is_pyctcdecode_available,
is_pytesseract_available,
is_pytorch_quantization_available,
is_remote_url,
is_rjieba_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -141,5 +133,4 @@
torch_only_method,
torch_required,
torch_version,
url_to_filename,
)
77 changes: 27 additions & 50 deletions src/transformers/modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,10 @@
)
from .training_args import ParallelMode
from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
cached_file,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -153,11 +148,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
find_from_standard_name: (*optional*) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
colocated modelcard.
return_unused_kwargs: (*optional*) bool:
- If False, then this function returns just the final model card object.
Expand All @@ -168,59 +158,46 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
Examples:
```python
modelcard = ModelCard.from_pretrained(
"bert-base-uncased"
) # Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained(
"./test/saved_model/"
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
# Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained("bert-base-uncased")
# Model card was saved using *save_pretrained('./test/saved_model/')*
modelcard = ModelCard.from_pretrained("./test/saved_model/")
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
```"""
# This imports every model so let's do it dynamically here.
from transformers.models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP

cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
from_pipeline = kwargs.pop("_from_pipeline", None)

user_agent = {"file_type": "model_card"}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
resolved_model_card_file = pretrained_model_name_or_path
is_local = True
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)

if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)

try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(
model_card_file, cache_dir=cache_dir, proxies=proxies, user_agent=user_agent
)
if resolved_model_card_file == model_card_file:
logger.info(f"loading model card file {model_card_file}")
else:
logger.info(f"loading model card file {model_card_file} from cache at {resolved_model_card_file}")
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_file(
pretrained_model_name_or_path,
filename=MODEL_CARD_NAME,
cache_dir=cache_dir,
proxies=proxies,
user_agent=user_agent,
)
if is_local:
logger.info(f"loading model card file {resolved_model_card_file}")
else:
logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)

except (EnvironmentError, json.JSONDecodeError):
# We fall back on creating an empty model card
modelcard = cls()
except (EnvironmentError, json.JSONDecodeError):
# We fall back on creating an empty model card
modelcard = cls()

# Update model card with kwargs if needed
to_remove = []
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,7 +2156,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None)
_ = kwargs.pop("mirror", None)
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
Expand Down Expand Up @@ -2270,7 +2270,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
Expand Down Expand Up @@ -2321,7 +2320,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
)

config.name_or_path = pretrained_model_name_or_path
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
Expand Down Expand Up @@ -1955,7 +1955,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
Expand Down Expand Up @@ -2012,7 +2011,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
subfolder=subfolder,
)

Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/rag/retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer

Expand Down Expand Up @@ -111,22 +111,21 @@ def __init__(self, vector_size, index_path):
self._index_initialized = False

def _resolve_path(self, index_path, filename):
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`."
archive_file = os.path.join(index_path, filename)
is_local = os.path.isdir(index_path)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file)
resolved_archive_file = cached_file(index_path, filename)
except EnvironmentError:
msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n"
f"Can't load '{filename}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info(f"loading file {archive_file}")
if is_local:
logger.info(f"loading file {resolved_archive_file}")
else:
logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
return resolved_archive_file

def _load_passages(self):
Expand Down
Loading

0 comments on commit 377cdde

Please sign in to comment.