Skip to content

Commit

Permalink
feat: return the sourceDocument when asking for documents info (#185)
Browse files Browse the repository at this point in the history
* feat: return the sourceDocument when asking for documents info

* feat: create summary when using documentChain
  • Loading branch information
romansharapov19 authored Sep 20, 2023
1 parent 22615b5 commit ce3c3c2
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 132 deletions.
4 changes: 0 additions & 4 deletions packages/api/src/ai/ai.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ import { aiModelsMongooseProviders } from '@/ai/ai-models.mongoose.providers';
import { AiModelsRepository } from '@/ai/ai-models.repository';
import { AiController } from '@/ai/ai.controller';
import { AiService } from '@/ai/facades/ai.service';
import { AgentConversationService } from '@/ai/services/agent-conversation.service';
import { ChatModelService } from '@/ai/services/chat-model.service';
import { MemoryService } from '@/ai/services/memory.service';
import { RedisKeepAliveService } from '@/ai/services/redis-keep-alive.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { ToolService } from '@/ai/services/tool.service';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { AdminAddAiModelUsecase } from '@/ai/usecases/admin-add-ai-model.usecase';
import { AdminFindAiModelsUsecase } from '@/ai/usecases/admin-find-ai-models.usecase';
Expand Down Expand Up @@ -46,9 +44,7 @@ import { ScheduleModule } from '@nestjs/schedule';
// Private services
AiModelsRepository,
MemoryService,
ToolService,
SimpleConversationChainService,
AgentConversationService,
RedisKeepAliveService,
VectorDbService,
ChatModelService,
Expand Down
28 changes: 11 additions & 17 deletions packages/api/src/ai/facades/ai.service.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { AgentConversationService } from '@/ai/services/agent-conversation.service';
import { ChatModelService } from '@/ai/services/chat-model.service';
import { DocumentConversationChain } from '@/ai/services/document-conversation-chain.service';
import { MemoryService } from '@/ai/services/memory.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { RedisChatMemoryNotFoundException } from '@/chats/exceptions/redis-chat-memory-not-found.exception';
import { ChatDocument } from '@/common/types/chat';
import { Injectable, Logger } from '@nestjs/common';
import { AgentExecutor } from 'langchain/agents';
import {
ConversationChain,
LLMChain,
Expand All @@ -17,7 +16,7 @@ import { Document } from 'langchain/document';
import { PromptTemplate } from 'langchain/prompts';
import { BaseMessage, ChainValues } from 'langchain/schema';

type AIExecutor = AgentExecutor | ConversationChain;
type AIExecutor = DocumentConversationChain | ConversationChain;
@Injectable()
export class AiService {
private readonly logger = new Logger(AiService.name);
Expand All @@ -31,7 +30,6 @@ Helpful answer:`

constructor(
private readonly simpleConversationChainService: SimpleConversationChainService,
private readonly agentConversationService: AgentConversationService,
private readonly appConfigService: AppConfigService,
private readonly memoryService: MemoryService,
private readonly vectorDbService: VectorDbService,
Expand All @@ -53,17 +51,20 @@ Helpful answer:`
}

if (documents.length > 0) {
aiExecutor = await this.agentConversationService.getAgent(
roomId,
await this.chatModelService.getChatModel(chatLlmId),
aiExecutor = new DocumentConversationChain({
memoryService: this.memoryService,
vectorDbService: this.vectorDbService,
simpleConversationChainService: this.simpleConversationChainService,
llmModel: await this.chatModelService.getChatModel(chatLlmId),
documents,
summary?.response
);
roomId,
summary: summary?.output,
});
} else {
aiExecutor = await this.simpleConversationChainService.getChain(
roomId,
await this.chatModelService.getChatModel(chatLlmId),
summary?.response
summary?.output
);
}

Expand All @@ -82,13 +83,6 @@ Helpful answer:`
);
}
}

invalidateAgentCache(roomId: string): void {
if (this.agentConversationService.agentMap.has(roomId)) {
this.agentConversationService.agentMap.delete(roomId);
}
}

async askAiToDescribeDocument(
lcDocuments: Document[],
chatLlmId: string
Expand Down
70 changes: 0 additions & 70 deletions packages/api/src/ai/services/agent-conversation.service.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { MemoryService } from '@/ai/services/memory.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { ChatDocument } from '@/common/types/chat';
import { Injectable } from '@nestjs/common';
import { RetrievalQAChain, loadQAStuffChain } from 'langchain/chains';
import { BaseChatModel } from 'langchain/chat_models';
import { PromptTemplate } from 'langchain/prompts';

type DocumentConversationChainProps = {
memoryService: MemoryService;
vectorDbService: VectorDbService;
simpleConversationChainService: SimpleConversationChainService;
llmModel: BaseChatModel;
documents: ChatDocument[];
roomId: string;
summary?: string;
};
@Injectable()
export class DocumentConversationChain {
defaultChatPrompt: string;

constructor(private readonly args: DocumentConversationChainProps) {
this.defaultChatPrompt = `Use the following context to answer the question at the end.
{context}
Question: {question}
If you don't know the answer return only: Answer not found.
But if you have an answer, provide the most detailed response you can.`;
}

async call({ input }: { input: string }) {
const prompt = PromptTemplate.fromTemplate(this.defaultChatPrompt);

for (const document of this.args.documents) {
const vectorStore =
await this.args.vectorDbService.getVectorDbClientForExistingCollection(
this.args.roomId,
document.meta.filename
);

const chain = new RetrievalQAChain({
combineDocumentsChain: loadQAStuffChain(this.args.llmModel, {
prompt,
}),
retriever: vectorStore.asRetriever(),
returnSourceDocuments: true,
inputKey: 'input',
});

const res = await chain.call({
input,
});

if (!res.text.includes('Answer not found.')) {
await this.args.memoryService.createMemoryWithDocumentInput(
this.args.roomId,
input,
res.text,
this.args.summary
);
return { output: res.text, source: res.sourceDocuments };
}
}
const simpleConversationChain =
await this.args.simpleConversationChainService.getChain(
this.args.roomId,
this.args.llmModel,
this.args.summary
);
return simpleConversationChain.call({ input });
}
}
27 changes: 27 additions & 0 deletions packages/api/src/ai/services/memory.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,31 @@ export class MemoryService {
})
);
}

async createMemoryWithDocumentInput(
roomId: string,
input: string,
response: string,
summary?: string
) {
const redisChatHistory = new RedisChatMessageHistory({
sessionId: roomId,
client: this.cacheClient,
sessionTTL: this.appConfigService.getAiAppConfig().defaultChatContextTTL,
});
if (!!summary) {
await this.memoryMap.get(roomId).clear();
await redisChatHistory.addAIChatMessage(summary);
}
await redisChatHistory.addUserMessage(input);
await redisChatHistory.addAIChatMessage(response);
this.memoryMap.set(
roomId,
new BufferMemory({
returnMessages: true,
memoryKey: 'history',
chatHistory: redisChatHistory,
})
);
}
}
39 changes: 0 additions & 39 deletions packages/api/src/ai/services/tool.service.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ export class UploadDocumentsToChatUsecase implements Usecase {
)
);
}

this.aiService.invalidateAgentCache(roomId);
}
private checkMaxDocumentsSizePerRoomInvariant(
chat: Chat,
Expand Down

0 comments on commit ce3c3c2

Please sign in to comment.