Skip to content

Commit

Permalink
community[minor]: Add openvino backend support (langchain-ai#11591)
Browse files Browse the repository at this point in the history
- **Description:** add openvino backend support by HuggingFace Optimum
Intel,
  - **Dependencies:** “optimum[openvino]”,

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
openvino-dev-samples and baskaryan committed Mar 1, 2024
1 parent a89f007 commit f61cb8d
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 7 deletions.
90 changes: 89 additions & 1 deletion docs/docs/integrations/llms/huggingface_pipelines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,94 @@
"for answer in answers:\n",
" print(answer)"
]
},
{
"cell_type": "markdown",
"id": "df1d41d9",
"metadata": {},
"source": [
"### Inference with OpenVINO backend\n",
"\n",
"To deploy a model with OpenVINO, you can specify the `backend=\"openvino\"` parameter to trigger OpenVINO as backend inference framework.\n",
"\n",
"If you have an Intel GPU, you can specify `model_kwargs={\"device\": \"GPU\"}` to run inference on it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efb73dd7-77bf-4436-92e5-51306af45bd7",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade-strategy eager \"optimum[openvino,nncf]\" --quiet"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70f6826c",
"metadata": {},
"outputs": [],
"source": [
"ov_config = {\"PERFORMANCE_HINT\": \"LATENCY\", \"NUM_STREAMS\": \"1\", \"CACHE_DIR\": \"\"}\n",
"\n",
"ov_llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"gpt2\",\n",
" task=\"text-generation\",\n",
" backend=\"openvino\",\n",
" model_kwargs={\"device\": \"CPU\", \"ov_config\": ov_config},\n",
" pipeline_kwargs={\"max_new_tokens\": 10},\n",
")\n",
"\n",
"ov_chain = prompt | ov_llm\n",
"\n",
"question = \"What is electroencephalography?\"\n",
"\n",
"print(ov_chain.invoke({\"question\": question}))"
]
},
{
"cell_type": "markdown",
"id": "12524837-e9ab-455a-86be-66b95f4f893a",
"metadata": {},
"source": [
"### Inference with local OpenVINO model\n",
"\n",
"It is possible to [export your model](https://github.com/huggingface/optimum-intel?tab=readme-ov-file#export) to the OpenVINO IR format with the CLI, and load the model from local folder.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d1104a2-79c7-43a6-aa1c-8076a5ad7747",
"metadata": {},
"outputs": [],
"source": [
"!optimum-cli export openvino --model gpt2 ov_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac71e60d-5595-454e-8602-03ebb0248205",
"metadata": {},
"outputs": [],
"source": [
"ov_llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"ov_model\",\n",
" task=\"text-generation\",\n",
" backend=\"openvino\",\n",
" model_kwargs={\"device\": \"CPU\", \"ov_config\": ov_config},\n",
" pipeline_kwargs={\"max_new_tokens\": 10},\n",
")\n",
"\n",
"ov_chain = prompt | ov_llm\n",
"\n",
"question = \"What is electroencephalography?\"\n",
"\n",
"print(ov_chain.invoke({\"question\": question}))"
]
}
],
"metadata": {
Expand All @@ -210,7 +298,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
16 changes: 16 additions & 0 deletions docs/docs/integrations/platforms/huggingface.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ See a [usage example](/docs/integrations/llms/huggingface_pipelines).
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
```

To use the OpenVINO backend in local pipeline wrapper, please install the optimum library and set HuggingFacePipeline's backend as `openvino`:

```bash
pip install --upgrade-strategy eager "optimum[openvino,nncf]"
```

See a [usage example](/docs/integrations/llms/huggingface_pipelines)

To export your model to the OpenVINO IR format with the CLI:

```bash
optimum-cli export openvino --model gpt2 ov_model
```

To apply [weight-only quantization](https://github.com/huggingface/optimum-intel?tab=readme-ov-file#export) when exporting your model.

### Hugging Face TextGen Inference

>[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is
Expand Down
71 changes: 65 additions & 6 deletions libs/community/langchain_community/llms/huggingface_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def from_model_id(
cls,
model_id: str,
task: str,
backend: str = "default",
device: Optional[int] = -1,
device_map: Optional[str] = None,
model_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -95,9 +96,57 @@ def from_model_id(

try:
if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
if backend == "openvino":
try:
from optimum.intel.openvino import OVModelForCausalLM

except ImportError:
raise ValueError(
"Could not import optimum-intel python package. "
"Please install it with: "
"pip install 'optimum[openvino,nncf]' "
)
try:
# use local model
model = OVModelForCausalLM.from_pretrained(
model_id, **_model_kwargs
)

except Exception:
# use remote model
model = OVModelForCausalLM.from_pretrained(
model_id, export=True, **_model_kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_id, **_model_kwargs
)
elif task in ("text2text-generation", "summarization"):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
if backend == "openvino":
try:
from optimum.intel.openvino import OVModelForSeq2SeqLM

except ImportError:
raise ValueError(
"Could not import optimum-intel python package. "
"Please install it with: "
"pip install 'optimum[openvino,nncf]' "
)
try:
# use local model
model = OVModelForSeq2SeqLM.from_pretrained(
model_id, **_model_kwargs
)

except Exception:
# use remote model
model = OVModelForSeq2SeqLM.from_pretrained(
model_id, export=True, **_model_kwargs
)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, **_model_kwargs
)
else:
raise ValueError(
f"Got invalid task {task}, "
Expand All @@ -112,9 +161,13 @@ def from_model_id(
tokenizer.pad_token_id = model.config.eos_token_id

if (
getattr(model, "is_loaded_in_4bit", False)
or getattr(model, "is_loaded_in_8bit", False)
) and device is not None:
(
getattr(model, "is_loaded_in_4bit", False)
or getattr(model, "is_loaded_in_8bit", False)
)
and device is not None
and backend == "default"
):
logger.warning(
f"Setting the `device` argument to None from {device} to avoid "
"the error caused by attempting to move the model that was already "
Expand All @@ -123,7 +176,11 @@ def from_model_id(
)
device = None

if device is not None and importlib.util.find_spec("torch") is not None:
if (
device is not None
and importlib.util.find_spec("torch") is not None
and backend == "default"
):
import torch

cuda_device_count = torch.cuda.device_count()
Expand All @@ -142,6 +199,8 @@ def from_model_id(
"can be a positive integer associated with CUDA device id.",
cuda_device_count,
)
if device is not None and device_map is not None and backend == "openvino":
logger.warning("Please set device for OpenVINO through: " "'model_kwargs'")
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,45 @@ def test_huggingface_pipeline_runtime_kwargs() -> None:
prompt = "Say foo:"
output = llm(prompt, pipeline_kwargs={"max_new_tokens": 2})
assert len(output) < 10


ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}


def test_huggingface_pipeline_text_generation_ov() -> None:
"""Test valid call to HuggingFace text generation model with openvino."""
llm = HuggingFacePipeline.from_model_id(
model_id="gpt2",
task="text-generation",
backend="openvino",
model_kwargs={"device": "CPU", "ov_config": ov_config},
pipeline_kwargs={"max_new_tokens": 64},
)
output = llm("Say foo:")
assert isinstance(output, str)


def test_huggingface_pipeline_text2text_generation_ov() -> None:
"""Test valid call to HuggingFace text2text generation model with openvino."""
llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-small",
task="text2text-generation",
backend="openvino",
model_kwargs={"device": "CPU", "ov_config": ov_config},
pipeline_kwargs={"max_new_tokens": 64},
)
output = llm("Say foo:")
assert isinstance(output, str)


def text_huggingface_pipeline_summarization_ov() -> None:
"""Test valid call to HuggingFace summarization model with openvino."""
llm = HuggingFacePipeline.from_model_id(
model_id="facebook/bart-large-cnn",
task="summarization",
backend="openvino",
model_kwargs={"device": "CPU", "ov_config": ov_config},
pipeline_kwargs={"max_new_tokens": 64},
)
output = llm("Say foo:")
assert isinstance(output, str)

0 comments on commit f61cb8d

Please sign in to comment.