Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assistant with internet access #13

Merged
merged 16 commits into from
Jul 10, 2023
Merged
Prev Previous commit
Next Next commit
Add get_llm_assistant function
This function returns an assistant with or without internet access
  • Loading branch information
zaldivards committed Jul 7, 2023
commit 3ab942d276200a38d43b9e6eccb75df6a87e7976
39 changes: 33 additions & 6 deletions api/contextqa/services/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langchain.agents import initialize_agent, AgentType, Agent
from langchain.chat_models import ChatOpenAI
from langchain import ConversationChain
from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
Expand All @@ -9,7 +10,8 @@
)

from contextqa import models, settings
from contextqa.utils import memory
from contextqa.utils import memory, prompts
from contextqa.agents.tools import searcher


_MESSAGES = [
Expand All @@ -26,6 +28,33 @@
]


def get_llm_assistant() -> ConversationChain | Agent:
"""Return certain LLM assistant based on the system configuration

Returns
-------
ConversationChain | Agent
"""
llm = ChatOpenAI(temperature=0)

if settings().enable_internet_access:
return initialize_agent(
[searcher],
llm=llm,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
memory=memory.Redis("default"),
verbose=settings().debug,
agent_kwargs={
# "output_parser": CustomOP(),
# "format_instructions": prompts.CONTEXTQA_AGENT_TEMPLATE,
"prefix": prompts.PREFIX,
},
handle_parsing_errors=True,
)
prompt = ChatPromptTemplate.from_messages(_MESSAGES)
return ConversationChain(llm=llm, prompt=prompt, memory=memory.Redis("default"), verbose=settings().debug)


def qa_service(message: str) -> models.LLMResult:
"""Chat with the llm

Expand All @@ -39,7 +68,5 @@ def qa_service(message: str) -> models.LLMResult:
models.LLMResult
LLM response
"""
llm = ChatOpenAI(temperature=0)
prompt = ChatPromptTemplate.from_messages(_MESSAGES)
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory.Redis("default"), verbose=settings().debug)
return models.LLMResult(response=chain.run(input=message))
assistant = get_llm_assistant()
return models.LLMResult(response=assistant.run(input=message))