Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev - 1.4.0 #28

Merged
merged 36 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0e82bc2
Add streaming support
zaldivards Nov 26, 2023
b8c836e
Remove unused prompts
zaldivards Nov 27, 2023
44fa10b
Update state handlers
zaldivards Nov 27, 2023
98a317e
Add `AsyncCallback` for the agent
zaldivards Nov 27, 2023
3c3984a
Update chatbox
zaldivards Nov 27, 2023
8509578
Remove deprecated code
zaldivards Nov 27, 2023
c8a3f61
Add `general` module
zaldivards Nov 27, 2023
1e8b2bb
Add streaming support for `QA`
zaldivards Nov 27, 2023
a0c1bfe
Rename `general` module to `streaming`
zaldivards Nov 27, 2023
b78a31b
Update the `ask` function
zaldivards Nov 27, 2023
66db0c4
Update text utils
zaldivards Nov 28, 2023
acb1684
Update the `stream` function
zaldivards Nov 28, 2023
703bccb
Update the `ask` function
zaldivards Nov 28, 2023
30aac42
Remove `clean` function
zaldivards Nov 28, 2023
6f58500
Merge pull request #24 from zaldivards/feature/streaming
zaldivards Dec 1, 2023
2529bc9
Update `build_sources` function
zaldivards Dec 18, 2023
9305abf
Update how the sources are streamed
zaldivards Dec 18, 2023
3f1b2d7
Add the `SourcesBox` component
zaldivards Dec 18, 2023
f78fc7a
Add rendering of sources
zaldivards Dec 18, 2023
b53dba8
Update `build_sources` function
zaldivards Feb 10, 2024
440312c
Add state of the latest sources
zaldivards Feb 10, 2024
e36af74
Update sources layout
zaldivards Feb 10, 2024
d00799a
Merge pull request #26 from zaldivards/feature/sourceRendering
zaldivards Feb 12, 2024
3c7b528
Update file uploader
zaldivards Feb 12, 2024
29c5786
Add the `BatchProcessor` class
zaldivards Feb 12, 2024
6339245
Update the `/ingest/` endpoint
zaldivards Feb 12, 2024
8b9b8a4
Fix bug when working with multiple threads
zaldivards Feb 12, 2024
2621722
Add `/check-sources` endpoint
zaldivards Feb 12, 2024
3299493
Update mounted function to check the sources availability
zaldivards Feb 12, 2024
e5f7e9c
Update QA session warnings
zaldivards Feb 12, 2024
2ff4f89
Fix error related to the connection pool
zaldivards Feb 12, 2024
e9a368c
Update the `/ingest/` endpoint
zaldivards Feb 12, 2024
f7b7d28
Update QA messages
zaldivards Feb 12, 2024
933c16a
Update section names
zaldivards Feb 12, 2024
b21a044
Fix bug regarding the `latestSources` state
zaldivards Feb 12, 2024
d1aa05a
Merge pull request #27 from zaldivards/feature/multiIngestion
zaldivards Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions api/contextqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
def get_logger() -> logging.Logger:
return logging.getLogger("contextqa")


class AppSettings(BaseSettings):
"""Project settings"""

default_collection: str = "contextqa-default"
tmp_separator: str = ":::sep:::"
media_home: Path = Path(".media/")
Expand Down Expand Up @@ -40,7 +42,7 @@ def validate_media_path(cls, value: Path) -> Path:
"""validator for media path"""
value.mkdir(parents=True, exist_ok=True)
return value

@property
def sqlalchemy_url(self) -> str:
"""sqlalchemy url built either from the sqlite url or the credential of a specific mysql server"""
Expand All @@ -52,9 +54,6 @@ def sqlalchemy_url(self) -> str:
if extras := self.mysql_extra_args:
uri += extras
return uri





settings = AppSettings()
Expand Down
46 changes: 33 additions & 13 deletions api/contextqa/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,69 @@
# pylint: disable=E0611
from enum import Enum
from typing import Annotated
from typing import Annotated, Literal

from pydantic import BaseModel, Field


class SimilarityProcessor(str, Enum):
"""Enum representing the supported vector stores

Note that the LOCAL identifier refers to ChromaDB
"""

LOCAL = "local"
PINECONE = "pinecone"


class SourceFormat(str, Enum):
"""Enum representing the supported file formats"""

PDF = "pdf"
TXT = "txt"
CSV = "csv"


class Source(BaseModel):
"""Source returned as metadata in QA sessions"""

title: str
format_: Annotated[SourceFormat, Field(alias="format")]
content: str | dict
content: str | list


class SourceStatus(BaseModel):
"""Response model returning the status of data sources"""

status: Literal["ready", "empty"]

@classmethod
def from_count_status(cls, status_flag: bool) -> "SourceStatus":
"""Get instance given the status flag"""
status = "ready" if status_flag else "empty"
return cls(status=status)


class LLMResult(BaseModel):
response: str
"""LLM chat response object"""

response: str

class QAResult(LLMResult):
sources: list[Source]

class IngestionResult(BaseModel):
"""Result of the ingestion process"""

class LLMRequestBodyBase(BaseModel):
separator: str = Field(description="Separator to use for the text splitting", default=".")
chunk_size: int = Field(description="size of each splitted chunk", default=100)
chunk_overlap: int = 50
completed: int
skipped_files: list[str]


class LLMContextQueryRequest(BaseModel):
"""QA session request object"""

question: str


class LLMQueryRequest(BaseModel):
"""Chat request object"""

message: str
internet_access: bool = False


class LLMQueryRequestBody(LLMRequestBodyBase):
query: str = Field(description="The query we want the llm to respond", min_length=10)
13 changes: 6 additions & 7 deletions api/contextqa/routes/conversational.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# pylint: disable=C0413
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import StreamingResponse

from contextqa import chat
from contextqa.models.schemas import (
LLMResult,
LLMQueryRequest,
)
from contextqa.models.schemas import LLMQueryRequest

router = APIRouter()


@router.post("/", response_model=LLMResult)
def get_answer(params: LLMQueryRequest):
@router.post("/")
async def get_answer(params: LLMQueryRequest):
"""
Provide a message and receive a response from the LLM
"""
try:
return chat.qa_service(params)
generator = chat.qa_service(params)
return StreamingResponse(generator, media_type="text/event-stream")
except Exception as ex:
raise HTTPException(
status_code=status.HTTP_424_FAILED_DEPENDENCY, detail={"message": "Something went wrong", "cause": str(ex)}
Expand Down
3 changes: 2 additions & 1 deletion api/contextqa/routes/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generator

from contextqa.services.db import SessionLocal
from sqlalchemy.orm import scoped_session


def get_db() -> Generator:
Expand All @@ -12,7 +13,7 @@ def get_db() -> Generator:
db session
"""
try:
session = SessionLocal()
session = scoped_session(SessionLocal)
yield session
session.commit()
except:
Expand Down
37 changes: 23 additions & 14 deletions api/contextqa/routes/qa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Annotated

from fastapi import APIRouter, HTTPException, UploadFile, Depends, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session

from contextqa import context, get_logger
from contextqa.models.schemas import (
LLMResult,
QAResult,
SimilarityProcessor,
LLMContextQueryRequest,
)
from contextqa.models.schemas import SimilarityProcessor, SourceStatus, LLMContextQueryRequest, IngestionResult
from contextqa.routes.dependencies import get_db
from contextqa.utils.exceptions import VectorDBConnectionError, DuplicatedSourceError

Expand All @@ -19,15 +15,16 @@
router = APIRouter()


@router.post("/ingest/", response_model=LLMResult)
def ingest_source(document: UploadFile, session: Annotated[Session, Depends(get_db)]):
@router.post("/ingest/", response_model=IngestionResult)
def ingest_source(documents: list[UploadFile], session: Annotated[Session, Depends(get_db)]):
"""
Ingest a data source into the vector database
"""
try:
context_setter = context.get_setter(SimilarityProcessor.LOCAL)
context_manager = context.get_setter(SimilarityProcessor.LOCAL)
processor = context.BatchProcessor(manager=context_manager)
# pylint: disable=E1102
return context_setter.persist(document.filename, document.file, session)
return processor.persist(documents, session)
except DuplicatedSourceError as ex:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
Expand Down Expand Up @@ -55,17 +52,29 @@ def ingest_source(document: UploadFile, session: Annotated[Session, Depends(get_
) from ex


@router.post("/", response_model=QAResult)
def qa(params: LLMContextQueryRequest):
@router.post("/")
async def qa(params: LLMContextQueryRequest):
"""
Perform a QA process against the documents you have ingested
"""
try:
context_setter = context.get_setter()
# pylint: disable=E1102
return context_setter.load_and_respond(params.question)
generator = context_setter.load_and_respond(params.question)
return StreamingResponse(generator, media_type="text/event-stream")
except Exception as ex:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": "ContextQA server did not process the request successfully", "cause": str(ex)},
) from ex


@router.get("/check-sources")
async def check_sources(session: Annotated[Session, Depends(get_db)]):
try:
status_flag = context.sources_exists(session)
return SourceStatus.from_count_status(status_flag)
except Exception as ex:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": "ContextQA could not get the results from the DB", "cause": str(ex)},
) from ex
57 changes: 34 additions & 23 deletions api/contextqa/services/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import AsyncGenerator

from langchain.agents import initialize_agent, AgentType, Agent
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE
Expand All @@ -11,8 +16,10 @@

from contextqa import settings
from contextqa.agents.tools import searcher
from contextqa.models.schemas import LLMResult, LLMQueryRequest
from contextqa.models.schemas import LLMQueryRequest
from contextqa.utils import memory, prompts
from contextqa.agents.tools import searcher
from contextqa.utils.streaming import stream


_MESSAGES = [
Expand All @@ -30,7 +37,7 @@
]


def get_llm_assistant(internet_access: bool) -> ConversationChain | Agent:
def get_llm_assistant(internet_access: bool) -> tuple[ConversationChain | Agent, AsyncCallbackHandler]:
"""Return certain LLM assistant based on the system configuration

Parameters
Expand All @@ -40,40 +47,44 @@ def get_llm_assistant(internet_access: bool) -> ConversationChain | Agent:

Returns
-------
ConversationChain | Agent
ConversationChain | Agent, AsyncCallbackHandler
"""
llm = ChatOpenAI(temperature=0)

if internet_access:
return initialize_agent(
[searcher],
llm=llm,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
memory=memory.Redis("default", internet_access=True),
verbose=settings.debug,
agent_kwargs={
# "output_parser": CustomOP(),
# "format_instructions": prompts.CONTEXTQA_AGENT_TEMPLATE,
"prefix": prompts.PREFIX,
},
handle_parsing_errors=True,
callback = AsyncFinalIteratorCallbackHandler(
answer_prefix_tokens=["Final", "Answer", '",', "", '"', "action", "_input", '":', '"']
)
llm = ChatOpenAI(temperature=0, streaming=True, callbacks=[callback])
return (
initialize_agent(
[searcher],
llm=llm,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
memory=memory.Redis("default", internet_access=True),
verbose=settings.debug,
agent_kwargs={"prefix": prompts.PREFIX},
handle_parsing_errors=True,
),
callback,
)
callback = AsyncIteratorCallbackHandler()
llm = ChatOpenAI(temperature=0, streaming=True, callbacks=[callback])
prompt = ChatPromptTemplate.from_messages(_MESSAGES)
return ConversationChain(llm=llm, prompt=prompt, memory=memory.Redis("default"), verbose=settings.debug)
return ConversationChain(llm=llm, prompt=prompt, memory=memory.Redis("default"), verbose=settings.debug), callback


def qa_service(params: LLMQueryRequest) -> LLMResult:
def qa_service(params: LLMQueryRequest) -> AsyncGenerator:
"""Chat with the llm

Parameters
----------
params : models.LLMQueryRequest
params : LLMQueryRequest
request body parameters

Returns
-------
models.LLMResult
LLM response
AsyncGenerator
"""
assistant = get_llm_assistant(params.internet_access)
return LLMResult(response=assistant.run(input=params.message))

assistant, callback = get_llm_assistant(params.internet_access)
return stream(assistant.arun(input=params.message), callback)
Loading