Skip to content

Commit

Permalink
code quality fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sachin-duhan committed Oct 8, 2023
1 parent 5168133 commit 17ca0e7
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 50 deletions.
33 changes: 33 additions & 0 deletions .github/media/workflows/quality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name:

on: push

jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install deps
uses: knowsuchagency/poetry-install@v1
env:
POETRY_VIRTUALENVS_CREATE: false
- name: Run black check
run: python3 -m black --check .
flake8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install deps
uses: knowsuchagency/poetry-install@v1
env:
POETRY_VIRTUALENVS_CREATE: false
- name: Run flake8 check
run: python3 -m flake8 --count .
53 changes: 53 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: check-ast
- id: trailing-whitespace
- id: check-toml
- id: end-of-file-fixer

- repo: https://github.com/asottile/add-trailing-comma
rev: v2.1.0
hooks:
- id: add-trailing-comma

- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.1.0
hooks:
- id: pretty-format-yaml
args:
- --autofix
- --preserve-quotes
- --indent=2

- repo: local
hooks:
- id: autoflake
name: autoflake
entry: poetry run autoflake
language: system
types: [python]
args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]

- id: black
name: Format with Black
entry: poetry run black
language: system
types: [python]

- id: isort
name: isort
entry: poetry run isort
language: system
types: [python]

- id: flake8
name: Check with Flake8
entry: poetry run flake8
language: system
pass_filenames: false
types: [python]
args: [--count, .]
69 changes: 39 additions & 30 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,44 +83,52 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:


def chat_input():
user_input = st.chat_input("What coding issue can I help you resolve today?")

if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
st.session_state[f"user_input"].append(user_input)
st.session_state[f"generated"].append(output)
st.session_state[f"rag_mode"].append(name)
if not (
user_input := st.chat_input(
"What coding issue can I help you resolve today?"
)
):
return
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
_extracted_from_chat_input_(user_input)


# TODO Rename this here and in `chat_input`
def _extracted_from_chat_input_(user_input):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
st.session_state["user_input"].append(user_input)
st.session_state["generated"].append(output)
st.session_state["rag_mode"].append(name)


def display_chat():
# Session state
if "generated" not in st.session_state:
st.session_state[f"generated"] = []
st.session_state["generated"] = []

if "user_input" not in st.session_state:
st.session_state[f"user_input"] = []
st.session_state["user_input"] = []

if "rag_mode" not in st.session_state:
st.session_state[f"rag_mode"] = []
st.session_state["rag_mode"] = []

if st.session_state[f"generated"]:
size = len(st.session_state[f"generated"])
if st.session_state["generated"]:
size = len(st.session_state["generated"])
# Display only the last three exchanges
for i in range(max(size - 3, 0), size):
with st.chat_message("user"):
st.write(st.session_state[f"user_input"][i])
st.write(st.session_state["user_input"][i])

with st.chat_message("assistant"):
st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}")
st.write(st.session_state[f"generated"][i])
st.caption(f"RAG: {st.session_state['rag_mode'][i]}")
st.write(st.session_state["generated"][i])

with st.expander("Not finding what you're looking for?"):
st.write(
Expand All @@ -142,9 +150,9 @@ def mode_select() -> str:


name = mode_select()
if name == "LLM only" or name == "Disabled":
if name in ["LLM only", "Disabled"]:
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
elif name in ["Vector + Graph", "Enabled"]:
output_function = rag_chain


Expand All @@ -153,9 +161,10 @@ def generate_ticket():
records = neo4j_graph.query(
"MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3"
)
questions = []
for i, question in enumerate(records, start=1):
questions.append((question["title"], question["body"]))
questions = [
(question["title"], question["body"])
for i, question in enumerate(records, start=1)
]
# Ask LLM to generate new question in the same style
questions_prompt = ""
for i, question in enumerate(questions, start=1):
Expand All @@ -182,7 +191,7 @@ def generate_ticket():
system_prompt = SystemMessagePromptTemplate.from_template(
gen_system_template, template_format="jinja2"
)
q_prompt = st.session_state[f"user_input"][-1]
q_prompt = st.session_state["user_input"][-1]
chat_prompt = ChatPromptTemplate.from_messages(
[
system_prompt,
Expand Down Expand Up @@ -215,7 +224,7 @@ def close_sidebar():
st.session_state.open_sidebar = False


if not "open_sidebar" in st.session_state:
if "open_sidebar" not in st.session_state:
st.session_state.open_sidebar = False
if st.session_state.open_sidebar:
new_title, new_question = generate_ticket()
Expand Down
11 changes: 7 additions & 4 deletions chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from utils import BaseLogger


def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=None):
if config is None:
config = {}
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(
base_url=config["ollama_base_url"], model="llama2"
Expand All @@ -33,7 +35,9 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=
return embeddings, dimension


def load_llm(llm_name: str, logger=BaseLogger(), config={}):
def load_llm(llm_name: str, logger=BaseLogger(), config=None):
if config is None:
config = {}
if llm_name == "gpt-4":
logger.info("LLM: Using GPT-4")
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
Expand Down Expand Up @@ -140,10 +144,9 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
""",
)

kg_qa = RetrievalQAWithSourcesChain(
return RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
)
return kg_qa
3 changes: 1 addition & 2 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ def insert_so_data(data: dict) -> None:

# Streamlit
def get_tag() -> str:
input_text = st.text_input(
return st.text_input(
"Which tag questions do you want to import?", value="neo4j"
)
return input_text


def get_pages():
Expand Down
12 changes: 4 additions & 8 deletions pdf_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def main():
if pdf is not None:
pdf_reader = PdfReader(pdf)

text = ""
for page in pdf_reader.pages:
text += page.extract_text()

text = "".join(page.extract_text() for page in pdf_reader.pages)
# langchain_textspliter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, length_function=len
Expand All @@ -83,10 +80,9 @@ def main():
llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever()
)

# Accept user questions/query
query = st.text_input("Ask questions about related your upload pdf file")

if query:
if query := st.text_input(
"Ask questions about related your upload pdf file"
):
stream_handler = StreamHandler(st.empty())
qa.run(query, callbacks=[stream_handler])

Expand Down
9 changes: 3 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
class BaseLogger:
def __init__(self) -> None:
self.info = print
Expand Down Expand Up @@ -28,15 +29,11 @@ def extract_title_and_question(input_string):

def create_vector_index(driver, dimension: int) -> None:
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
try:
with contextlib.suppress(Exception):
driver.query(index_query, {"dimension": dimension})
except: # Already exists
pass
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
try:
with contextlib.suppress(Exception):
driver.query(index_query, {"dimension": dimension})
except: # Already exists
pass


def create_constraints(driver):
Expand Down

0 comments on commit 17ca0e7

Please sign in to comment.