Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TENTATIVe] Attempt to reduce number of HEAD calls during model warmup. #18429

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
# TODO this seems to eat up from_auto_class and `_from_pipeline` kwargs
# which are necessary for tracking.
from_auto_class = kwargs.pop("_from_auto", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
)
kwargs["_from_auto"] = from_auto_class
kwargs["_from_pipeline"] = from_pipeline
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
user_agent = {
"file_type": "tokenizer",
"from_auto_class": kwargs.get("_from_auto", False),
"is_fast": kwargs.get("is_fast", False),
}
if "_from_pipeline" in kwargs:
user_agent["using_pipeline"] = kwargs.get("_from_pipeline")

resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
Expand All @@ -399,6 +407,7 @@ def get_tokenizer_config(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
user_agent=user_agent,
)
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
Expand Down Expand Up @@ -502,7 +511,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True

use_fast = kwargs.pop("use_fast", True)
use_fast = kwargs.get("use_fast", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means the use_fast stays in the kwargs, which are then sent to classes like AutoConfig, which don't use it. So I'd leave it as pop here and if it's missing in a call down there, let's add it manually.

tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,11 @@ def pipeline(

# Config is the primordial information item.
# Instantiate config if needed
_from_pipeline = "auto" if task is None else task
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
config = AutoConfig.from_pretrained(config, _from_pipeline=_from_pipeline, **hub_kwargs, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
config = AutoConfig.from_pretrained(model, _from_pipeline=_from_pipeline, **hub_kwargs, **model_kwargs)

custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
Expand Down Expand Up @@ -3196,8 +3197,7 @@ def truncate_sequences(
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
error_msg
+ "Please select another truncation strategy than "
error_msg + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
Expand Down
48 changes: 47 additions & 1 deletion src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import copy
import fnmatch
import functools
import io
import json
import os
Expand All @@ -25,6 +26,7 @@
import sys
import tarfile
import tempfile
import time
import traceback
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -456,6 +458,42 @@ def http_get(
progress.close()


def timed_cache(ttl):
cache = {}

def true_decorator(f):
@functools.wraps(f)
def wrapped(url, headers, *args, **kwargs):

key = (url, headers.get("authorization", None))
result_time, result = cache.get(key, (None, None))
if result_time is not None and (time.time() - result_time) < ttl:
return result
result = f(url, headers, *args, **kwargs)
cache[key] = (time.time(), result)
return result

def clear_cache():
nonlocal cache
cache.clear()

wrapped.clear_cache = clear_cache

return wrapped

return true_decorator


def _request_head(url, headers, allow_redirects, proxies, timeout):
r = requests.head(url, headers=headers, allow_redirects=allow_redirects, proxies=proxies, timeout=timeout)
return r


@timed_cache(ttl=10)
def request_head(url, headers, allow_redirects, proxies, timeout):
return _request_head(url, headers, allow_redirects, proxies, timeout)


def get_from_cache(
url: str,
cache_dir=None,
Expand Down Expand Up @@ -497,7 +535,13 @@ def get_from_cache(
etag = None
if not local_files_only:
try:
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r = request_head(
url,
headers=headers,
allow_redirects=False,
proxies=proxies,
timeout=etag_timeout,
)
_raise_for_status(r)
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
Expand Down Expand Up @@ -844,6 +888,7 @@ def get_file_from_repo(
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
user_agent: Dict[str, str] = None,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
Expand Down Expand Up @@ -911,6 +956,7 @@ def get_file_from_repo(
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
user_agent=user_agent,
)


Expand Down
202 changes: 201 additions & 1 deletion tests/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import unittest
from pathlib import Path

import requests
import transformers

# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer, pipeline
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_tokenizers, require_torch
from transformers.utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
Expand Down Expand Up @@ -218,3 +220,201 @@ def test_find_labels(self):
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])


class TelemetryTest(unittest.TestCase):
url = "https://huggingface.co/hf-internal-testing/tiny-random-gpt2/resolve/main/"

def setUp(self):

self.real_request = requests.request

self.calls = []

def counter_request_head(*args, **kwargs):
self.calls.append([args, kwargs])
return self.real_request(*args, **kwargs)

requests.request = counter_request_head

def tearDown(self):
requests.request = self.real_request

@require_torch
def test_pipeline(self):
pipeline(task="text-generation", model="hf-internal-testing/tiny-random-gpt2")

files = [kwargs["url"][len(self.url) :] for _, kwargs in self.calls]
self.assertEqual(
files,
[
"config.json",
"pytorch_model.bin",
"tokenizer_config.json",
"vocab.json",
"merges.txt",
"tokenizer.json",
"added_tokens.json",
"special_tokens_map.json",
],
)

for file, (_, kwargs) in zip(files, self.calls):
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"using_pipeline/text-generation",
user_agent,
f"user_agent is incorrect for file {file}",
)

@require_tokenizers
def test_tokenizer(self):
AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")

files = [kwargs["url"][len(self.url) :] for _, kwargs in self.calls]
self.assertEqual(
files,
[
"tokenizer_config.json",
"vocab.json",
"merges.txt",
"tokenizer.json",
"added_tokens.json",
"special_tokens_map.json",
],
)

for file, (_, kwargs) in zip(files, self.calls):
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"file_type/tokenizer",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"is_fast/True",
user_agent,
f"user_agent is incorrect for file {file}",
)

def test_slow_tokenizer(self):
AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2", use_fast=False)

files = [kwargs["url"][len(self.url) :] for _, kwargs in self.calls]
self.assertEqual(
files,
[
"tokenizer_config.json",
"vocab.json",
"merges.txt",
"added_tokens.json",
"special_tokens_map.json",
],
)

for file, (_, kwargs) in zip(files, self.calls):
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"file_type/tokenizer",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"from_auto_class/True",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"is_fast/False",
user_agent,
f"user_agent is incorrect for file {file}",
)

@require_torch
def test_model(self):
AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")

files = [kwargs["url"][len(self.url) :] for _, kwargs in self.calls]
self.assertEqual(
files,
[
"config.json",
"pytorch_model.bin",
],
)

for file, (_, kwargs) in zip(files, self.calls):
if file == "config.json":
# Skip config
continue
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"file_type/model",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"from_auto_class/True",
user_agent,
f"user_agent is incorrect for file {file}",
)

@require_torch
def test_model_raw(self):
from transformers import GPT2Model

GPT2Model.from_pretrained("hf-internal-testing/tiny-random-gpt2")

files = [kwargs["url"][len(self.url) :] for _, kwargs in self.calls]
self.assertEqual(
files,
[
"config.json",
"pytorch_model.bin",
],
)

for file, (_, kwargs) in zip(files, self.calls):
if file == "config.json":
# Skip config
continue
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"file_type/model",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"from_auto_class/False",
user_agent,
f"user_agent is incorrect for file {file}",
)

def test_feature_extractor(self):
AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
url = "https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/"

files = [args[0][len(url) :] for args, _ in self.calls]
self.assertEqual(
files,
["preprocessor_config.json"],
)

for file, (_, kwargs) in zip(files, self.calls):
user_agent = kwargs["headers"]["user-agent"]

self.assertIn(
"file_type/feature extractor",
user_agent,
f"user_agent is incorrect for file {file}",
)
self.assertIn(
"from_auto_class/True",
user_agent,
f"user_agent is incorrect for file {file}",
)