Skip to content

Commit

Permalink
Fixing offline mode for pipeline (when inferring task). (#21113)
Browse files Browse the repository at this point in the history
* Fixing offline mode for pipeline (when inferring task).

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Updating test to reflect change in exception.

* Fixing offline mode.

* Clean.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
Narsil and sgugger authored Jan 17, 2023
1 parent 8896ebb commit 25ddd91
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ..utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
is_kenlm_available,
is_offline_mode,
is_pyctcdecode_available,
is_tf_available,
is_torch_available,
Expand Down Expand Up @@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]:


def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
if is_offline_mode():
raise RuntimeError(f"You cannot infer task automatically within `pipeline` when using offline mode")
try:
info = model_info(model, token=use_auth_token)
except Exception as e:
Expand Down
84 changes: 72 additions & 12 deletions tests/utils/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import subprocess
import sys

from transformers import BertConfig, BertModel, BertTokenizer, pipeline
from transformers.testing_utils import TestCasePlus, require_torch


Expand All @@ -30,42 +31,77 @@ def test_offline_mode(self):

# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""

run = """
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""

mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
socket.socket = offline_socket
"""

# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)

# baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run])]
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]

# should succeed
env = self.get_env()
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())

# next emulate no network
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
@require_torch
def test_offline_mode_no_internet(self):
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""

# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
# env["TRANSFORMERS_OFFLINE"] = "0"
# result = subprocess.run(cmd, env=env, check=False, capture_output=True)
# self.assertEqual(result.returncode, 1, result.stderr)
run = """
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""

# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
socket.socket = offline_socket
"""

# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)

# baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]

# should succeed
env = self.get_env()
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())
Expand Down Expand Up @@ -93,7 +129,7 @@ def test_offline_mode_sharded_checkpoint(self):

mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
socket.socket = offline_socket
"""

Expand All @@ -119,3 +155,27 @@ def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())

@require_torch
def test_offline_mode_pipeline_exception(self):
load = """
from transformers import pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
pipe = pipeline(model=mname)
"""

mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
env = self.get_env()
env["TRANSFORMERS_OFFLINE"] = "1"
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr)
self.assertIn(
"You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode()
)

0 comments on commit 25ddd91

Please sign in to comment.