Skip to content

Commit

Permalink
pipeline support for device="mps" (or any other string) (huggingf…
Browse files Browse the repository at this point in the history
…ace#18494)

* `pipeline` support for `device="mps"` (or any other string)

* Simplify `if` nesting

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix? @sgugger

* passing `attr=None` is not the same as not passing `attr` 🤯

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and oneraghavan committed Sep 26, 2022
1 parent 261a3d5 commit 16aa38d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def pipeline(
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
device: Optional[Union[int, str, "torch.device"]] = None,
device_map=None,
torch_dtype=None,
trust_remote_code: Optional[bool] = None,
Expand Down Expand Up @@ -508,6 +509,9 @@ def pipeline(
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
device (`int` or `str` or `torch.device`):
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
pipeline will be allocated.
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
Expand Down Expand Up @@ -811,4 +815,7 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor

if device is not None:
kwargs["device"] = device

return pipeline_class(model=model, framework=framework, task=task, **kwargs)
19 changes: 13 additions & 6 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def predict(self, X):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id. You can pass native `torch.device` too.
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Expand Down Expand Up @@ -747,7 +747,7 @@ def __init__(
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
device: Union[int, str, "torch.device"] = -1,
binary_output: bool = False,
**kwargs,
):
Expand All @@ -760,14 +760,21 @@ def __init__(
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
if is_torch_available() and isinstance(device, torch.device):
self.device = device
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda:{device}")
else:
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.device = device
self.binary_output = binary_output

# Special handling
if self.framework == "pt" and self.device.type == "cuda":
if self.framework == "pt" and self.device.type != "cpu":
self.model = self.model.to(self.device)

# Update config with task specific parameters
Expand Down

0 comments on commit 16aa38d

Please sign in to comment.