Skip to content

Commit

Permalink
make interfaces the same
Browse files Browse the repository at this point in the history
  • Loading branch information
ianmacartney committed Oct 31, 2023
1 parent ce215d3 commit e1115c9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 118 deletions.
80 changes: 27 additions & 53 deletions convex/agent/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as embeddingsCache from './embeddingsCache';
import { GameId, conversationId, playerId } from '../aiTown/ids';

const selfInternal = internal.agent.conversation;
const completionFn = UseOllama ? ollamaChatCompletion : chatCompletion;

export async function startConversationMessage(
ctx: ActionCtx,
Expand Down Expand Up @@ -50,27 +51,18 @@ export async function startConversationMessage(
}
prompt.push(`${player.name}:`);

if (UseOllama) {
const { content } = await ollamaChatCompletion({
prompt: prompt.join('\n'),
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
} else {
const { content } = await chatCompletion({
messages: [
{
role: 'user',
content: prompt.join('\n'),
},
],
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
const { content } = await completionFn({
messages: [
{
role: 'user',
content: prompt.join('\n'),
},
],
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}

export async function continueConversationMessage(
Expand Down Expand Up @@ -122,22 +114,13 @@ export async function continueConversationMessage(
];
llmMessages.push({ role: 'user', content: `${player.name}:` });

if (UseOllama) {
const { content } = await ollamaChatCompletion({
messages: llmMessages,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
} else {
const { content } = await chatCompletion({
messages: llmMessages,
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
const { content } = await completionFn({
messages: llmMessages,
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}

export async function leaveConversationMessage(
Expand Down Expand Up @@ -180,22 +163,13 @@ export async function leaveConversationMessage(
];
llmMessages.push({ role: 'user', content: `${player.name}:` });

if (UseOllama) {
const { content } = await ollamaChatCompletion({
messages: llmMessages,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
} else {
const { content } = await chatCompletion({
messages: llmMessages,
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
const { content } = await completionFn({
messages: llmMessages,
max_tokens: 300,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}

function agentPrompts(
Expand Down
63 changes: 20 additions & 43 deletions convex/agent/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { SerializedPlayer } from '../aiTown/player';
import { UseOllama, ollamaChatCompletion } from '../util/ollama';
import { memoryFields } from './schema';

const completionFn = UseOllama ? ollamaChatCompletion : chatCompletion;

// How long to wait before updating a memory's last access time.
export const MEMORY_ACCESS_THROTTLE = 300_000; // In ms
// We fetch 10x the number of memories by relevance, to have more candidates
Expand Down Expand Up @@ -60,25 +62,13 @@ export async function rememberConversation(
});
}
llmMessages.push({ role: 'user', content: 'Summary:' });
let summaryResult: string;

if (UseOllama) {
console.log('### Using Ollama for conversation summary ###');
const ollamaPrompt = llmMessages.map((m) => m.content).join('\n');
const { content } = await ollamaChatCompletion({
prompt: ollamaPrompt,
});
summaryResult = content;
} else {
const { content } = await chatCompletion({
messages: llmMessages,
max_tokens: 500,
});
summaryResult = content;
}
const { content } = await completionFn({
messages: llmMessages,
max_tokens: 500,
});
const description = `Conversation with ${otherPlayer.name} at ${new Date(
data.conversation._creationTime,
).toLocaleString()}: ${summaryResult}`;
).toLocaleString()}: ${content}`;
const importance = await calculateImportance(description);
const { embedding } = await fetchEmbedding(description);
authors.delete(player.id as GameId<'players'>);
Expand Down Expand Up @@ -253,38 +243,25 @@ export const loadMessages = internalQuery({
});

async function calculateImportance(description: string) {
const llmMessages: LLMMessage[] = [
{
role: 'user',
content: `On the scale of 0 to 9, where 0 is purely mundane (e.g., brushing teeth, making bed) and 9 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following piece of memory.
const { content: importanceRaw } = await completionFn({
messages: [
{
role: 'user',
content: `On the scale of 0 to 9, where 0 is purely mundane (e.g., brushing teeth, making bed) and 9 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following piece of memory.
Memory: ${description}
Answer on a scale of 0 to 9. Respond with number only, e.g. "5"`,
},
];
let returnedImportanceRaw: string;

if (UseOllama) {
console.log('### Using Ollama for memory scoring ###');
const { content: importanceRaw } = await ollamaChatCompletion({
prompt: llmMessages.map((m) => m.content).join('\n'),
});
console.log('### Ollama returned: ', importanceRaw);
returnedImportanceRaw = importanceRaw;
} else {
const { content: importanceRaw } = await chatCompletion({
messages: llmMessages,
temperature: 0.0,
max_tokens: 1,
});
returnedImportanceRaw = importanceRaw;
}
},
],
temperature: 0.0,
max_tokens: 1,
});

let importance = parseFloat(returnedImportanceRaw);
let importance = parseFloat(importanceRaw);
if (isNaN(importance)) {
importance = +(returnedImportanceRaw.match(/\d+/)?.[0] ?? NaN);
importance = +(importanceRaw.match(/\d+/)?.[0] ?? NaN);
}
if (isNaN(importance)) {
console.debug('Could not parse memory importance from: ', returnedImportanceRaw);
console.debug('Could not parse memory importance from: ', importanceRaw);
importance = 5;
}
return importance;
Expand Down
50 changes: 28 additions & 22 deletions convex/util/ollama.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
import { Ollama } from 'langchain/llms/ollama';
import { ChatCompletionContent, LLMMessage, retryWithBackoff } from './openai';
import {
ChatCompletionContent,
CreateChatCompletionRequest,
LLMMessage,
retryWithBackoff,
} from './openai';
import { IterableReadableStream } from 'langchain/dist/util/stream';

const ollamaModel = process.env.OLLAMA_MODEL || 'llama2';
export const UseOllama = process.env.OLLAMA_HOST !== undefined;

type Body =
| {
messages: LLMMessage[];
stop?: string[];
stream?: boolean;
model?: string;
}
| {
prompt: string;
stop?: string[];
stream?: boolean;
model?: string;
};

// Overload for non-streaming
export async function ollamaChatCompletion(
body: Body & { stream?: false | undefined },
body: Omit<CreateChatCompletionRequest, 'model'> & {
model?: string;
} & {
stream?: false | null | undefined;
},
): Promise<{ content: string; retries: number; ms: number }>;
// Overload for streaming
export async function ollamaChatCompletion(
body: Body & { stream: true },
body: Omit<CreateChatCompletionRequest, 'model'> & {
model?: string;
} & {
stream?: true;
},
): Promise<{ content: OllamaCompletionContent; retries: number; ms: number }>;
export async function ollamaChatCompletion(body: Body) {
export async function ollamaChatCompletion(
body: Omit<CreateChatCompletionRequest, 'model'> & {
model?: string;
},
) {
body.model = body.model ?? 'llama2';
const {
result: content,
Expand All @@ -34,16 +39,17 @@ export async function ollamaChatCompletion(body: Body) {
} = await retryWithBackoff(async () => {
console.log('#### Ollama api ####, using ', ollamaModel);

const stop = typeof body.stop === 'string' ? [body.stop] : body.stop;
const ollama = new Ollama({
model: ollamaModel,
baseUrl: process.env.OLLAMA_HOST,
stop: body.stop,
stop,
});
const prompt = 'prompt' in body ? body.prompt : body.messages.map((m) => m.content).join('\n');
const prompt = body.messages.map((m) => m.content).join('\n');
console.log('body.prompt', prompt);
const stream = await ollama.stream(prompt, { stop: body.stop });
const stream = await ollama.stream(prompt, { stop });
if (body.stream) {
return new OllamaCompletionContent(stream, body.stop ?? []);
return new OllamaCompletionContent(stream, stop ?? []);
}
let ollamaResult = '';
for await (const chunk of stream) {
Expand Down

0 comments on commit e1115c9

Please sign in to comment.