Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi ingestion support #27

Merged
merged 12 commits into from
Feb 13, 2024
Prev Previous commit
Next Next commit
Add the BatchProcessor class
This utility is a QA processor for batch ingestions
  • Loading branch information
zaldivards committed Feb 12, 2024
commit 29c578647665bf99e7cffb78971b66377d35c98a
7 changes: 3 additions & 4 deletions api/contextqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
def get_logger() -> logging.Logger:
return logging.getLogger("contextqa")


class AppSettings(BaseSettings):
"""Project settings"""

default_collection: str = "contextqa-default"
tmp_separator: str = ":::sep:::"
media_home: Path = Path(".media/")
Expand Down Expand Up @@ -40,7 +42,7 @@ def validate_media_path(cls, value: Path) -> Path:
"""validator for media path"""
value.mkdir(parents=True, exist_ok=True)
return value

@property
def sqlalchemy_url(self) -> str:
"""sqlalchemy url built either from the sqlite url or the credential of a specific mysql server"""
Expand All @@ -52,9 +54,6 @@ def sqlalchemy_url(self) -> str:
if extras := self.mysql_extra_args:
uri += extras
return uri





settings = AppSettings()
Expand Down
47 changes: 37 additions & 10 deletions api/contextqa/services/context.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import BinaryIO, Type, AsyncGenerator
from typing import AsyncGenerator, BinaryIO, Type

import pinecone
from chromadb import PersistentClient
from fastapi import UploadFile
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.docstore.document import Document
from langchain.document_loaders import PyMuPDFLoader, TextLoader, CSVLoader
from langchain.document_loaders import CSVLoader, PyMuPDFLoader, TextLoader
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.pinecone import Pinecone
from langchain.vectorstores.chroma import Chroma
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma
from langchain.vectorstores.pinecone import Pinecone
from pydantic import BaseModel
from sqlalchemy.orm import Session


from contextqa import get_logger, settings
from contextqa.models.schemas import LLMResult, SimilarityProcessor, SourceFormat
from contextqa.utils import memory, prompts
from contextqa.utils.exceptions import VectorDBConnectionError
from contextqa.utils.sources import get_not_seen_chunks, check_digest

from contextqa.utils.streaming import stream, CustomQAChain
from contextqa.utils.sources import check_digest, get_not_seen_chunks
from contextqa.utils.streaming import CustomQAChain, stream


LOGGER = get_logger()
Expand All @@ -33,7 +34,7 @@
chroma_client = PersistentClient(path=str(settings.local_vectordb_home))


class LLMContextManager(ABC):
class LLMContextManager(BaseModel, ABC):
"""Base llm manager"""

def load_and_preprocess(self, filename: str, file_: BinaryIO, session: Session) -> tuple[list[Document], list[str]]:
Expand All @@ -52,7 +53,7 @@ def load_and_preprocess(self, filename: str, file_: BinaryIO, session: Session)
-------
tuple[list[Document], list[str]]
document chunks and their corresponding IDs

Raises
------
DuplicatedSourceError
Expand Down Expand Up @@ -196,6 +197,32 @@ def context_object(self) -> VectorStore:
return processor


class BatchProcessor(BaseModel):
"""QA processor for batch ingestions"""

manager: LLMContextManager

def persist(self, sources: list[UploadFile], session: Session) -> LLMResult:
"""Ingest the uploaded sources

Parameters
----------
sources : list[UploadFile]
uploaded sources
session : Session
db session

Returns
-------
LLMResult
"""
func = self.manager.persist
with ThreadPoolExecutor(max_workers=5) as executor:
for source in sources:
executor.submit(func, source.filename, source.file, session)
return LLMResult(response="success")


def get_setter(processor: SimilarityProcessor | None = None) -> LLMContextManager:
"""LLMContextManager factory function

Expand Down