Skip to content

Commit

Permalink
First attempt at OpenAI compatibilty wrapper, refs #325
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 26, 2024
1 parent 9957c71 commit 1c21929
Showing 1 changed file with 51 additions and 11 deletions.
62 changes: 51 additions & 11 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,54 @@
import json
import yaml

if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):

def log_response(response, *args, **kwargs):
click.echo(response.text, err=True)
return response
def _log_response(response, *args, **kwargs):
click.echo(response.text, err=True)
return response

openai.requestssession = requests.Session()
openai.requestssession.hooks["response"].append(log_response)

_log_session = requests.Session()
_log_session.hooks["response"].append(_log_response)


def is_openai_pre_1():
return openai.version.VERSION.startswith("0.")


if is_openai_pre_1 and os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
openai.requestssession = _log_session


class OpenAILegacyWrapper:
def __init__(self, client):
self.client = client

@property
def ChatCompletion(self):
return self.client.chat.completions

@property
def Completion(self):
return self.client.completions

@property
def Embedding(self):
return self.client.embeddings


def get_openai_client():
if is_openai_pre_1:
return openai

if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
client = openai.OpenAI(requestssession=_log_session)
else:
client = openai.OpenAI()

return OpenAILegacyWrapper(client)


client = get_openai_client()


@hookimpl
Expand Down Expand Up @@ -111,7 +151,7 @@ def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float
}
if self.dimensions:
kwargs["dimensions"] = self.dimensions
results = openai.Embedding.create(**kwargs)["data"]
results = client.Embedding.create(**kwargs)["data"]
return ([float(r) for r in result["embedding"]] for result in results)


Expand Down Expand Up @@ -305,7 +345,7 @@ def execute(self, prompt, stream, response, conversation=None):
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
if stream:
completion = openai.ChatCompletion.create(
completion = client.ChatCompletion.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
Expand All @@ -319,7 +359,7 @@ def execute(self, prompt, stream, response, conversation=None):
yield content
response.response_json = combine_chunks(chunks)
else:
completion = openai.ChatCompletion.create(
completion = client.ChatCompletion.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
Expand Down Expand Up @@ -384,7 +424,7 @@ def execute(self, prompt, stream, response, conversation=None):
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
if stream:
completion = openai.Completion.create(
completion = client.Completion.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=True,
Expand All @@ -398,7 +438,7 @@ def execute(self, prompt, stream, response, conversation=None):
yield content
response.response_json = combine_chunks(chunks)
else:
completion = openai.Completion.create(
completion = client.Completion.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=False,
Expand Down

0 comments on commit 1c21929

Please sign in to comment.