From 81b1054811a5cc503d297d93d2ac03798086fafb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 9 Oct 2024 16:03:28 -0400 Subject: [PATCH] thread --- js/ai/src/generate.ts | 2 +- js/ai/src/model.ts | 2 +- js/genkit/src/genkit.ts | 60 +++-- js/genkit/src/session.ts | 248 ++++++++++++-------- js/genkit/tests/environment_test.ts | 141 +++--------- js/genkit/tests/helpers.ts | 2 +- js/genkit/tests/models_test.ts | 5 +- js/genkit/tests/session_test.ts | 344 +++++++++++++++++++++------- js/testapps/rag/src/bar.ts | 31 --- js/testapps/rag/src/foo.ts | 93 -------- 10 files changed, 478 insertions(+), 450 deletions(-) delete mode 100644 js/testapps/rag/src/bar.ts delete mode 100644 js/testapps/rag/src/foo.ts diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 9b5fc259..988457fe 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -549,7 +549,7 @@ export async function generate< config: { ...resolvedModel.config, version: resolvedModel.version, - ...resolvedOptions.config + ...resolvedOptions.config, }, output: resolvedOptions.output && { format: resolvedOptions.output.format, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 05477f3b..14b6b15e 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -385,7 +385,7 @@ export interface ModelReference { configSchema?: CustomOptions; info?: ModelInfo; version?: string; - config?: z.infer, + config?: z.infer; } /** Cretes a model reference. */ diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 46684451..0df98484 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -46,7 +46,7 @@ import { retrieve, RetrieverParams, ToolAction, - ToolConfig + ToolConfig, } from '@genkit-ai/ai'; import { CallableFlow, @@ -84,7 +84,6 @@ import { } from './model.js'; import { lookupAction, Registry, runWithRegistry } from './registry.js'; import { - BaseGenerateOptions, Environment, getCurrentSession, Session, @@ -700,43 +699,42 @@ export class Genkit { chat( options?: SessionOptions - ): Session; - - chat( - requestBase: BaseGenerateOptions, - options?: SessionOptions - ): Session; - - chat( - requestBaseOrOpts?: SessionOptions | BaseGenerateOptions, - maybeOptions?: SessionOptions ): Session { - // parse overloaded args - let baseGenerateOptions: BaseGenerateOptions | undefined = undefined; - let options: SessionOptions | undefined = undefined; - if (maybeOptions !== undefined) { - options = maybeOptions; - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; - } else if (requestBaseOrOpts !== undefined) { - if ( - (requestBaseOrOpts as SessionOptions).state || - (requestBaseOrOpts as SessionOptions).store || - (requestBaseOrOpts as SessionOptions).schema - ) { - options = requestBaseOrOpts as SessionOptions; - } else { - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; + return new Session( + this, + { + ...options, + }, + { + sessionData: { + state: options?.state, + }, + stateSchema: options?.stateSchema, + store: options?.store, } - } + ); + } + async loadChat( + sessionId: string, + options: SessionOptions + ): Promise> { + if (!options.store) { + throw new Error('options.store is required for loading chat sessions'); + } + const sessionData = await options.store.get(sessionId); + if (!sessionData) { + throw new Error(`chat session ${sessionId} not found`); + } return new Session( this, { - ...baseGenerateOptions, + ...options, }, { - state: options?.state, - schema: options?.schema, + id: sessionId, + sessionData, + stateSchema: options?.stateSchema, store: options?.store, } ); diff --git a/js/genkit/src/session.ts b/js/genkit/src/session.ts index b2f259b7..7104f3dc 100644 --- a/js/genkit/src/session.ts +++ b/js/genkit/src/session.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import { GenerateOptions, GenerateResponse, @@ -23,13 +39,16 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { v4 as uuidv4 } from 'uuid'; import { ExecutablePrompt, Genkit } from './genkit'; +const MAIN_THREAD = '__main'; + export type BaseGenerateOptions = Omit; -export interface SessionOptions { - state?: z.infer; - schema?: S; +export type SessionOptions = BaseGenerateOptions & { + stateSchema?: S; store?: SessionStore; -} + state?: z.infer; + sessionId?: string; +}; type EnvironmentType = Pick< Genkit, @@ -46,7 +65,7 @@ export class Environment implements EnvironmentType { private name: string; constructor( - private genkit: Genkit, + readonly genkit: Genkit, config: { name: string; stateSchema?: S; @@ -111,83 +130,31 @@ export class Environment implements EnvironmentType { return this.genkit.definePrompt(options, templateOrFn as PromptFn); } - createSession(options?: EnvironmentSessionOptions): Session; - - createSession( - requestBase: BaseGenerateOptions, - options?: EnvironmentSessionOptions - ): Session; - - createSession( - requestBaseOrOpts?: EnvironmentSessionOptions | BaseGenerateOptions, - maybeOptions?: EnvironmentSessionOptions - ): Session { - // parse overloaded args - let baseGenerateOptions: BaseGenerateOptions | undefined = undefined; - let options: EnvironmentSessionOptions | undefined = undefined; - if (maybeOptions !== undefined) { - options = maybeOptions; - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; - } else if (requestBaseOrOpts !== undefined) { - if ( - (requestBaseOrOpts as EnvironmentSessionOptions).state || - (requestBaseOrOpts as EnvironmentSessionOptions).schema - ) { - options = requestBaseOrOpts as EnvironmentSessionOptions; - } else { - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; - } - } - + createSession(options?: EnvironmentSessionOptions): Session { return new Session( this.genkit, { - ...baseGenerateOptions, + ...options, }, { - state: options?.state, - schema: options?.schema, + id: options?.sessionId, + sessionData: { + state: options?.state, + }, + stateSchema: options?.stateSchema, store: this.store, } ); } - loadSession( - sessionId: string, - options?: EnvironmentSessionOptions - ): Promise>; - - loadSession( - sessionId: string, - requestBase: BaseGenerateOptions, - options?: EnvironmentSessionOptions - ): Promise>; - async loadSession( sessionId: string, - requestBaseOrOpts?: EnvironmentSessionOptions | BaseGenerateOptions, - maybeOptions?: EnvironmentSessionOptions + options?: EnvironmentSessionOptions ): Promise> { - // parse overloaded args - let baseGenerateOptions: BaseGenerateOptions | undefined = undefined; - let options: EnvironmentSessionOptions | undefined = undefined; - if (maybeOptions !== undefined) { - options = maybeOptions; - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; - } else if (requestBaseOrOpts !== undefined) { - if ( - (requestBaseOrOpts as EnvironmentSessionOptions).state || - (requestBaseOrOpts as EnvironmentSessionOptions).schema - ) { - options = requestBaseOrOpts as EnvironmentSessionOptions; - } else { - baseGenerateOptions = requestBaseOrOpts as BaseGenerateOptions; - } - } - - const state = this.store.get(sessionId); + const state = await this.store.get(sessionId); - return this.createSession(baseGenerateOptions!, { + return this.createSession({ + sessionId, ...options, state, }); @@ -205,27 +172,85 @@ export class Environment implements EnvironmentType { export class Session { readonly id: string; readonly schema?: S; - readonly sessionData: SessionData; + private sessionData?: SessionData; private store: SessionStore; + private threadName: string; constructor( - readonly environment: Genkit, + readonly parent: Genkit | Environment | Session, readonly requestBase?: BaseGenerateOptions, options?: { - schema?: S; - state: z.infer; + id?: string; + stateSchema?: S; + sessionData?: SessionData; store?: SessionStore; + threadName?: string; } ) { - this.id = uuidv4(); - this.schema = options?.schema; - this.sessionData = { - state: options?.state ?? {}, - messages: requestBase?.messages ?? [], - }; + this.id = options?.id ?? uuidv4(); + this.schema = options?.stateSchema; + this.threadName = options?.threadName ?? MAIN_THREAD; + this.sessionData = options?.sessionData; + if (!this.sessionData) { + this.sessionData = {}; + } + if (!this.sessionData.threads) { + this.sessionData!.threads = {}; + } + // this is handling dotprompt render case + if (requestBase && requestBase['prompt']) { + const basePrompt = requestBase['prompt'] as string | Part | Part[]; + let promptMessage: MessageData; + if (typeof basePrompt === 'string') { + promptMessage = { + role: 'user', + content: [{ text: basePrompt }], + }; + } else if (Array.isArray(basePrompt)) { + promptMessage = { + role: 'user', + content: basePrompt, + }; + } else { + promptMessage = { + role: 'user', + content: [basePrompt], + }; + } + requestBase.messages = [...(requestBase.messages ?? []), promptMessage]; + } + if (parent instanceof Session) { + if (!this.sessionData.threads[this.threadName]) { + this!.sessionData.threads[this.threadName] = [ + ...(parent.messages ?? []), + ...(requestBase?.messages ?? []), + ]; + } + } else { + if (!this.sessionData.threads[this.threadName]) { + this.sessionData.threads[this.threadName] = [ + ...(requestBase?.messages ?? []), + ]; + } + } this.store = options?.store ?? new InMemorySessionStore(); } + thread(threadName: string): Session { + const requestBase = { + ...this.requestBase, + }; + delete requestBase.messages; + const parent = this.parent instanceof Session ? this.parent : this; + return new Session(parent, requestBase, { + id: this.id, + stateSchema: this.schema, + store: this.store, + threadName, + sessionData: this.sessionData, + }); + } + async send< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -244,9 +269,9 @@ export class Session { prompt: options, } as GenerateOptions; } - const response = await this.environment.generate({ + const response = await this.genkit.generate({ ...this.requestBase, - messages: this.sessionData.messages, + messages: this.messages, ...options, }); await this.updateMessages(response.toHistory()); @@ -271,9 +296,9 @@ export class Session { prompt: options, } as GenerateOptions; } - const { response, stream } = await this.environment.generateStream({ + const { response, stream } = await this.genkit.generateStream({ ...this.requestBase, - messages: this.sessionData.messages, + messages: this.messages, ...options, }); @@ -285,6 +310,16 @@ export class Session { }; } + private get genkit(): Genkit { + if (this.parent instanceof Session) { + return this.parent.genkit; + } + if (this.parent instanceof Environment) { + return this.parent.genkit; + } + return this.parent; + } + runFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, @@ -293,30 +328,52 @@ export class Session { } get state(): z.infer { - return this.sessionData.state; + // We always get state from the parent. Parent session is the source of truth. + if (this.parent instanceof Session) { + return this.parent.state; + } + return this.sessionData!.state; } async updateState(data: z.infer): Promise { - this.sessionData.state = data; - await this.store.save(this.id, { - state: this.state, - messages: this.messages, - }); + // We always update the state on the parent. Parent session is the source of truth. + if (this.parent instanceof Session) { + return this.parent.updateState(data); + } + let sessionData = await this.store.get(this.id); + if (!sessionData) { + sessionData = {} as SessionData; + } + sessionData.state = data; + this.sessionData = sessionData; + + await this.store.save(this.id, sessionData); } get messages(): MessageData[] | undefined { - return this.sessionData.messages; + if (!this.sessionData?.threads) { + return undefined; + } + return this.sessionData?.threads[this.threadName]; } async updateMessages(messages: MessageData[]): Promise { - this.sessionData.messages = messages; - await this.store.save(this.id, { - state: this.state, - messages: this.messages, - }); + let sessionData = await this.store.get(this.id); + if (!sessionData) { + sessionData = { threads: {} }; + } + if (!sessionData.threads) { + sessionData.threads = {}; + } + sessionData.threads[this.threadName] = messages; + this.sessionData = sessionData; + await this.store.save(this.id, sessionData); } toJSON() { + if (this.parent instanceof Session) { + return this.parent.toJSON(); + } return this.sessionData; } @@ -327,7 +384,7 @@ export class Session { export interface SessionData { state?: z.infer; - messages?: MessageData[]; + threads?: Record; } const sessionAls = new AsyncLocalStorage>(); @@ -356,6 +413,7 @@ export class SessionError extends Error { } } +/** Session store persists session data such as state and chat messages. */ export interface SessionStore { get(sessionId: string): Promise | undefined>; diff --git a/js/genkit/tests/environment_test.ts b/js/genkit/tests/environment_test.ts index 0488bbb3..7f751ddf 100644 --- a/js/genkit/tests/environment_test.ts +++ b/js/genkit/tests/environment_test.ts @@ -17,8 +17,8 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { Genkit, genkit } from '../src/genkit'; -import { defineEchoModel, TestMemorySessionStore } from './helpers'; import { z } from '../src/index'; +import { TestMemorySessionStore, defineEchoModel } from './helpers'; describe('environment', () => { let ai: Genkit; @@ -35,7 +35,7 @@ describe('environment', () => { name: 'agent', stateSchema: z.object({ name: z.string(), - }) + }), }); const session = env.createSession(); @@ -50,41 +50,16 @@ describe('environment', () => { 'Echo: hi,Echo: hi,; config: {},bye; config: {}' ); assert.deepStrictEqual(response.toHistory(), [ + { content: [{ text: 'hi' }], role: 'user' }, { - content: [ - { - text: 'hi', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi', - }, - { - text: '; config: {}', - }, - ], + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], role: 'model', }, + { content: [{ text: 'bye' }], role: 'user' }, { content: [ - { - text: 'bye', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi,Echo: hi,; config: {},bye', - }, - { - text: '; config: {}', - }, + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, ], role: 'model', }, @@ -96,11 +71,11 @@ describe('environment', () => { name: 'agent', stateSchema: z.object({ name: z.string(), - }) + }), }); const session = env.createSession(); - let {response, stream} = await session.sendStream('hi'); + let { response, stream } = await session.sendStream('hi'); let chunks: string[] = []; for await (const chunk of stream) { @@ -109,7 +84,7 @@ describe('environment', () => { assert.strictEqual((await response).text(), 'Echo: hi; config: {}'); assert.deepStrictEqual(chunks, ['3', '2', '1']); - ({response, stream} = await session.sendStream('bye')); + ({ response, stream } = await session.sendStream('bye')); chunks = []; for await (const chunk of stream) { @@ -122,43 +97,18 @@ describe('environment', () => { 'Echo: hi,Echo: hi,; config: {},bye; config: {}' ); assert.deepStrictEqual((await response).toHistory(), [ + { content: [{ text: 'hi' }], role: 'user' }, { - content: [ - { - text: 'hi', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi', - }, - { - text: '; config: {}', - }, - ], role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], }, + { content: [{ text: 'bye' }], role: 'user' }, { + role: 'model', content: [ - { - text: 'bye', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi,Echo: hi,; config: {},bye', - }, - { - text: '; config: {}', - }, + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, ], - role: 'model', }, ]); }); @@ -170,7 +120,7 @@ describe('environment', () => { store, stateSchema: z.object({ name: z.string(), - }) + }), }); const session = env.createSession(); await session.send('hi'); @@ -178,45 +128,22 @@ describe('environment', () => { const state = await store.get(session.id); - assert.deepStrictEqual(state?.messages, [ - { - content: [ - { - text: 'hi', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi', - }, - { - text: '; config: {}', - }, - ], - role: 'model', - }, - { - content: [ - { - text: 'bye', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi,Echo: hi,; config: {},bye', - }, - { - text: '; config: {}', - }, - ], - role: 'model', - }, - ]); + assert.deepStrictEqual(state?.threads, { + __main: [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + }); }); }); diff --git a/js/genkit/tests/helpers.ts b/js/genkit/tests/helpers.ts index 31f50964..51f0b6e4 100644 --- a/js/genkit/tests/helpers.ts +++ b/js/genkit/tests/helpers.ts @@ -82,7 +82,7 @@ async function runAsync(fn: () => O): Promise { export class TestMemorySessionStore implements SessionStore { - data: Record> = {}; + private data: Record> = {}; async get(sessionId: string): Promise | undefined> { return this.data[sessionId]; diff --git a/js/genkit/tests/models_test.ts b/js/genkit/tests/models_test.ts index d070283a..1a3233bf 100644 --- a/js/genkit/tests/models_test.ts +++ b/js/genkit/tests/models_test.ts @@ -47,10 +47,7 @@ describe('models', () => { for await (const chunk of stream) { chunks.push(chunk.text()); } - assert.strictEqual( - (await response).text(), - 'Echo: hi; config: {}' - ); + assert.strictEqual((await response).text(), 'Echo: hi; config: {}'); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); }); diff --git a/js/genkit/tests/session_test.ts b/js/genkit/tests/session_test.ts index 5f1d771a..1c9e5365 100644 --- a/js/genkit/tests/session_test.ts +++ b/js/genkit/tests/session_test.ts @@ -17,9 +17,7 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { Genkit, genkit } from '../src/genkit'; -import { z } from '../src/index'; -import { SessionData, SessionStore } from '../src/session'; -import { defineEchoModel, TestMemorySessionStore } from './helpers'; +import { TestMemorySessionStore, defineEchoModel } from './helpers'; describe('session', () => { let ai: Genkit; @@ -44,41 +42,16 @@ describe('session', () => { 'Echo: hi,Echo: hi,; config: {},bye; config: {}' ); assert.deepStrictEqual(response.toHistory(), [ + { content: [{ text: 'hi' }], role: 'user' }, { - content: [ - { - text: 'hi', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi', - }, - { - text: '; config: {}', - }, - ], + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], role: 'model', }, + { content: [{ text: 'bye' }], role: 'user' }, { content: [ - { - text: 'bye', - }, - ], - role: 'user', - }, - { - content: [ - { - text: 'Echo: hi,Echo: hi,; config: {},bye', - }, - { - text: '; config: {}', - }, + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, ], role: 'model', }, @@ -109,94 +82,293 @@ describe('session', () => { 'Echo: hi,Echo: hi,; config: {},bye; config: {}' ); assert.deepStrictEqual((await response).toHistory(), [ + { content: [{ text: 'hi' }], role: 'user' }, { - content: [ - { - text: 'hi', - }, - ], - role: 'user', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', }, + { content: [{ text: 'bye' }], role: 'user' }, { content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ]); + }); + + it('stores state and messages in the store', async () => { + const store = new TestMemorySessionStore(); + const session = ai.chat({ store }); + await session.send('hi'); + await session.send('bye'); + + const state = await store.get(session.id); + + assert.deepStrictEqual(state?.threads, { + __main: [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + }); + }); + + it('can init a session with a prompt', async () => { + const prompt = ai.definePrompt({ name: 'hi' }, 'hi {{ name }}'); + const session = ai.chat( + await prompt.render({ + input: { name: 'Genkit' }, + config: { temperature: 11 }, + }) + ); + const response = await session.send('hi'); + + assert.strictEqual( + response.text(), + 'Echo: hi Genkit,hi; config: {"temperature":11}' + ); + }); + + it('can send a prompt session to a session', async () => { + const prompt = ai.definePrompt( + { name: 'hi', config: { version: 'abc' } }, + 'hi {{ name }}' + ); + const session = ai.chat(); + const response = await session.send( + await prompt.render({ + input: { name: 'Genkit' }, + config: { temperature: 11 }, + }) + ); + + assert.strictEqual( + response.text(), + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' + ); + }); + + describe('loadChat', () => { + it('inherits history from parent session', async () => { + const store = new TestMemorySessionStore(); + // init the store + const originalMainChat = ai.chat({ store }); + await originalMainChat.send('hi'); + const originalSideChat = originalMainChat.thread('sideChat'); + await originalSideChat.send('bye'); + + const sessionId = originalMainChat.id; + + // load + const mainChat = await ai.loadChat(sessionId, { store }); + assert.deepStrictEqual(mainChat.messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + }, + ]); + let response = await mainChat.send('hi again'); + assert.strictEqual( + response.text(), + 'Echo: hi,Echo: hi,; config: {},hi again; config: {}' + ); + assert.deepStrictEqual(mainChat.messages, [ + { role: 'user', content: [{ text: 'hi' }] }, + { + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + }, + { content: [{ text: 'hi again' }], role: 'user' }, + { + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},hi again' }, + { text: '; config: {}' }, + ], + }, + ]); + + // make sure we can load the thread + const sideChat = mainChat.thread('sideChat'); + assert.deepStrictEqual(sideChat.messages, [ + { role: 'user', content: [{ text: 'hi' }] }, + { + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + }, + { role: 'user', content: [{ text: 'bye' }] }, + { + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + }, + ]); + + response = await sideChat.send('bye again'); + assert.strictEqual( + response.text(), + 'Echo: hi,Echo: hi,; config: {},bye,Echo: hi,Echo: hi,; config: {},bye,; config: {},bye again; config: {}' + ); + assert.deepStrictEqual(sideChat.messages, [ + { role: 'user', content: [{ text: 'hi' }] }, + { + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + }, + { role: 'user', content: [{ text: 'bye' }] }, + { + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + }, + { content: [{ text: 'bye again' }], role: 'user' }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye,Echo: hi,Echo: hi,; config: {},bye,; config: {},bye again', + }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ]); + + const state = await store.get(sessionId); + assert.deepStrictEqual(state?.threads, { + __main: [ + { content: [{ text: 'hi' }], role: 'user' }, { - text: 'Echo: hi', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', }, + { content: [{ text: 'hi again' }], role: 'user' }, { - text: '; config: {}', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},hi again' }, + { text: '; config: {}' }, + ], + role: 'model', }, ], - role: 'model', - }, - { - content: [ + sideChat: [ + { role: 'user', content: [{ text: 'hi' }] }, { - text: 'bye', + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], }, - ], - role: 'user', - }, - { - content: [ + { role: 'user', content: [{ text: 'bye' }] }, { - text: 'Echo: hi,Echo: hi,; config: {},bye', + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], }, + { content: [{ text: 'bye again' }], role: 'user' }, { - text: '; config: {}', + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye,Echo: hi,Echo: hi,; config: {},bye,; config: {},bye again', + }, + { text: '; config: {}' }, + ], + role: 'model', }, ], - role: 'model', - }, - ]); + }); + }); }); - it('stores state and messages in the store', async () => { - const store = new TestMemorySessionStore(); - const session = ai.chat({ store }); - await session.send('hi'); - await session.send('bye'); + describe('threads', () => { + it('inherits history from parent session', async () => { + const store = new TestMemorySessionStore(); + const mainChat = ai.chat({ store }); + await mainChat.send('hi'); - const state = await store.get(session.id); + const sideChat = mainChat.thread('sideChat'); + await sideChat.send('bye'); - assert.deepStrictEqual(state?.messages, [ - { - content: [ + const state = await store.get(mainChat.id); + + assert.deepStrictEqual(state?.threads, { + __main: [ + { content: [{ text: 'hi' }], role: 'user' }, { - text: 'hi', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', }, ], - role: 'user', - }, - { - content: [ + sideChat: [ + { role: 'user', content: [{ text: 'hi' }] }, { - text: 'Echo: hi', + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], }, + { role: 'user', content: [{ text: 'bye' }] }, { - text: '; config: {}', + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], }, ], - role: 'model', - }, - { - content: [ + }); + + // continue main thread + await mainChat.send('hi again'); + + assert.deepStrictEqual(state?.threads, { + __main: [ + { content: [{ text: 'hi' }], role: 'user' }, { - text: 'bye', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + // new lines + { content: [{ text: 'hi again' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},hi again' }, + { text: '; config: {}' }, + ], + role: 'model', }, ], - role: 'user', - }, - { - content: [ + sideChat: [ + // <---- sideChat unchanged from previous iteration + { role: 'user', content: [{ text: 'hi' }] }, { - text: 'Echo: hi,Echo: hi,; config: {},bye', + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], }, + { role: 'user', content: [{ text: 'bye' }] }, { - text: '; config: {}', + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], }, ], - role: 'model', - }, - ]); + }); + }); }); }); diff --git a/js/testapps/rag/src/bar.ts b/js/testapps/rag/src/bar.ts deleted file mode 100644 index ee0bbceb..00000000 --- a/js/testapps/rag/src/bar.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { vertexAI } from '@genkit-ai/vertexai'; -import { genkit, z } from 'genkit'; -import { modelRef } from 'genkit/model'; - -const ai = genkit({ - plugins: [vertexAI({ location: 'us-central1' })], - model: modelRef({ - name: 'vertexai/gemini-1.5-flash', - config: { - temperature: 1, - }, - }), -}); - -const hi = ai.definePrompt( - { - name: 'hi', - input: { - schema: z.object({ - name: z.string(), - }), - }, - }, - 'hi {{ name }}' -); - - -(async () => { - const response = await hi({ name: 'Genkit' }); - console.log(response.text()); -})() diff --git a/js/testapps/rag/src/foo.ts b/js/testapps/rag/src/foo.ts deleted file mode 100644 index 9048eec0..00000000 --- a/js/testapps/rag/src/foo.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { vertexAI } from '@genkit-ai/vertexai'; -import { genkit, z } from 'genkit'; -import { modelRef } from 'genkit/model'; - -const ai = genkit({ - plugins: [vertexAI({ location: 'us-central1' })], - model: modelRef({ - name: 'vertexai/gemini-1.5-flash', - config: { - temperature: 1, - }, - }), -}); - -const Weapon = z.object({ - name: z.string().describe('name of the weapon'), - description: z.string().describe('description on the weapon, one paragraph'), - power: z.number().describe('power level, between 1 and 10'), -}); - -const Character = z.object({ - name: z.string().describe('name of the character'), - background: z.string().describe('background story, one paragraph'), - weapon: Weapon, -}); - -(async () => { - let response; - - response = await ai.generate('tell me joke?'); - console.log(response.text()); - - response = await ai.generate({ - prompt: 'create a RPG character, archer', - output: { - format: 'json', - schema: Character, - }, - }); - - // text chat - let chatbotSession = ai.chat(); - response = await chatbotSession.send('hi my name is John'); - console.log(response.text()); - response = await chatbotSession.send('who am I?'); - console.log(response.text()); // { answer: '...John...' } - - - // json chat - chatbotSession = ai.chat({ - output: { - schema: z.object({ - answer: z.string(), - }), - format: 'json', - }, - }); - response = await chatbotSession.send('hi my name is John'); - console.log(response.output()); - response = await chatbotSession.send('who am I?'); - console.log(response.output()); // { answer: '...John...' } - - // Agent - const agent = ai.defineEnvironment({ - name: 'agent', - stateSchema: z.object({ name: z.string(), done: z.boolean() }), - }); - - const agentFlow = agent.defineFlow({ name: 'agentFlow' }, async () => { - const response = await agent.currentSession.send( - `hi, my name is ${agent.currentSession.state.name}` - ); - await agent.currentSession.updateState({ - ...agent.currentSession.state, - done: true, - }); - return response.text(); - }); - - const session = agent.createSession({ - state: { - name: 'Bob', - done: false, - }, - }); - - console.log(session.state); // { name: 'Bob', done: false } - console.log(await session.runFlow(agentFlow, undefined)); - console.log(session.state); // { name: 'Bob', done: true } - response = await session.send('What is my name?'); - console.log(response.text()); // { answer: '...Bob...' } - console.log(JSON.stringify(session.messages, undefined, ' ')); -})();