Skip to content

Commit

Permalink
Feature/bot: Multi-session memory support (#30)
Browse files Browse the repository at this point in the history
* feat: Implement StatelessMemorySequentialChain for serving

* feat: Finish StatelessMemorySequentialChain implementation

* feat: Adapt history support to Beam

* fix: Linting issues

* chore: Remove eos_token_id from template

* fix: history keys issues
  • Loading branch information
iusztinpaul authored Oct 19, 2023
1 parent 00d3498 commit 07d1c7d
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 38 deletions.
24 changes: 24 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@
"I am a student and I have some money that I want to invest.",
"--question",
"Should I consider investing in stocks from the Tech Sector?",
"--history",
"[[\"What is your opinion on investing in startup companies?\", \"Startup investments can be very lucrative, but they also come with a high degree of risk. It is important to do your due diligence and research the company thoroughly before investing.\"]]",
"--debug",
"False"
]
Expand All @@ -147,9 +149,31 @@
"I am a student and I have some money that I want to invest.",
"--question",
"Should I consider investing in stocks from the Tech Sector?",
"--history",
"[[\"What is your opinion on investing in startup companies?\", \"Startup investments can be very lucrative, but they also come with a high degree of risk. It is important to do your due diligence and research the company thoroughly before investing.\"]]",
"--debug",
"True"
]
},
{
"name": "Financial Bot UI",
"type": "python",
"request": "launch",
"module": "tools.ui",
"justMyCode": false,
"cwd": "${workspaceFolder}/modules/financial_bot",
"args": []
},
{
"name": "Financial Bot UI [Dev]",
"type": "python",
"request": "launch",
"module": "tools.ui",
"justMyCode": false,
"cwd": "${workspaceFolder}/modules/financial_bot",
"args": [
"--debug"
]
},
]
}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Thus, check out the README for every module individually to see how to install &
1. [q_and_a_dataset_generator](/modules/q_and_a_dataset_generator/)
2. [training_pipeline](/modules/training_pipeline/)
3. [streaming_pipeline](/modules/streaming_pipeline/)
4. [inference_pipeline]()
4. [inference_pipeline](/modules/financial_bot/)


### 3.1 Run Notebooks Server
Expand Down
8 changes: 5 additions & 3 deletions modules/financial_bot/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ run:
poetry run python -m tools.bot \
--about_me "I am a student and I have some money that I want to invest." \
--question "Should I consider investing in stocks from the Tech Sector?" \
--history "[[\"What is your opinion on investing in startup companies?\", \"Startup investments can be very lucrative, but they also come with a high degree of risk. It is important to do your due diligence and research the company thoroughly before investing.\"]]" \
--debug False

run_dev:
Expand All @@ -36,6 +37,7 @@ run_dev:
poetry run python -m tools.bot \
--about_me "I am a student and I have some money that I want to invest." \
--question "Should I consider investing in stocks from the Tech Sector?" \
--history "[[\"What is your opinion on investing in startup companies?\", \"Startup investments can be very lucrative, but they also come with a high degree of risk. It is important to do your due diligence and research the company thoroughly before investing.\"]]" \
--debug True

run_ui:
Expand All @@ -46,7 +48,7 @@ run_ui:
run_ui_dev:
@echo "Running financial_bot UI [Dev Mode]..."

poetry run gradio tools/ui.py
poetry run gradio tools/ui.py --debug


# === Beam ===
Expand All @@ -69,13 +71,13 @@ deploy_beam_dev: export_requirements

call_restful_api:
curl -X POST \
--compressed 'https://apps.beam.cloud/${DEPLOYMENT_ID}' \
--compressed 'https://${DEPLOYMENT_ID}.apps.beam.cloud' \
-H 'Accept: */*' \
-H 'Accept-Encoding: gzip, deflate' \
-H 'Authorization: Basic ${TOKEN}' \
-H 'Connection: keep-alive' \
-H 'Content-Type: application/json' \
-d '{"about_me": "I am a student and I have some money that I want to invest.", "question": "Should I consider investing in stocks from the Tech Sector?"}'
-d '{"about_me": "I am a student and I have some money that I want to invest.", "question": "Should I consider investing in stocks from the Tech Sector?", "history": [["What is your opinion on investing in startup companies?", "Startup investments can be very lucrative, but they also come with a high degree of risk. It is important to do your due diligence and research the company thoroughly before investing."]]}'

# === Formatting & Linting ===
# Be sure to install the dev dependencies first #
Expand Down
51 changes: 45 additions & 6 deletions modules/financial_bot/financial_bot/chains.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,54 @@
from typing import Any, Dict, List

import qdrant_client
from langchain import chains
from langchain.chains.base import Chain
from langchain.llms import HuggingFacePipeline

from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.template import PromptTemplate


class StatelessMemorySequentialChain(chains.SequentialChain):
history_input_key: str = "to_load_history"

def _call(self, inputs: Dict[str, str], **kwargs) -> Dict[str, str]:
"""Override _call to load history before calling the chain."""

to_load_history = inputs[self.history_input_key]
for (
human,
ai,
) in to_load_history:
self.memory.save_context(
inputs={self.memory.input_key: human},
outputs={self.memory.output_key: ai},
)
memory_values = self.memory.load_memory_variables({})
inputs.update(memory_values)

del inputs[self.history_input_key]

return super()._call(inputs, **kwargs)

def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Override prep_outputs to clear the internal memory after each call."""

results = super().prep_outputs(inputs, outputs, return_only_outputs)

# Clear the internal memory.
self.memory.clear()
if self.memory.memory_key in results:
results[self.memory.memory_key] = ""

return results


class ContextExtractorChain(Chain):
"""
Encode the question, search the vector store for top-k articles and return
Expand All @@ -18,15 +59,14 @@ class ContextExtractorChain(Chain):
embedding_model: EmbeddingModelSingleton
vector_store: qdrant_client.QdrantClient
vector_collection: str
output_key: str = "context"

@property
def input_keys(self) -> List[str]:
return ["about_me", "question"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]
return ["context"]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
_, quest_key = self.input_keys
Expand All @@ -46,7 +86,7 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
context += match.payload["summary"] + "\n"

return {
self.output_key: context,
"context": context,
}


Expand All @@ -55,15 +95,14 @@ class FinancialBotQAChain(Chain):

hf_pipeline: HuggingFacePipeline
template: PromptTemplate
output_key: str = "answer"

@property
def input_keys(self) -> List[str]:
return ["context"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]
return ["answer"]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
prompt = self.template.format_infer(
Expand All @@ -76,4 +115,4 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
)["prompt"]
response = self.hf_pipeline(prompt)

return {self.output_key: response}
return {"answer": response}
45 changes: 33 additions & 12 deletions modules/financial_bot/financial_bot/langchain_bot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import Iterable, List, Tuple

from langchain import chains
from langchain.memory import ConversationBufferMemory
from langchain.memory import ConversationBufferWindowMemory

from financial_bot import constants
from financial_bot.chains import ContextExtractorChain, FinancialBotQAChain
from financial_bot.chains import (
ContextExtractorChain,
FinancialBotQAChain,
StatelessMemorySequentialChain,
)
from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.models import build_huggingface_pipeline
from financial_bot.qdrant import build_qdrant_client
Expand All @@ -24,6 +28,7 @@ def __init__(
vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME,
vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK,
model_cache_dir: Path = constants.CACHE_DIR,
streaming: bool = False,
embedding_model_device: str = "cuda:0",
debug: bool = False,
):
Expand All @@ -40,12 +45,16 @@ def __init__(
self._llm_agent, self._streamer = build_huggingface_pipeline(
llm_model_id=llm_model_id,
llm_lora_model_id=llm_lora_model_id,
use_streamer=True,
use_streamer=streaming,
cache_dir=model_cache_dir,
debug=debug,
)
self.finbot_chain = self.build_chain()

@property
def is_streaming(self) -> bool:
return self._streamer is not None

def build_chain(self) -> chains.SequentialChain:
"""
Constructs and returns a financial bot chain.
Expand Down Expand Up @@ -91,13 +100,16 @@ def build_chain(self) -> chains.SequentialChain:
)

logger.info("Building 3/3 - Connecting chains into SequentialChain")
# TODO: Change memory to keep TOP k messages or a summary of the conversation.
seq_chain = chains.SequentialChain(
memory=ConversationBufferMemory(
memory_key="chat_history", input_key="question"
seq_chain = StatelessMemorySequentialChain(
history_input_key="to_load_history",
memory=ConversationBufferWindowMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
k=3,
),
chains=[context_retrieval_chain, llm_generator_chain],
input_variables=["about_me", "question"],
input_variables=["about_me", "question", "to_load_history"],
output_variables=["answer"],
verbose=True,
)
Expand All @@ -114,7 +126,12 @@ def build_chain(self) -> chains.SequentialChain:

return seq_chain

def answer(self, about_me: str, question: str) -> str:
def answer(
self,
about_me: str,
question: str,
to_load_history: List[Tuple[str, str]] = None,
) -> str:
"""
Given a short description about the user and a question make the LLM
generate a response.
Expand All @@ -132,7 +149,11 @@ def answer(self, about_me: str, question: str) -> str:
LLM generated response.
"""

inputs = {"about_me": about_me, "question": question}
inputs = {
"about_me": about_me,
"question": question,
"to_load_history": to_load_history if to_load_history else [],
}
response = self.finbot_chain.run(inputs)

return response
Expand All @@ -141,7 +162,7 @@ def stream_answer(self) -> Iterable[str]:
"""Stream the answer from the LLM after each token is generated after calling `answer()`."""

assert (
self._streamer
self.is_streaming
), "Stream answer not available. Build the bot with `use_streamer=True`."

partial_answer = ""
Expand Down
36 changes: 29 additions & 7 deletions modules/financial_bot/financial_bot/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from comet_ml import API
Expand All @@ -11,6 +11,8 @@
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
pipeline,
)
Expand Down Expand Up @@ -47,6 +49,22 @@ def download_from_model_registry(model_id: str, cache_dir: Optional[Path] = None
return model_dir


class StopOnTokens(StoppingCriteria):
def __init__(self, stop_ids: List[int]):
super().__init__()

self._stop_ids = stop_ids

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in self._stop_ids:
if input_ids[0][-1] == stop_id:
return True

return False


def build_huggingface_pipeline(
llm_model_id: str,
llm_lora_model_id: str,
Expand All @@ -60,8 +78,11 @@ def build_huggingface_pipeline(
"""Using our custom LLM + Finetuned checkpoint we create a HF pipeline"""

if debug is True:
return HuggingFacePipeline(
pipeline=MockedPipeline(f=lambda _: "You are doing great!")
return (
HuggingFacePipeline(
pipeline=MockedPipeline(f=lambda _: "You are doing great!")
),
None,
)

model, tokenizer, _ = build_qlora_model(
Expand All @@ -76,8 +97,11 @@ def build_huggingface_pipeline(
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop_on_tokens = StopOnTokens(stop_ids=[tokenizer.eos_token_id])
stopping_criteria = StoppingCriteriaList([stop_on_tokens])
else:
streamer = None
stopping_criteria = []

pipe = pipeline(
"text-generation",
Expand All @@ -86,13 +110,11 @@ def build_huggingface_pipeline(
max_new_tokens=max_new_tokens,
temperature=temperature,
streamer=streamer,
stopping_criteria=stopping_criteria,
)
hf = HuggingFacePipeline(pipeline=pipe)

if use_streamer:
return hf, streamer

return hf
return hf, streamer


def build_qlora_model(
Expand Down
Loading

0 comments on commit 07d1c7d

Please sign in to comment.