forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LLMs support for Anyscale Service (langchain-ai#4350)
Add Anyscale service integration under LLM Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
- Loading branch information
Showing
4 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Anyscale | ||
|
||
This page covers how to use the Anyscale ecosystem within LangChain. | ||
It is broken into two parts: installation and setup, and then references to specific Anyscale wrappers. | ||
|
||
## Installation and Setup | ||
- Get an Anyscale Service URL, route and API key and set them as environment variables (`ANYSCALE_SERVICE_URL`,`ANYSCALE_SERVICE_ROUTE`, `ANYSCALE_SERVICE_TOKEN`). | ||
- Please see [the Anyscale docs](https://docs.anyscale.com/productionize/services-v2/get-started) for more details. | ||
|
||
## Wrappers | ||
|
||
### LLM | ||
|
||
There exists an Anyscale LLM wrapper, which you can access with | ||
```python | ||
from langchain.llms import Anyscale | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9597802c", | ||
"metadata": {}, | ||
"source": [ | ||
"# Anysacle\n", | ||
"\n", | ||
"[Anyscale](https://www.anyscale.com/) is a fully-managed [Ray](https://www.ray.io/) platform, on which you can build, deploy, and manage scalable AI and Python applications\n", | ||
"\n", | ||
"This example goes over how to use LangChain to interact with `Anyscale` [service](https://docs.anyscale.com/productionize/services-v2/get-started)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5472a7cd-af26-48ca-ae9b-5f6ae73c74d2", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"ANYSCALE_SERVICE_URL\"] = ANYSCALE_SERVICE_URL\n", | ||
"os.environ[\"ANYSCALE_SERVICE_ROUTE\"] = ANYSCALE_SERVICE_ROUTE\n", | ||
"os.environ[\"ANYSCALE_SERVICE_TOKEN\"] = ANYSCALE_SERVICE_TOKEN" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6fb585dd", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.llms import Anyscale\n", | ||
"from langchain import PromptTemplate, LLMChain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "035dea0f", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"template = \"\"\"Question: {question}\n", | ||
"\n", | ||
"Answer: Let's think step by step.\"\"\"\n", | ||
"\n", | ||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3f3458d9", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"llm = Anyscale()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a641dbd9", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"llm_chain = LLMChain(prompt=prompt, llm=llm)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9f844993", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"question = \"When was George Washington president?\"\n", | ||
"\n", | ||
"llm_chain.run(question)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "42f05b34-1a44-4cbd-8342-35c1572b6765", | ||
"metadata": {}, | ||
"source": [ | ||
"With Ray, we can distribute the queries without asyncrhonized implementation. This not only applies to Anyscale LLM model, but to any other Langchain LLM models which do not have `_acall` or `_agenerate` implemented" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "08b23adc-2b29-4c38-b538-47b3c3d840a6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt_list = [\n", | ||
" \"When was George Washington president?\",\n", | ||
" \"Explain to me the difference between nuclear fission and fusion.\",\n", | ||
" \"Give me a list of 5 science fiction books I should read next.\",\n", | ||
" \"Explain the difference between Spark and Ray.\",\n", | ||
" \"Suggest some fun holiday ideas.\",\n", | ||
" \"Tell a joke.\",\n", | ||
" \"What is 2+2?\",\n", | ||
" \"Explain what is machine learning like I am five years old.\",\n", | ||
" \"Explain what is artifical intelligence.\",\n", | ||
"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2b45abb9-b764-497d-af99-0df1d4e335e0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import ray\n", | ||
"\n", | ||
"@ray.remote\n", | ||
"def send_query(llm, prompt):\n", | ||
" resp = llm(prompt)\n", | ||
" return resp\n", | ||
"\n", | ||
"futures = [send_query.remote(llm, prompt) for prompt in prompt_list]\n", | ||
"results = ray.get(futures)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.8" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Wrapper around Anyscale""" | ||
from typing import Any, Dict, List, Mapping, Optional | ||
|
||
import requests | ||
from pydantic import Extra, root_validator | ||
|
||
from langchain.callbacks.manager import CallbackManagerForLLMRun | ||
from langchain.llms.base import LLM | ||
from langchain.llms.utils import enforce_stop_tokens | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
|
||
class Anyscale(LLM): | ||
"""Wrapper around Anyscale Services. | ||
To use, you should have the environment variable ``ANYSCALE_SERVICE_URL``, | ||
``ANYSCALE_SERVICE_ROUTE`` and ``ANYSCALE_SERVICE_TOKEN`` set with your Anyscale | ||
Service, or pass it as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain.llms import Anyscale | ||
anyscale = Anyscale(anyscale_service_url="SERVICE_URL", | ||
anyscale_service_route="SERVICE_ROUTE", | ||
anyscale_service_token="SERVICE_TOKEN") | ||
# Use Ray for distributed processing | ||
import ray | ||
prompt_list=[] | ||
@ray.remote | ||
def send_query(llm, prompt): | ||
resp = llm(prompt) | ||
return resp | ||
futures = [send_query.remote(anyscale, prompt) for prompt in prompt_list] | ||
results = ray.get(futures) | ||
""" | ||
|
||
model_kwargs: Optional[dict] = None | ||
"""Key word arguments to pass to the model. Reserved for future use""" | ||
|
||
anyscale_service_url: Optional[str] = None | ||
anyscale_service_route: Optional[str] = None | ||
anyscale_service_token: Optional[str] = None | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
anyscale_service_url = get_from_dict_or_env( | ||
values, "anyscale_service_url", "ANYSCALE_SERVICE_URL" | ||
) | ||
anyscale_service_route = get_from_dict_or_env( | ||
values, "anyscale_service_route", "ANYSCALE_SERVICE_ROUTE" | ||
) | ||
anyscale_service_token = get_from_dict_or_env( | ||
values, "anyscale_service_token", "ANYSCALE_SERVICE_TOKEN" | ||
) | ||
try: | ||
anyscale_service_endpoint = f"{anyscale_service_url}/-/route" | ||
headers = {"Authorization": f"Bearer {anyscale_service_token}"} | ||
requests.get(anyscale_service_endpoint, headers=headers) | ||
except requests.exceptions.RequestException as e: | ||
raise ValueError(e) | ||
values["anyscale_service_url"] = anyscale_service_url | ||
values["anyscale_service_route"] = anyscale_service_route | ||
values["anyscale_service_token"] = anyscale_service_token | ||
return values | ||
|
||
@property | ||
def _identifying_params(self) -> Mapping[str, Any]: | ||
"""Get the identifying parameters.""" | ||
return { | ||
"anyscale_service_url": self.anyscale_service_url, | ||
"anyscale_service_route": self.anyscale_service_route, | ||
} | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of llm.""" | ||
return "anyscale" | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
) -> str: | ||
"""Call out to Anyscale Service endpoint. | ||
Args: | ||
prompt: The prompt to pass into the model. | ||
stop: Optional list of stop words to use when generating. | ||
Returns: | ||
The string generated by the model. | ||
Example: | ||
.. code-block:: python | ||
response = anyscale("Tell me a joke.") | ||
""" | ||
|
||
anyscale_service_endpoint = ( | ||
f"{self.anyscale_service_url}/{self.anyscale_service_route}" | ||
) | ||
headers = {"Authorization": f"Bearer {self.anyscale_service_token}"} | ||
body = {"prompt": prompt} | ||
resp = requests.post(anyscale_service_endpoint, headers=headers, json=body) | ||
|
||
if resp.status_code != 200: | ||
raise ValueError( | ||
f"Error returned by service, status code {resp.status_code}" | ||
) | ||
text = resp.text | ||
|
||
if stop is not None: | ||
# This is a bit hacky, but I can't figure out a better way to enforce | ||
# stop tokens when making calls to huggingface_hub. | ||
text = enforce_stop_tokens(text, stop) | ||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Test Anyscale API wrapper.""" | ||
|
||
from langchain.llms.anyscale import Anyscale | ||
|
||
|
||
def test_anyscale_call() -> None: | ||
"""Test valid call to Anyscale.""" | ||
llm = Anyscale() | ||
output = llm("Say foo:") | ||
assert isinstance(output, str) |