From f9a14a8f57fa5cfb50b7c854b103ebd6319df70c Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 9 Oct 2024 10:25:57 -0400 Subject: [PATCH] WIP --- js/ai/src/generate.ts | 12 +- js/genkit/src/genkit.ts | 7 +- js/genkit/src/session.ts | 23 ++- js/genkit/tests/environment_test.ts | 222 ++++++++++++++++++++++++++++ js/genkit/tests/helpers.ts | 16 ++ js/genkit/tests/models_test.ts | 6 +- js/genkit/tests/prompts_test.ts | 8 +- js/genkit/tests/session_test.ts | 202 +++++++++++++++++++++++++ js/testapps/rag/src/bar.ts | 31 ++++ js/testapps/rag/src/foo.ts | 4 +- 10 files changed, 501 insertions(+), 30 deletions(-) create mode 100644 js/genkit/tests/environment_test.ts create mode 100644 js/genkit/tests/session_test.ts create mode 100644 js/testapps/rag/src/bar.ts diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 263f14e3..9b5fc259 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -451,7 +451,9 @@ async function resolveModel(options: GenerateOptions): Promise { return { modelAction: (await lookupAction(`/model/${model}`)) as ModelAction, }; - } else if (model.hasOwnProperty('name')) { + } else if (model.hasOwnProperty('__action')) { + return { modelAction: model as ModelAction }; + } else { const ref = model as ModelReference; return { modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction, @@ -460,8 +462,6 @@ async function resolveModel(options: GenerateOptions): Promise { }, version: ref.version, }; - } else { - return { modelAction: model as ModelAction }; } } @@ -515,10 +515,10 @@ export async function generate< let modelId: string; if (typeof resolvedOptions.model === 'string') { modelId = resolvedOptions.model; - } else if ((resolvedOptions.model as ModelReference).name) { - modelId = (resolvedOptions.model as ModelReference).name; - } else { + } else if ((resolvedOptions.model as ModelAction)?.__action?.name) { modelId = (resolvedOptions.model as ModelAction).__action.name; + } else { + modelId = (resolvedOptions.model as ModelReference).name; } throw new Error(`Model ${modelId} not found`); } diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 5293480e..46684451 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -698,16 +698,16 @@ export class Genkit { }); } - createSession( + chat( options?: SessionOptions ): Session; - createSession( + chat( requestBase: BaseGenerateOptions, options?: SessionOptions ): Session; - createSession( + chat( requestBaseOrOpts?: SessionOptions | BaseGenerateOptions, maybeOptions?: SessionOptions ): Session { @@ -720,6 +720,7 @@ export class Genkit { } else if (requestBaseOrOpts !== undefined) { if ( (requestBaseOrOpts as SessionOptions).state || + (requestBaseOrOpts as SessionOptions).store || (requestBaseOrOpts as SessionOptions).schema ) { options = requestBaseOrOpts as SessionOptions; diff --git a/js/genkit/src/session.ts b/js/genkit/src/session.ts index 03626f8a..b2f259b7 100644 --- a/js/genkit/src/session.ts +++ b/js/genkit/src/session.ts @@ -33,10 +33,7 @@ export interface SessionOptions { type EnvironmentType = Pick< Genkit, - | 'defineFlow' - | 'defineStreamingFlow' - | 'defineTool' - | 'definePrompt' + 'defineFlow' | 'defineStreamingFlow' | 'defineTool' | 'definePrompt' >; type EnvironmentSessionOptions = Omit< @@ -274,16 +271,18 @@ export class Session { prompt: options, } as GenerateOptions; } - const response = await this.environment.generateStream({ + const { response, stream } = await this.environment.generateStream({ ...this.requestBase, + messages: this.sessionData.messages, ...options, }); - try { - return response; - } finally { - await this.updateMessages((await response.response).toHistory()); - } + return { + response: response.finally(async () => { + this.updateMessages((await response).toHistory()); + }), + stream, + }; } runFlow< @@ -370,7 +369,7 @@ class InMemorySessionStore implements SessionStore { return this.data[sessionId]; } - async save(sessionId: string, data: SessionData): Promise { - data[sessionId] = data; + async save(sessionId: string, sessionData: SessionData): Promise { + this.data[sessionId] = sessionData; } } diff --git a/js/genkit/tests/environment_test.ts b/js/genkit/tests/environment_test.ts new file mode 100644 index 00000000..0488bbb3 --- /dev/null +++ b/js/genkit/tests/environment_test.ts @@ -0,0 +1,222 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; +import { defineEchoModel, TestMemorySessionStore } from './helpers'; +import { z } from '../src/index'; + +describe('environment', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('maintains history in the session', async () => { + const env = await ai.defineEnvironment({ + name: 'agent', + stateSchema: z.object({ + name: z.string(), + }) + }); + + const session = env.createSession(); + let response = await session.send('hi'); + + assert.strictEqual(response.text(), 'Echo: hi; config: {}'); + + response = await session.send('bye'); + + assert.strictEqual( + response.text(), + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual(response.toHistory(), [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); + + it('maintains history in the session with streaming', async () => { + const env = await ai.defineEnvironment({ + name: 'agent', + stateSchema: z.object({ + name: z.string(), + }) + }); + const session = env.createSession(); + + let {response, stream} = await session.sendStream('hi'); + + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text()); + } + assert.strictEqual((await response).text(), 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + + ({response, stream} = await session.sendStream('bye')); + + chunks = []; + for await (const chunk of stream) { + chunks.push(chunk.text()); + } + + assert.deepStrictEqual(chunks, ['3', '2', '1']); + assert.strictEqual( + (await response).text(), + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual((await response).toHistory(), [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); + + it('stores state and messages in the store', async () => { + const store = new TestMemorySessionStore(); + const env = await ai.defineEnvironment({ + name: 'agent', + store, + stateSchema: z.object({ + name: z.string(), + }) + }); + const session = env.createSession(); + await session.send('hi'); + await session.send('bye'); + + const state = await store.get(session.id); + + assert.deepStrictEqual(state?.messages, [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); +}); diff --git a/js/genkit/tests/helpers.ts b/js/genkit/tests/helpers.ts index 10f9e224..31f50964 100644 --- a/js/genkit/tests/helpers.ts +++ b/js/genkit/tests/helpers.ts @@ -15,6 +15,8 @@ */ import { Genkit } from '../src/genkit'; +import { z } from '../src/index'; +import { SessionData, SessionStore } from '../src/session'; export function defineEchoModel(ai: Genkit) { ai.defineModel( @@ -76,3 +78,17 @@ export function defineEchoModel(ai: Genkit) { async function runAsync(fn: () => O): Promise { return Promise.resolve(fn()); } + +export class TestMemorySessionStore + implements SessionStore +{ + data: Record> = {}; + + async get(sessionId: string): Promise | undefined> { + return this.data[sessionId]; + } + + async save(sessionId: string, sessionData: SessionData): Promise { + this.data[sessionId] = sessionData; + } +} diff --git a/js/genkit/tests/models_test.ts b/js/genkit/tests/models_test.ts index 8b7c03a6..d070283a 100644 --- a/js/genkit/tests/models_test.ts +++ b/js/genkit/tests/models_test.ts @@ -35,7 +35,7 @@ describe('models', () => { const response = await ai.generate({ prompt: 'hi', }); - assert.strictEqual(response.text(), 'Echo: hi; config: undefined'); + assert.strictEqual(response.text(), 'Echo: hi; config: {}'); }); it('streams the default model', async () => { @@ -49,7 +49,7 @@ describe('models', () => { } assert.strictEqual( (await response).text(), - 'Echo: hi; config: undefined' + 'Echo: hi; config: {}' ); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); @@ -68,7 +68,7 @@ describe('models', () => { model: 'echoModel', prompt: 'hi', }); - assert.strictEqual(response.text(), 'Echo: hi; config: undefined'); + assert.strictEqual(response.text(), 'Echo: hi; config: {}'); }); }); }); diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 10625e88..25f79949 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -94,7 +94,7 @@ describe('definePrompt - dotprompt', () => { }); assert.strictEqual( response.text(), - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); }); }); @@ -186,7 +186,7 @@ describe('definePrompt - dotprompt', () => { assert.strictEqual( responseText, - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); @@ -518,7 +518,7 @@ describe('definePrompt', () => { ); assert.strictEqual( response.text(), - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); }); @@ -581,7 +581,7 @@ describe('definePrompt', () => { assert.strictEqual( responseText, - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); diff --git a/js/genkit/tests/session_test.ts b/js/genkit/tests/session_test.ts new file mode 100644 index 00000000..5f1d771a --- /dev/null +++ b/js/genkit/tests/session_test.ts @@ -0,0 +1,202 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; +import { z } from '../src/index'; +import { SessionData, SessionStore } from '../src/session'; +import { defineEchoModel, TestMemorySessionStore } from './helpers'; + +describe('session', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('maintains history in the session', async () => { + const session = ai.chat(); + let response = await session.send('hi'); + + assert.strictEqual(response.text(), 'Echo: hi; config: {}'); + + response = await session.send('bye'); + + assert.strictEqual( + response.text(), + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual(response.toHistory(), [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); + + it('maintains history in the session with streaming', async () => { + const session = await ai.chat(); + let { response, stream } = await session.sendStream('hi'); + + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text()); + } + assert.strictEqual((await response).text(), 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + + ({ response, stream } = await session.sendStream('bye')); + + chunks = []; + for await (const chunk of stream) { + chunks.push(chunk.text()); + } + + assert.deepStrictEqual(chunks, ['3', '2', '1']); + assert.strictEqual( + (await response).text(), + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual((await response).toHistory(), [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); + + it('stores state and messages in the store', async () => { + const store = new TestMemorySessionStore(); + const session = ai.chat({ store }); + await session.send('hi'); + await session.send('bye'); + + const state = await store.get(session.id); + + assert.deepStrictEqual(state?.messages, [ + { + content: [ + { + text: 'hi', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + { + content: [ + { + text: 'bye', + }, + ], + role: 'user', + }, + { + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {},bye', + }, + { + text: '; config: {}', + }, + ], + role: 'model', + }, + ]); + }); +}); diff --git a/js/testapps/rag/src/bar.ts b/js/testapps/rag/src/bar.ts new file mode 100644 index 00000000..ee0bbceb --- /dev/null +++ b/js/testapps/rag/src/bar.ts @@ -0,0 +1,31 @@ +import { vertexAI } from '@genkit-ai/vertexai'; +import { genkit, z } from 'genkit'; +import { modelRef } from 'genkit/model'; + +const ai = genkit({ + plugins: [vertexAI({ location: 'us-central1' })], + model: modelRef({ + name: 'vertexai/gemini-1.5-flash', + config: { + temperature: 1, + }, + }), +}); + +const hi = ai.definePrompt( + { + name: 'hi', + input: { + schema: z.object({ + name: z.string(), + }), + }, + }, + 'hi {{ name }}' +); + + +(async () => { + const response = await hi({ name: 'Genkit' }); + console.log(response.text()); +})() diff --git a/js/testapps/rag/src/foo.ts b/js/testapps/rag/src/foo.ts index 20ab4eea..9048eec0 100644 --- a/js/testapps/rag/src/foo.ts +++ b/js/testapps/rag/src/foo.ts @@ -39,7 +39,7 @@ const Character = z.object({ }); // text chat - let chatbotSession = ai.createSession(); + let chatbotSession = ai.chat(); response = await chatbotSession.send('hi my name is John'); console.log(response.text()); response = await chatbotSession.send('who am I?'); @@ -47,7 +47,7 @@ const Character = z.object({ // json chat - chatbotSession = ai.createSession({ + chatbotSession = ai.chat({ output: { schema: z.object({ answer: z.string(),