diff --git a/.gitignore b/.gitignore index 93ecc0ff..0fcdd970 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,149 @@ data/ embedding_model/* !embedding_model/.ignore .DS_Store + +### Python template + +.idea/ +.vscode/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +*.sqlite3 +*.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# secret config files +*.secret.* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..0c530150 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.4.0 + hooks: + - id: check-ast + - id: trailing-whitespace + - id: check-toml + - id: end-of-file-fixer + +- repo: https://github.com/asottile/add-trailing-comma + rev: v2.1.0 + hooks: + - id: add-trailing-comma + +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.1.0 + hooks: + - id: pretty-format-yaml + args: + - --autofix + - --preserve-quotes + - --indent=2 + +- repo: local + hooks: + - id: autoflake + name: autoflake + entry: poetry run autoflake + language: system + types: [python] + args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys] + + - id: black + name: Format with Black + entry: poetry run black + language: system + types: [python] + + - id: isort + name: isort + entry: poetry run isort + language: system + types: [python] + + - id: flake8 + name: Check with Flake8 + entry: poetry run flake8 + language: system + pass_filenames: false + types: [python] + args: [--count, .] diff --git a/bot.py b/bot.py index 2290602a..397b2785 100644 --- a/bot.py +++ b/bot.py @@ -83,44 +83,52 @@ def on_llm_new_token(self, token: str, **kwargs) -> None: def chat_input(): - user_input = st.chat_input("What coding issue can I help you resolve today?") - - if user_input: - with st.chat_message("user"): - st.write(user_input) - with st.chat_message("assistant"): - st.caption(f"RAG: {name}") - stream_handler = StreamHandler(st.empty()) - result = output_function( - {"question": user_input, "chat_history": []}, callbacks=[stream_handler] - )["answer"] - output = result - st.session_state[f"user_input"].append(user_input) - st.session_state[f"generated"].append(output) - st.session_state[f"rag_mode"].append(name) + if not ( + user_input := st.chat_input( + "What coding issue can I help you resolve today?" + ) + ): + return + with st.chat_message("user"): + st.write(user_input) + with st.chat_message("assistant"): + _extracted_from_chat_input_(user_input) + + +# TODO Rename this here and in `chat_input` +def _extracted_from_chat_input_(user_input): + st.caption(f"RAG: {name}") + stream_handler = StreamHandler(st.empty()) + result = output_function( + {"question": user_input, "chat_history": []}, callbacks=[stream_handler] + )["answer"] + output = result + st.session_state["user_input"].append(user_input) + st.session_state["generated"].append(output) + st.session_state["rag_mode"].append(name) def display_chat(): # Session state if "generated" not in st.session_state: - st.session_state[f"generated"] = [] + st.session_state["generated"] = [] if "user_input" not in st.session_state: - st.session_state[f"user_input"] = [] + st.session_state["user_input"] = [] if "rag_mode" not in st.session_state: - st.session_state[f"rag_mode"] = [] + st.session_state["rag_mode"] = [] - if st.session_state[f"generated"]: - size = len(st.session_state[f"generated"]) + if st.session_state["generated"]: + size = len(st.session_state["generated"]) # Display only the last three exchanges for i in range(max(size - 3, 0), size): with st.chat_message("user"): - st.write(st.session_state[f"user_input"][i]) + st.write(st.session_state["user_input"][i]) with st.chat_message("assistant"): - st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}") - st.write(st.session_state[f"generated"][i]) + st.caption(f"RAG: {st.session_state['rag_mode'][i]}") + st.write(st.session_state["generated"][i]) with st.expander("Not finding what you're looking for?"): st.write( @@ -142,9 +150,9 @@ def mode_select() -> str: name = mode_select() -if name == "LLM only" or name == "Disabled": +if name in ["LLM only", "Disabled"]: output_function = llm_chain -elif name == "Vector + Graph" or name == "Enabled": +elif name in ["Vector + Graph", "Enabled"]: output_function = rag_chain @@ -153,9 +161,10 @@ def generate_ticket(): records = neo4j_graph.query( "MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3" ) - questions = [] - for i, question in enumerate(records, start=1): - questions.append((question["title"], question["body"])) + questions = [ + (question["title"], question["body"]) + for i, question in enumerate(records, start=1) + ] # Ask LLM to generate new question in the same style questions_prompt = "" for i, question in enumerate(questions, start=1): @@ -182,7 +191,7 @@ def generate_ticket(): system_prompt = SystemMessagePromptTemplate.from_template( gen_system_template, template_format="jinja2" ) - q_prompt = st.session_state[f"user_input"][-1] + q_prompt = st.session_state["user_input"][-1] chat_prompt = ChatPromptTemplate.from_messages( [ system_prompt, @@ -215,7 +224,7 @@ def close_sidebar(): st.session_state.open_sidebar = False -if not "open_sidebar" in st.session_state: +if "open_sidebar" not in st.session_state: st.session_state.open_sidebar = False if st.session_state.open_sidebar: new_title, new_question = generate_ticket() diff --git a/chains.py b/chains.py index c8b82b75..5bb2e324 100644 --- a/chains.py +++ b/chains.py @@ -17,7 +17,9 @@ from utils import BaseLogger -def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): +def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=None): + if config is None: + config = {} if embedding_model_name == "ollama": embeddings = OllamaEmbeddings( base_url=config["ollama_base_url"], model="llama2" @@ -41,7 +43,9 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config= return embeddings, dimension -def load_llm(llm_name: str, logger=BaseLogger(), config={}): +def load_llm(llm_name: str, logger=BaseLogger(), config=None): + if config is None: + config = {} if llm_name == "gpt-4": logger.info("LLM: Using GPT-4") return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True) @@ -153,10 +157,9 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass """, ) - kg_qa = RetrievalQAWithSourcesChain( + return RetrievalQAWithSourcesChain( combine_documents_chain=qa_chain, retriever=kg.as_retriever(search_kwargs={"k": 2}), reduce_k_below_max_tokens=False, max_tokens_limit=3375, ) - return kg_qa diff --git a/loader.py b/loader.py index 8cc08023..c047448f 100644 --- a/loader.py +++ b/loader.py @@ -97,10 +97,9 @@ def insert_so_data(data: dict) -> None: # Streamlit def get_tag() -> str: - input_text = st.text_input( + return st.text_input( "Which tag questions do you want to import?", value="neo4j" ) - return input_text def get_pages(): diff --git a/pdf_bot.py b/pdf_bot.py index fde7772b..979a96ef 100644 --- a/pdf_bot.py +++ b/pdf_bot.py @@ -57,10 +57,7 @@ def main(): if pdf is not None: pdf_reader = PdfReader(pdf) - text = "" - for page in pdf_reader.pages: - text += page.extract_text() - + text = "".join(page.extract_text() for page in pdf_reader.pages) # langchain_textspliter text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len @@ -83,10 +80,9 @@ def main(): llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever() ) - # Accept user questions/query - query = st.text_input("Ask questions about related your upload pdf file") - - if query: + if query := st.text_input( + "Ask questions about related your upload pdf file" + ): stream_handler = StreamHandler(st.empty()) qa.run(query, callbacks=[stream_handler]) diff --git a/requirements.txt b/requirements.txt index 6c257a19..4db86c6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,9 @@ Pillow fastapi PyPDF2 torch==2.0.1 +black +flake8 pydantic uvicorn sse-starlette -boto3 +boto3 \ No newline at end of file diff --git a/utils.py b/utils.py index 9404f154..ff521e8d 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,4 @@ +import contextlib class BaseLogger: def __init__(self) -> None: self.info = print @@ -28,15 +29,11 @@ def extract_title_and_question(input_string): def create_vector_index(driver, dimension: int) -> None: index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')" - try: + with contextlib.suppress(Exception): driver.query(index_query, {"dimension": dimension}) - except: # Already exists - pass index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')" - try: + with contextlib.suppress(Exception): driver.query(index_query, {"dimension": dimension}) - except: # Already exists - pass def create_constraints(driver):