Skip to content

Commit

Permalink
Move to Langchain LCEL
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Feb 3, 2024
1 parent a694108 commit f573826
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 98 deletions.
22 changes: 21 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,25 @@
"[python]": {
"editor.defaultFormatter": "ms-python.python"
},
"python.formatting.provider": "none"
"python.formatting.provider": "none",
"workbench.colorCustomizations": {
"activityBar.activeBackground": "#7c185f",
"activityBar.background": "#7c185f",
"activityBar.foreground": "#e7e7e7",
"activityBar.inactiveForeground": "#e7e7e799",
"activityBarBadge.background": "#000000",
"activityBarBadge.foreground": "#e7e7e7",
"commandCenter.border": "#e7e7e799",
"sash.hoverBorder": "#7c185f",
"statusBar.background": "#51103e",
"statusBar.foreground": "#e7e7e7",
"statusBarItem.hoverBackground": "#7c185f",
"statusBarItem.remoteBackground": "#51103e",
"statusBarItem.remoteForeground": "#e7e7e7",
"titleBar.activeBackground": "#51103e",
"titleBar.activeForeground": "#e7e7e7",
"titleBar.inactiveBackground": "#51103e99",
"titleBar.inactiveForeground": "#e7e7e799"
},
"peacock.color": "#51103e"
}
133 changes: 75 additions & 58 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@

import boto3
import streamlit as st
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI, BedrockChat
from langchain.chat_models import BedrockChat, ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import SupabaseVectorStore
from pydantic import BaseModel, validator
from supabase.client import Client, create_client

from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT
from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT

from operator import itemgetter

from langchain.prompts.prompt import PromptTemplate
from langchain.schema import format_document
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

supabase_url = st.secrets["SUPABASE_URL"]
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
Expand All @@ -25,7 +34,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "claude", "mixtral"]:
if v not in ["gpt", "codellama", "mixtral"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -44,23 +53,15 @@ def __init__(self, config: ModelConfig):
def setup(self):
if self.model_type == "gpt":
self.setup_gpt()
elif self.model_type == "claude":
self.setup_claude()
elif self.model_type == "codellama":
self.setup_codellama()
elif self.model_type == "mixtral":
self.setup_mixtral()

def setup_gpt(self):
self.q_llm = OpenAI(
temperature=0.1,
api_key=self.secrets["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo-16k",
max_tokens=500,
base_url=self.gateway_url,
)

self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo-16k",
temperature=0.5,
model_name="gpt-3.5-turbo-0125",
temperature=0.2,
api_key=self.secrets["OPENAI_API_KEY"],
max_tokens=500,
callbacks=[self.callback_handler],
Expand All @@ -69,60 +70,76 @@ def setup_gpt(self):
)

def setup_mixtral(self):
self.q_llm = OpenAI(
temperature=0.1,
api_key=self.secrets["MIXTRAL_API_KEY"],
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=500,
base_url="https://api.together.xyz/v1",
)

self.llm = ChatOpenAI(
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.5,
temperature=0.2,
api_key=self.secrets["MIXTRAL_API_KEY"],
max_tokens=500,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://api.together.xyz/v1",
)

def setup_claude(self):
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
region_name="us-east-1",
)
parameters = {
"max_tokens_to_sample": 1000,
"stop_sequences": [],
"temperature": 0,
"top_p": 0.9,
}
self.q_llm = BedrockChat(
model_id="anthropic.claude-instant-v1", client=bedrock_runtime
)

self.llm = BedrockChat(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime,
def setup_codellama(self):
self.llm = ChatOpenAI(
model_name="codellama/codellama-70b-instruct",
temperature=0.2,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=500,
callbacks=[self.callback_handler],
streaming=True,
model_kwargs=parameters,
base_url="https://openrouter.ai/api/v1",
)

# def setup_claude(self):
# bedrock_runtime = boto3.client(
# service_name="bedrock-runtime",
# aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
# aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
# region_name="us-east-1",
# )
# parameters = {
# "max_tokens_to_sample": 1000,
# "stop_sequences": [],
# "temperature": 0,
# "top_p": 0.9,
# }
# self.q_llm = BedrockChat(
# model_id="anthropic.claude-instant-v1", client=bedrock_runtime
# )

# self.llm = BedrockChat(
# model_id="anthropic.claude-instant-v1",
# client=bedrock_runtime,
# callbacks=[self.callback_handler],
# streaming=True,
# model_kwargs=parameters,
# )

def get_chain(self, vectorstore):
if not self.q_llm or not self.llm:
raise ValueError("Models have not been properly initialized.")
question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT)
conv_chain = ConversationalRetrievalChain(
retriever=vectorstore.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)

_inputs = RunnableParallel(
standalone_question=RunnablePassthrough.assign(
chat_history=lambda x: get_buffer_string(x["chat_history"])
)
| CONDENSE_QUESTION_PROMPT
| OpenAI()
| StrOutputParser(),
)
return conv_chain
_context = {
"context": itemgetter("standalone_question")
| vectorstore.as_retriever()
| _combine_documents,
"question": lambda x: x["standalone_question"],
}
conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm

return conversational_qa_chain


def load_chain(model_name="GPT-3.5", callback_handler=None):
Expand All @@ -136,8 +153,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
query_name="v_match_documents",
)

if "claude" in model_name.lower():
model_type = "claude"
if "codellama" in model_name.lower():
model_type = "codellama"
elif "GPT-3.5" in model_name:
model_type = "gpt"
elif "mixtral" in model_name.lower():
Expand Down
43 changes: 23 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from snowflake.snowpark.exceptions import SnowparkSQLException

from chain import load_chain
from utils.snow_connect import SnowflakeConnection

# from utils.snow_connect import SnowflakeConnection
from utils.snowchat_ui import StreamlitUICallbackHandler, message_func
from utils.snowddl import Snowddl

Expand All @@ -17,11 +18,10 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["✨ GPT-3.5", "♾️ Claude", "⛰️ Mixtral"],
options=["✨ GPT-3.5", "♾️ codellama", "⛰️ Mixtral"],
index=0,
horizontal=True,
)

st.session_state["model"] = model

INITIAL_MESSAGE = [
Expand Down Expand Up @@ -97,15 +97,10 @@ def get_sql(text):
return sql_match.group(1) if sql_match else None


def append_message(content, role="assistant", display=False):
message = {"role": role, "content": content}
st.session_state.messages.append(message)
if role != "data":
append_chat_history(st.session_state.messages[-2]["content"], content)

if callback_handler.has_streaming_ended:
callback_handler.has_streaming_ended = False
return
def append_message(content, role="assistant"):
"""Appends a message to the session state messages."""
if content.strip():
st.session_state.messages.append({"role": role, "content": content})


def handle_sql_exception(query, conn, e, retries=2):
Expand Down Expand Up @@ -135,14 +130,22 @@ def execute_sql(query, conn, retries=2):
return handle_sql_exception(query, conn, e, retries)


if st.session_state.messages[-1]["role"] != "assistant":
content = st.session_state.messages[-1]["content"]
if isinstance(content, str):
result = chain(
{"question": content, "chat_history": st.session_state["history"]}
)["answer"]
print(result)
append_message(result)
if (
"messages" in st.session_state
and st.session_state["messages"][-1]["role"] != "assistant"
):
user_input_content = st.session_state["messages"][-1]["content"]
# print(f"User input content is: {user_input_content}")

if isinstance(user_input_content, str):
result = chain.invoke(
{
"question": user_input_content,
"chat_history": [h for h in st.session_state["history"]],
}
)
append_message(result.content)

# if get_sql(result):
# conn = SnowflakeConnection().get_session()
# df = execute_sql(get_sql(result), conn)
Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
langchain==0.0.350
langchain==0.1.5
pandas==1.5.0
pydantic==1.10.8
snowflake_snowpark_python==1.5.0
snowflake-snowpark-python[pandas]
streamlit==1.27.1
streamlit==1.31.0
supabase==1.0.3
unstructured==0.7.12
tiktoken==0.4.0
openai==0.27.8
openai==1.11.0
black==23.3.0
replicate==0.8.4
boto3==1.28.57
20 changes: 13 additions & 7 deletions template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain.prompts.prompt import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate

template = """You are an AI chatbot having a conversation with a human.
Expand Down Expand Up @@ -27,11 +28,13 @@
Write your response in markdown format.
Human: ```{question}```
User: {question}
{context}
Assistant:
"""


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

Expand All @@ -54,11 +57,14 @@
"""

LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST
# LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST

CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_template(template)

# QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"])
# LLAMA_PROMPT = PromptTemplate(
# template=LLAMA_TEMPLATE, input_variables=["question", "context"]
# )

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(template)

QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"])
LLAMA_PROMPT = PromptTemplate(
template=LLAMA_TEMPLATE, input_variables=["question", "context"]
)
QA_PROMPT = ChatPromptTemplate.from_template(TEMPLATE)
13 changes: 5 additions & 8 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class StreamlitUICallbackHandler(BaseCallbackHandler):
def __init__(self):
# Buffer to accumulate tokens
self.token_buffer = []
self.placeholder = None
self.placeholder = st.empty()
self.has_streaming_ended = False

def _get_bot_message_container(self, text):
Expand All @@ -111,13 +111,10 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
"""
self.token_buffer.append(token)
complete_message = "".join(self.token_buffer)
if self.placeholder is None:
container_content = self._get_bot_message_container(complete_message)
self.placeholder = st.markdown(container_content, unsafe_allow_html=True)
else:
# Update the placeholder content
container_content = self._get_bot_message_container(complete_message)
self.placeholder.markdown(container_content, unsafe_allow_html=True)

# Update the placeholder content with the complete message
container_content = self._get_bot_message_container(complete_message)
self.placeholder.markdown(container_content, unsafe_allow_html=True)

def display_dataframe(self, df):
"""
Expand Down

0 comments on commit f573826

Please sign in to comment.