Skip to content

Commit

Permalink
Add generation feature to Svelte front-end
Browse files Browse the repository at this point in the history
Fix a few bugs and refactor generation back-end to chains.py so it can be reused.
  • Loading branch information
oskarhane committed Oct 24, 2023
1 parent f8cf77d commit af43324
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 137 deletions.
17 changes: 16 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
generate_ticket,
)
from fastapi import FastAPI, Depends
from pydantic import BaseModel
Expand Down Expand Up @@ -112,6 +113,10 @@ class Question(BaseModel):
rag: bool = False


class BaseTicket(BaseModel):
text: str


@app.get("/query-stream")
def qstream(question: Question = Depends()):
output_function = llm_chain
Expand Down Expand Up @@ -143,4 +148,14 @@ async def ask(question: Question = Depends()):
{"question": question.text, "chat_history": []}, callbacks=[]
)

return json.dumps({"result": result["answer"], "model": llm_name})
return {"result": result["answer"], "model": llm_name}


@app.get("/generate-ticket")
async def generate_ticket_api(question: BaseTicket = Depends()):
new_title, new_question = generate_ticket(
neo4j_graph=neo4j_graph,
llm_chain=llm_chain,
input_question=question.text,
)
return {"result": {"title": new_title, "text": new_question}, "model": llm_name}
72 changes: 6 additions & 66 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import (
extract_title_and_question,
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
generate_ticket,
)

load_dotenv(".env")
Expand Down Expand Up @@ -148,65 +143,6 @@ def mode_select() -> str:
output_function = rag_chain


def generate_ticket():
# Get high ranked questions
records = neo4j_graph.query(
"MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3"
)
questions = []
for i, question in enumerate(records, start=1):
questions.append((question["title"], question["body"]))
# Ask LLM to generate new question in the same style
questions_prompt = ""
for i, question in enumerate(questions, start=1):
questions_prompt += f"{i}. {question[0]}\n"
questions_prompt += f"{question[1]}\n\n"
questions_prompt += "----\n\n"

gen_system_template = f"""
You're an expert in formulating high quality questions.
Can you formulate a question in the same style, detail and tone as the following example questions?
{questions_prompt}
---
Don't make anything up, only use information in the following question.
Return a title for the question, and the question post itself.
Return example:
---
Title: How do I use the Neo4j Python driver?
Question: I'm trying to connect to Neo4j using the Python driver, but I'm getting an error.
---
"""
# we need jinja2 since the questions themselves contain curly braces
system_prompt = SystemMessagePromptTemplate.from_template(
gen_system_template, template_format="jinja2"
)
q_prompt = st.session_state[f"user_input"][-1]
chat_prompt = ChatPromptTemplate.from_messages(
[
system_prompt,
SystemMessagePromptTemplate.from_template(
"""
Respond in the following format or you will be unplugged.
---
Title: New title
Question: New question
---
"""
),
HumanMessagePromptTemplate.from_template("{text}"),
]
)
llm_response = llm_chain(
f"Here's the question to rewrite in the expected format: ```{q_prompt}```",
[],
chat_prompt,
)
new_title, new_question = extract_title_and_question(llm_response["answer"])
return (new_title, new_question)


def open_sidebar():
st.session_state.open_sidebar = True

Expand All @@ -218,7 +154,11 @@ def close_sidebar():
if not "open_sidebar" in st.session_state:
st.session_state.open_sidebar = False
if st.session_state.open_sidebar:
new_title, new_question = generate_ticket()
new_title, new_question = generate_ticket(
neo4j_graph=neo4j_graph,
llm_chain=llm_chain,
input_question=st.session_state[f"user_input"][-1],
)
with st.sidebar:
st.title("Ticket draft")
st.write("Auto generated draft ticket")
Expand Down
64 changes: 62 additions & 2 deletions chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
HumanMessagePromptTemplate,
)
from typing import List, Any
from utils import BaseLogger
from utils import BaseLogger, extract_title_and_question


def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
Expand Down Expand Up @@ -88,7 +88,9 @@ def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
chain = prompt | llm
answer = chain.invoke(user_input, config={"callbacks": callbacks}).content
answer = chain.invoke(
{"question": user_input}, config={"callbacks": callbacks}
).content
return {"answer": answer}

return generate_llm_output
Expand Down Expand Up @@ -160,3 +162,61 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
max_tokens_limit=3375,
)
return kg_qa


def generate_ticket(neo4j_graph, llm_chain, input_question):
# Get high ranked questions
records = neo4j_graph.query(
"MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3"
)
questions = []
for i, question in enumerate(records, start=1):
questions.append((question["title"], question["body"]))
# Ask LLM to generate new question in the same style
questions_prompt = ""
for i, question in enumerate(questions, start=1):
questions_prompt += f"{i}. {question[0]}\n"
questions_prompt += f"{question[1]}\n\n"
questions_prompt += "----\n\n"

gen_system_template = f"""
You're an expert in formulating high quality questions.
Can you formulate a question in the same style, detail and tone as the following example questions?
{questions_prompt}
---
Don't make anything up, only use information in the following question.
Return a title for the question, and the question post itself.
Return example:
---
Title: How do I use the Neo4j Python driver?
Question: I'm trying to connect to Neo4j using the Python driver, but I'm getting an error.
---
"""
# we need jinja2 since the questions themselves contain curly braces
system_prompt = SystemMessagePromptTemplate.from_template(
gen_system_template, template_format="jinja2"
)
chat_prompt = ChatPromptTemplate.from_messages(
[
system_prompt,
SystemMessagePromptTemplate.from_template(
"""
Respond in the following format or you will be unplugged.
---
Title: New title
Question: New question
---
"""
),
HumanMessagePromptTemplate.from_template("{question}"),
]
)
llm_response = llm_chain(
f"Here's the question to rewrite in the expected format: ```{input_question}```",
[],
chat_prompt,
)
new_title, new_question = extract_title_and_question(llm_response["answer"])
return (new_title, new_question)
101 changes: 33 additions & 68 deletions front-end/src/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,21 @@
import botImage from "./assets/images/bot.jpeg";
import meImage from "./assets/images/me.jpeg";
import MdLink from "./lib/MdLink.svelte";
import External from "./lib/External.svelte";
import { chatStates, chatStore } from "./lib/chat.store.js";
import Modal from "./lib/Modal.svelte";
import { generationStore } from "./lib/generation.store";
let messages = [];
let ragMode = true;
let ragMode = false;
let question = "How can I create a chatbot on top of my local PDF files using langchain?";
let shouldAutoScroll = true;
let input;
let appState = "idle"; // or receiving
let senderImages = { bot: botImage, me: meImage };
let generationModalOpen = false;
async function send() {
if (!question.trim().length) {
return;
}
appState = "receiving";
addMessage("me", question, ragMode);
const messageId = addMessage("bot", "", ragMode);
try {
const evt = new EventSource(
`http://localhost:8504/query-stream?text=${encodeURI(question)}&rag=${ragMode}`
);
question = "";
evt.onmessage = (e) => {
if (e.data) {
const data = JSON.parse(e.data);
if (data.init) {
updateMessage(messageId, "", data.model);
return;
}
updateMessage(messageId, data.token);
}
};
evt.onerror = (e) => {
// Stream will end with an error
// and we want to close the connection on end (otherwise it will keep reconnecting)
evt.close();
};
} catch (e) {
updateMessage(messageId, "Error: " + e.message);
} finally {
appState = "idle";
}
}
function updateMessage(existingId, text, model = null) {
if (!existingId) {
return;
}
const existingIdIndex = messages.findIndex((m) => m.id === existingId);
if (existingIdIndex === -1) {
return;
}
messages[existingIdIndex].text += text;
if (model) {
messages[existingIdIndex].model = model;
}
messages = messages;
}
function addMessage(from, text, rag) {
const newId = Math.random().toString(36).substring(2, 9);
const message = { id: newId, from, text, rag };
messages = messages.concat([message]);
return newId;
function send() {
chatStore.send(question, ragMode);
question = "";
}
function scrollToBottom(node, _) {
Expand All @@ -79,7 +31,12 @@
shouldAutoScroll = e.target.scrollTop + e.target.clientHeight > e.target.scrollHeight - 55;
}
$: appState === "idle" && input && focus(input);
function generateTicket(text) {
generationStore.generate(text);
generationModalOpen = true;
}
$: $chatStore.state === chatStates.IDLE && input && focus(input);
async function focus(node) {
await tick();
node.focus();
Expand All @@ -88,24 +45,29 @@
</script>

<main class="h-full text-sm bg-gradient-to-t from-indigo-100 bg-fixed overflow-hidden">
<div on:scroll={scrolling} class="flex h-full flex-col py-12 overflow-y-auto" use:scrollToBottom={messages}>
<div on:scroll={scrolling} class="flex h-full flex-col py-12 overflow-y-auto" use:scrollToBottom={$chatStore}>
<div class="w-4/5 mx-auto flex flex-col mb-32">
{#each messages as message (message.id)}
{#each $chatStore.data as message (message.id)}
<div
class="max-w-[80%] min-w-[40%] rounded-lg p-4 mb-4 overflow-x-auto bg-white border border-indigo-200"
class:self-end={message.from === "me"}
class:text-right={message.from === "me"}
>
<div class="flex flex-row items-start gap-2">
<div class="flex flex-row gap-2">
{#if message.from === "me"}
<button
aria-label="Generate a new internal ticket from this question"
title="Generate a new internal ticket from this question"
on:click={() => generateTicket(message.text)}
class="w-6 h-6 flex flex-col justify-center items-center border rounded border-indigo-200"
><External --color="#ccc" --hover-color="#999" /></button
>
{/if}
<div
class:ml-auto={message.from === "me"}
class="relative w-12 h-12 border border-indigo-200 rounded-lg flex justify-center items-center"
class="relative w-12 h-12 border border-indigo-200 rounded flex justify-center items-center overflow-hidden"
>
<img
src={senderImages[message.from]}
alt=""
class="w-12 h-12 absolute top-0 left-0 rounded-lg"
/>
<img src={senderImages[message.from]} alt="" class="rounded-sm" />
</div>
{#if message.from === "bot"}
<div class="text-sm">
Expand Down Expand Up @@ -133,7 +95,7 @@
</div>
<form class="rounded-md w-full bg-white p-2 m-0" on:submit|preventDefault={send}>
<input
disabled={appState === "receiving"}
disabled={$chatStore.state === chatStates.RECEIVING}
class="text-lg w-full bg-white focus:outline-none px-4"
bind:value={question}
bind:this={input}
Expand All @@ -144,6 +106,9 @@
</div>
</div>
</main>
{#if generationModalOpen}
<Modal title="my title" text="my text" on:close={() => (generationModalOpen = false)} />
{/if}

<style>
:global(pre) {
Expand Down
28 changes: 28 additions & 0 deletions front-end/src/lib/External.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<div class="w-full h-full">
<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"
><g id="SVGRepo_bgCarrier" stroke-width="0" /><g
id="SVGRepo_tracerCarrier"
stroke-linecap="round"
stroke-linejoin="round"
/><g id="SVGRepo_iconCarrier">
<g id="Interface / External_Link">
<path
id="Vector"
d="M10.0002 5H8.2002C7.08009 5 6.51962 5 6.0918 5.21799C5.71547 5.40973 5.40973 5.71547 5.21799 6.0918C5 6.51962 5 7.08009 5 8.2002V15.8002C5 16.9203 5 17.4801 5.21799 17.9079C5.40973 18.2842 5.71547 18.5905 6.0918 18.7822C6.5192 19 7.07899 19 8.19691 19H15.8031C16.921 19 17.48 19 17.9074 18.7822C18.2837 18.5905 18.5905 18.2839 18.7822 17.9076C19 17.4802 19 16.921 19 15.8031V14M20 9V4M20 4H15M20 4L13 11"
stroke-width="1"
stroke-linecap="round"
stroke-linejoin="round"
/>
</g>
</g></svg
>
</div>

<style>
svg {
stroke: var(--color, #000);
}
svg:hover {
stroke: var(--hover-color, #000);
}
</style>
Loading

0 comments on commit af43324

Please sign in to comment.