Skip to content

Commit

Permalink
fix(Model): Compatible with openai 1.x.x, compatible with AzureOpeai (e…
Browse files Browse the repository at this point in the history
…osphoros-ai#804)

Update chatgpt.py
  • Loading branch information
Aries-ckt authored Nov 16, 2023
2 parents 8eaf369 + 2ff9625 commit 1ad09c8
Showing 1 changed file with 121 additions and 26 deletions.
147 changes: 121 additions & 26 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from typing import List
import logging

import importlib.metadata as metadata
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
Expand Down Expand Up @@ -57,14 +57,48 @@ def _initialize_openai(params: ProxyModelParameters):
return openai_params


def _initialize_openai_v1(params: ProxyModelParameters):
try:
from openai import OpenAI
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
)

api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")

base_url = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")

if not base_url and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
base_url = params.proxy_server_url.split("/chat/completions")[0]

if params.http_proxy:
openai.proxies = params.http_proxy
openai_params = {
"api_key": api_key,
"base_url": base_url,
"proxies": params.http_proxy,
}

return openai_params, api_type, api_version


def _build_request(model: ProxyModel, params):
history = []

model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}")

openai_params = _initialize_openai(model_params)

messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
Expand Down Expand Up @@ -95,13 +129,19 @@ def _build_request(model: ProxyModel, params):
}
proxyllm_backend = model_params.proxyllm_backend

if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
if metadata.version("openai") >= "1.0.0":
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
else:
openai_params = _initialize_openai(model_params)
if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend

logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
Expand All @@ -112,32 +152,87 @@ def _build_request(model: ProxyModel, params):
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import openai
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AzureOpenAI

client = AzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
)
else:
from openai import OpenAI

history, payloads = _build_request(model, params)
client = OpenAI(**openai_params)
res = client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield text

res = openai.ChatCompletion.create(messages=history, **payloads)
else:
import openai

history, payloads = _build_request(model, params)

text = ""
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text
res = openai.ChatCompletion.create(messages=history, **payloads)

text = ""
print("res", res)
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text


async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import openai
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AsyncAzureOpenAI

client = AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
)
else:
from openai import AsyncOpenAI

client = AsyncOpenAI(**openai_params)

res = await client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield text
else:
import openai

history, payloads = _build_request(model, params)
history, payloads = _build_request(model, params)

res = await openai.ChatCompletion.acreate(messages=history, **payloads)
res = await openai.ChatCompletion.acreate(messages=history, **payloads)

text = ""
async for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text
text = ""
async for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text

0 comments on commit 1ad09c8

Please sign in to comment.