diff --git a/packages/api/src/ai/ai.module.ts b/packages/api/src/ai/ai.module.ts index f0d6594..a566eec 100644 --- a/packages/api/src/ai/ai.module.ts +++ b/packages/api/src/ai/ai.module.ts @@ -3,12 +3,10 @@ import { aiModelsMongooseProviders } from '@/ai/ai-models.mongoose.providers'; import { AiModelsRepository } from '@/ai/ai-models.repository'; import { AiController } from '@/ai/ai.controller'; import { AiService } from '@/ai/facades/ai.service'; -import { AgentConversationService } from '@/ai/services/agent-conversation.service'; import { ChatModelService } from '@/ai/services/chat-model.service'; import { MemoryService } from '@/ai/services/memory.service'; import { RedisKeepAliveService } from '@/ai/services/redis-keep-alive.service'; import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service'; -import { ToolService } from '@/ai/services/tool.service'; import { VectorDbService } from '@/ai/services/vector-db.service'; import { AdminAddAiModelUsecase } from '@/ai/usecases/admin-add-ai-model.usecase'; import { AdminFindAiModelsUsecase } from '@/ai/usecases/admin-find-ai-models.usecase'; @@ -46,9 +44,7 @@ import { ScheduleModule } from '@nestjs/schedule'; // Private services AiModelsRepository, MemoryService, - ToolService, SimpleConversationChainService, - AgentConversationService, RedisKeepAliveService, VectorDbService, ChatModelService, diff --git a/packages/api/src/ai/facades/ai.service.ts b/packages/api/src/ai/facades/ai.service.ts index ea4048d..e455518 100644 --- a/packages/api/src/ai/facades/ai.service.ts +++ b/packages/api/src/ai/facades/ai.service.ts @@ -1,5 +1,5 @@ -import { AgentConversationService } from '@/ai/services/agent-conversation.service'; import { ChatModelService } from '@/ai/services/chat-model.service'; +import { DocumentConversationChain } from '@/ai/services/document-conversation-chain.service'; import { MemoryService } from '@/ai/services/memory.service'; import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service'; import { VectorDbService } from '@/ai/services/vector-db.service'; @@ -7,7 +7,6 @@ import { AppConfigService } from '@/app-config/app-config.service'; import { RedisChatMemoryNotFoundException } from '@/chats/exceptions/redis-chat-memory-not-found.exception'; import { ChatDocument } from '@/common/types/chat'; import { Injectable, Logger } from '@nestjs/common'; -import { AgentExecutor } from 'langchain/agents'; import { ConversationChain, LLMChain, @@ -17,7 +16,7 @@ import { Document } from 'langchain/document'; import { PromptTemplate } from 'langchain/prompts'; import { BaseMessage, ChainValues } from 'langchain/schema'; -type AIExecutor = AgentExecutor | ConversationChain; +type AIExecutor = DocumentConversationChain | ConversationChain; @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); @@ -31,7 +30,6 @@ Helpful answer:` constructor( private readonly simpleConversationChainService: SimpleConversationChainService, - private readonly agentConversationService: AgentConversationService, private readonly appConfigService: AppConfigService, private readonly memoryService: MemoryService, private readonly vectorDbService: VectorDbService, @@ -53,17 +51,20 @@ Helpful answer:` } if (documents.length > 0) { - aiExecutor = await this.agentConversationService.getAgent( - roomId, - await this.chatModelService.getChatModel(chatLlmId), + aiExecutor = new DocumentConversationChain({ + memoryService: this.memoryService, + vectorDbService: this.vectorDbService, + simpleConversationChainService: this.simpleConversationChainService, + llmModel: await this.chatModelService.getChatModel(chatLlmId), documents, - summary?.response - ); + roomId, + summary: summary?.output, + }); } else { aiExecutor = await this.simpleConversationChainService.getChain( roomId, await this.chatModelService.getChatModel(chatLlmId), - summary?.response + summary?.output ); } @@ -82,13 +83,6 @@ Helpful answer:` ); } } - - invalidateAgentCache(roomId: string): void { - if (this.agentConversationService.agentMap.has(roomId)) { - this.agentConversationService.agentMap.delete(roomId); - } - } - async askAiToDescribeDocument( lcDocuments: Document[], chatLlmId: string diff --git a/packages/api/src/ai/services/agent-conversation.service.ts b/packages/api/src/ai/services/agent-conversation.service.ts deleted file mode 100644 index fe98d30..0000000 --- a/packages/api/src/ai/services/agent-conversation.service.ts +++ /dev/null @@ -1,70 +0,0 @@ -import { MemoryService } from '@/ai/services/memory.service'; -import { ToolService } from '@/ai/services/tool.service'; -import { ChatDocument } from '@/common/types/chat'; -import { Injectable } from '@nestjs/common'; -import { - AgentExecutor, - initializeAgentExecutorWithOptions, -} from 'langchain/agents'; -import { BaseChatModel } from 'langchain/chat_models'; -import { BufferMemory } from 'langchain/memory'; -import { Calculator } from 'langchain/tools/calculator'; - -@Injectable() -export class AgentConversationService { - agentMap: Map; - - constructor( - private readonly memoryService: MemoryService, - private readonly toolService: ToolService - ) { - this.agentMap = new Map(); - } - - async getAgent( - roomId: string, - llmModel: BaseChatModel, - documents: ChatDocument[], - summary?: string - ): Promise { - if (!this.hasAgent(roomId) || !!summary) { - await this.createAgent(roomId, llmModel, documents, summary); - } - - return this.agentMap.get(roomId); - } - - private hasAgent(roomId: string) { - return this.agentMap.has(roomId); - } - - private async createAgent( - roomId: string, - llmModel: BaseChatModel, - documents: ChatDocument[], - summary?: string - ) { - const agentDocumentTools = await this.toolService.getDocumentQATools( - roomId, - llmModel, - documents - ); - - const agent = await initializeAgentExecutorWithOptions( - [new Calculator(), ...agentDocumentTools], - llmModel, - { - agentType: 'chat-conversational-react-description', - memory: new BufferMemory({ - returnMessages: true, - memoryKey: 'chat_history', - chatHistory: ( - await this.memoryService.getMemory(roomId, summary) - ).chatHistory, - }), - } - ); - - this.agentMap.set(roomId, agent); - } -} diff --git a/packages/api/src/ai/services/document-conversation-chain.service.ts b/packages/api/src/ai/services/document-conversation-chain.service.ts new file mode 100644 index 0000000..c57ac88 --- /dev/null +++ b/packages/api/src/ai/services/document-conversation-chain.service.ts @@ -0,0 +1,72 @@ +import { MemoryService } from '@/ai/services/memory.service'; +import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service'; +import { VectorDbService } from '@/ai/services/vector-db.service'; +import { ChatDocument } from '@/common/types/chat'; +import { Injectable } from '@nestjs/common'; +import { RetrievalQAChain, loadQAStuffChain } from 'langchain/chains'; +import { BaseChatModel } from 'langchain/chat_models'; +import { PromptTemplate } from 'langchain/prompts'; + +type DocumentConversationChainProps = { + memoryService: MemoryService; + vectorDbService: VectorDbService; + simpleConversationChainService: SimpleConversationChainService; + llmModel: BaseChatModel; + documents: ChatDocument[]; + roomId: string; + summary?: string; +}; +@Injectable() +export class DocumentConversationChain { + defaultChatPrompt: string; + + constructor(private readonly args: DocumentConversationChainProps) { + this.defaultChatPrompt = `Use the following context to answer the question at the end. + {context} + Question: {question} + If you don't know the answer return only: Answer not found. + But if you have an answer, provide the most detailed response you can.`; + } + + async call({ input }: { input: string }) { + const prompt = PromptTemplate.fromTemplate(this.defaultChatPrompt); + + for (const document of this.args.documents) { + const vectorStore = + await this.args.vectorDbService.getVectorDbClientForExistingCollection( + this.args.roomId, + document.meta.filename + ); + + const chain = new RetrievalQAChain({ + combineDocumentsChain: loadQAStuffChain(this.args.llmModel, { + prompt, + }), + retriever: vectorStore.asRetriever(), + returnSourceDocuments: true, + inputKey: 'input', + }); + + const res = await chain.call({ + input, + }); + + if (!res.text.includes('Answer not found.')) { + await this.args.memoryService.createMemoryWithDocumentInput( + this.args.roomId, + input, + res.text, + this.args.summary + ); + return { output: res.text, source: res.sourceDocuments }; + } + } + const simpleConversationChain = + await this.args.simpleConversationChainService.getChain( + this.args.roomId, + this.args.llmModel, + this.args.summary + ); + return simpleConversationChain.call({ input }); + } +} diff --git a/packages/api/src/ai/services/memory.service.ts b/packages/api/src/ai/services/memory.service.ts index d646b4c..6701088 100644 --- a/packages/api/src/ai/services/memory.service.ts +++ b/packages/api/src/ai/services/memory.service.ts @@ -70,4 +70,31 @@ export class MemoryService { }) ); } + + async createMemoryWithDocumentInput( + roomId: string, + input: string, + response: string, + summary?: string + ) { + const redisChatHistory = new RedisChatMessageHistory({ + sessionId: roomId, + client: this.cacheClient, + sessionTTL: this.appConfigService.getAiAppConfig().defaultChatContextTTL, + }); + if (!!summary) { + await this.memoryMap.get(roomId).clear(); + await redisChatHistory.addAIChatMessage(summary); + } + await redisChatHistory.addUserMessage(input); + await redisChatHistory.addAIChatMessage(response); + this.memoryMap.set( + roomId, + new BufferMemory({ + returnMessages: true, + memoryKey: 'history', + chatHistory: redisChatHistory, + }) + ); + } } diff --git a/packages/api/src/ai/services/tool.service.ts b/packages/api/src/ai/services/tool.service.ts deleted file mode 100644 index 53e3d86..0000000 --- a/packages/api/src/ai/services/tool.service.ts +++ /dev/null @@ -1,39 +0,0 @@ -import { VectorDbService } from '@/ai/services/vector-db.service'; -import { ChatDocument } from '@/common/types/chat'; -import { Injectable } from '@nestjs/common'; -import { VectorDBQAChain } from 'langchain/chains'; -import { BaseChatModel } from 'langchain/chat_models'; -import { ChainTool } from 'langchain/tools'; - -@Injectable() -export class ToolService { - constructor(private readonly vectorDbService: VectorDbService) {} - - async getDocumentQATools( - roomId: string, - llmModel: BaseChatModel, - documents: ChatDocument[] - ): Promise { - const documentQATools = []; - - for (const document of documents) { - const vectorStore = - await this.vectorDbService.getVectorDbClientForExistingCollection( - roomId, - document.meta.filename - ); - - const chain = VectorDBQAChain.fromLLM(llmModel, vectorStore); - - const qaTool = new ChainTool({ - name: document.meta.vectorDBDocumentName, - description: document.meta.vectorDBDocumentDescription, - chain, - }); - - documentQATools.push(qaTool); - } - - return documentQATools; - } -} diff --git a/packages/api/src/chats/usecases/upload-documents-to-chat.usecase.ts b/packages/api/src/chats/usecases/upload-documents-to-chat.usecase.ts index 6cc0be4..24a57bb 100644 --- a/packages/api/src/chats/usecases/upload-documents-to-chat.usecase.ts +++ b/packages/api/src/chats/usecases/upload-documents-to-chat.usecase.ts @@ -59,8 +59,6 @@ export class UploadDocumentsToChatUsecase implements Usecase { ) ); } - - this.aiService.invalidateAgentCache(roomId); } private checkMaxDocumentsSizePerRoomInvariant( chat: Chat,