Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
mmz-001 committed Jul 6, 2023
2 parents ecff70f + f82314b commit 4b7f366
Show file tree
Hide file tree
Showing 32 changed files with 769 additions and 254 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Local data
data/local_data/
resources/local_data/

# Prototyping notebooks
notebooks/
Expand Down
Binary file removed data/paul_graham_essay.docx
Binary file not shown.
7 changes: 1 addition & 6 deletions knowledge_gpt/components/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
load_dotenv()


def set_openai_api_key(api_key: str):
st.session_state["OPENAI_API_KEY"] = api_key


def sidebar():
with st.sidebar:
st.markdown(
Expand All @@ -28,8 +24,7 @@ def sidebar():
or st.session_state.get("OPENAI_API_KEY", ""),
)

if api_key_input:
set_openai_api_key(api_key_input)
st.session_state["OPENAI_API_KEY"] = api_key_input

st.markdown("---")
st.markdown("# About")
Expand Down
File renamed without changes.
33 changes: 33 additions & 0 deletions knowledge_gpt/core/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import streamlit as st
from streamlit.runtime.caching.hashing import HashFuncsDict

import knowledge_gpt.core.parsing as parsing
import knowledge_gpt.core.chunking as chunking
import knowledge_gpt.core.embedding as embedding
from knowledge_gpt.core.parsing import File


def file_hash_func(file: File) -> str:
"""Get a unique hash for a file"""
return file.id


@st.cache_resource()
def bootstrap_caching():
"""Patch module functions with caching"""

# Get all substypes of File from module
file_subtypes = [
cls
for cls in vars(parsing).values()
if isinstance(cls, type) and issubclass(cls, File) and cls != File
]
file_hash_funcs: HashFuncsDict = {cls: file_hash_func for cls in file_subtypes}

parsing.read_file = st.cache_data(show_spinner=False)(parsing.read_file)
chunking.chunk_file = st.cache_data(show_spinner=False, hash_funcs=file_hash_funcs)(
chunking.chunk_file
)
embedding.embed_files = st.cache_data(
show_spinner=False, hash_funcs=file_hash_funcs
)(embedding.embed_files)
38 changes: 38 additions & 0 deletions knowledge_gpt/core/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from knowledge_gpt.core.parsing import File


def chunk_file(
file: File, chunk_size: int, chunk_overlap: int = 0, model_name="gpt-3.5-turbo"
) -> File:
"""Chunks each document in a file into smaller documents
according to the specified chunk size and overlap
where the size is determined by the number of token for the specified model.
"""

# split each document into chunks
chunked_docs = []
for doc in file.docs:
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
model_name=model_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)

chunks = text_splitter.split_text(doc.page_content)

for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk,
metadata={
"page": doc.metadata.get("page", 1),
"chunk": i + 1,
"source": f"{doc.metadata.get('page', 1)}-{i + 1}",
},
)
chunked_docs.append(doc)

chunked_file = file.copy()
chunked_file.docs = chunked_docs
return chunked_file
71 changes: 71 additions & 0 deletions knowledge_gpt/core/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from langchain.vectorstores import VectorStore
from knowledge_gpt.core.parsing import File
from langchain.vectorstores.faiss import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from typing import List, Type
from langchain.docstore.document import Document


class FolderIndex:
"""Index for a collection of files (a folder)"""

def __init__(self, files: List[File], index: VectorStore):
self.name: str = "default"
self.files = files
self.index: VectorStore = index

@staticmethod
def _combine_files(files: List[File]) -> List[Document]:
"""Combines all the documents in a list of files into a single list."""

all_texts = []
for file in files:
for doc in file.docs:
doc.metadata["file_name"] = file.name
doc.metadata["file_id"] = file.id
all_texts.append(doc)

return all_texts

@classmethod
def from_files(
cls, files: List[File], embeddings: Embeddings, vector_store: Type[VectorStore]
) -> "FolderIndex":
"""Creates an index from files."""

all_docs = cls._combine_files(files)

index = vector_store.from_documents(
documents=all_docs,
embedding=embeddings,
)

return cls(files=files, index=index)


def embed_files(
files: List[File], embedding: str, vector_store: str, **kwargs
) -> FolderIndex:
"""Embeds a collection of files and stores them in a FolderIndex."""

supported_embeddings: dict[str, Type[Embeddings]] = {
"openai": OpenAIEmbeddings,
}
supported_vector_stores: dict[str, Type[VectorStore]] = {
"faiss": FAISS,
}

if embedding in supported_embeddings:
_embeddings = supported_embeddings[embedding](**kwargs)
else:
raise NotImplementedError(f"Embedding {embedding} not supported.")

if vector_store in supported_vector_stores:
_vector_store = supported_vector_stores[vector_store]
else:
raise NotImplementedError(f"Vector store {vector_store} not supported.")

return FolderIndex.from_files(
files=files, embeddings=_embeddings, vector_store=_vector_store
)
102 changes: 102 additions & 0 deletions knowledge_gpt/core/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from io import BytesIO
from typing import List, Any, Optional
import re

import docx2txt
from langchain.docstore.document import Document
from pypdf import PdfReader
from hashlib import md5

from abc import abstractmethod, ABC
from copy import deepcopy


class File(ABC):
"""Represents an uploaded file comprised of Documents"""

def __init__(
self,
name: str,
id: str,
metadata: Optional[dict[str, Any]] = None,
docs: Optional[List[Document]] = None,
):
self.name = name
self.id = id
self.metadata = metadata or {}
self.docs = docs or []

@classmethod
@abstractmethod
def from_bytes(cls, file: BytesIO) -> "File":
"""Creates a File from a BytesIO object"""

def __repr__(self) -> str:
return (
f"File(name={self.name}, id={self.id},"
" metadata={self.metadata}, docs={self.docs})"
)

def __str__(self) -> str:
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})"

def copy(self) -> "File":
"""Create a deep copy of this File"""
return self.__class__(
name=self.name,
id=self.id,
metadata=deepcopy(self.metadata),
docs=deepcopy(self.docs),
)


def strip_consecutive_newlines(text: str) -> str:
"""Strips consecutive newlines from a string
possibly with whitespace in between
"""
return re.sub(r"\s*\n\s*", "\n", text)


class DocxFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "DocxFile":
text = docx2txt.process(file)
text = strip_consecutive_newlines(text)
doc = Document(page_content=text.strip())
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])


class PdfFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "PdfFile":
pdf = PdfReader(file)
docs = []
for i, page in enumerate(pdf.pages):
text = page.extract_text()
text = strip_consecutive_newlines(text)
doc = Document(page_content=text.strip())
doc.metadata["page"] = i + 1
docs.append(doc)
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=docs)


class TxtFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "TxtFile":
text = file.read().decode("utf-8")
text = strip_consecutive_newlines(text)
file.seek(0)
doc = Document(page_content=text.strip())
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])


def read_file(file: BytesIO) -> File:
"""Reads an uploaded file and returns a File object"""
if file.name.endswith(".docx"):
return DocxFile.from_bytes(file)
elif file.name.endswith(".pdf"):
return PdfFile.from_bytes(file)
elif file.name.endswith(".txt"):
return TxtFile.from_bytes(file)
else:
raise NotImplementedError
File renamed without changes.
62 changes: 62 additions & 0 deletions knowledge_gpt/core/qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Any, List
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from knowledge_gpt.core.prompts import STUFF_PROMPT
from langchain.docstore.document import Document
from langchain.chat_models import ChatOpenAI
from knowledge_gpt.core.embedding import FolderIndex
from pydantic import BaseModel


class AnswerWithSources(BaseModel):
answer: str
sources: List[Document]


def query_folder(
query: str, folder_index: FolderIndex, return_all: bool = False, **model_kwargs: Any
) -> AnswerWithSources:
"""Queries a folder index for an answer.
Args:
query (str): The query to search for.
folder_index (FolderIndex): The folder index to search.
return_all (bool): Whether to return all the documents from the embedding or
just the sources for the answer.
**model_kwargs (Any): Keyword arguments for the model.
Returns:
AnswerWithSources: The answer and the source documents.
"""

chain = load_qa_with_sources_chain(
llm=ChatOpenAI(**model_kwargs),
chain_type="stuff",
prompt=STUFF_PROMPT,
)

relevant_docs = folder_index.index.similarity_search(query, k=5)
result = chain(
{"input_documents": relevant_docs, "question": query}, return_only_outputs=True
)
sources = relevant_docs

if not return_all:
sources = get_sources(result["output_text"], folder_index)

answer = result["output_text"].split("SOURCES: ")[0]

return AnswerWithSources(answer=answer, sources=sources)


def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]:
"""Retrieves the docs that were used to answer the question the generated answer."""

source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")]

source_docs = []
for file in folder_index.files:
for doc in file.docs:
if doc.metadata["source"] in source_keys:
source_docs.append(doc)

return source_docs
Loading

0 comments on commit 4b7f366

Please sign in to comment.