Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj committed Oct 9, 2024
1 parent 96cfca8 commit f9a14a8
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 30 deletions.
12 changes: 6 additions & 6 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ async function resolveModel(options: GenerateOptions): Promise<ResolvedModel> {
return {
modelAction: (await lookupAction(`/model/${model}`)) as ModelAction,
};
} else if (model.hasOwnProperty('name')) {
} else if (model.hasOwnProperty('__action')) {
return { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
return {
modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction,
Expand All @@ -460,8 +462,6 @@ async function resolveModel(options: GenerateOptions): Promise<ResolvedModel> {
},
version: ref.version,
};
} else {
return { modelAction: model as ModelAction };
}
}

Expand Down Expand Up @@ -515,10 +515,10 @@ export async function generate<
let modelId: string;
if (typeof resolvedOptions.model === 'string') {
modelId = resolvedOptions.model;
} else if ((resolvedOptions.model as ModelReference<any>).name) {
modelId = (resolvedOptions.model as ModelReference<any>).name;
} else {
} else if ((resolvedOptions.model as ModelAction)?.__action?.name) {
modelId = (resolvedOptions.model as ModelAction).__action.name;
} else {
modelId = (resolvedOptions.model as ModelReference<any>).name;
}
throw new Error(`Model ${modelId} not found`);
}
Expand Down
7 changes: 4 additions & 3 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -698,16 +698,16 @@ export class Genkit {
});
}

createSession<S extends z.ZodTypeAny = z.ZodTypeAny>(
chat<S extends z.ZodTypeAny = z.ZodTypeAny>(
options?: SessionOptions<S>
): Session<S>;

createSession<S extends z.ZodTypeAny = z.ZodTypeAny>(
chat<S extends z.ZodTypeAny = z.ZodTypeAny>(
requestBase: BaseGenerateOptions,
options?: SessionOptions<S>
): Session<S>;

createSession<S extends z.ZodTypeAny = z.ZodTypeAny>(
chat<S extends z.ZodTypeAny = z.ZodTypeAny>(
requestBaseOrOpts?: SessionOptions<S> | BaseGenerateOptions,
maybeOptions?: SessionOptions<S>
): Session<S> {
Expand All @@ -720,6 +720,7 @@ export class Genkit {
} else if (requestBaseOrOpts !== undefined) {
if (
(requestBaseOrOpts as SessionOptions<S>).state ||
(requestBaseOrOpts as SessionOptions<S>).store ||
(requestBaseOrOpts as SessionOptions<S>).schema
) {
options = requestBaseOrOpts as SessionOptions<S>;
Expand Down
23 changes: 11 additions & 12 deletions js/genkit/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ export interface SessionOptions<S extends z.ZodTypeAny> {

type EnvironmentType = Pick<
Genkit,
| 'defineFlow'
| 'defineStreamingFlow'
| 'defineTool'
| 'definePrompt'
'defineFlow' | 'defineStreamingFlow' | 'defineTool' | 'definePrompt'
>;

type EnvironmentSessionOptions<S extends z.ZodTypeAny> = Omit<
Expand Down Expand Up @@ -274,16 +271,18 @@ export class Session<S extends z.ZodTypeAny> {
prompt: options,
} as GenerateOptions<O, CustomOptions>;
}
const response = await this.environment.generateStream({
const { response, stream } = await this.environment.generateStream({
...this.requestBase,
messages: this.sessionData.messages,
...options,
});

try {
return response;
} finally {
await this.updateMessages((await response.response).toHistory());
}
return {
response: response.finally(async () => {
this.updateMessages((await response).toHistory());
}),
stream,
};
}

runFlow<
Expand Down Expand Up @@ -370,7 +369,7 @@ class InMemorySessionStore<S extends z.ZodTypeAny> implements SessionStore<S> {
return this.data[sessionId];
}

async save(sessionId: string, data: SessionData<S>): Promise<void> {
data[sessionId] = data;
async save(sessionId: string, sessionData: SessionData<S>): Promise<void> {
this.data[sessionId] = sessionData;
}
}
222 changes: 222 additions & 0 deletions js/genkit/tests/environment_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { Genkit, genkit } from '../src/genkit';
import { defineEchoModel, TestMemorySessionStore } from './helpers';
import { z } from '../src/index';

describe('environment', () => {
let ai: Genkit;

beforeEach(() => {
ai = genkit({
model: 'echoModel',
});
defineEchoModel(ai);
});

it('maintains history in the session', async () => {
const env = await ai.defineEnvironment({
name: 'agent',
stateSchema: z.object({
name: z.string(),
})
});

const session = env.createSession();
let response = await session.send('hi');

assert.strictEqual(response.text(), 'Echo: hi; config: {}');

response = await session.send('bye');

assert.strictEqual(
response.text(),
'Echo: hi,Echo: hi,; config: {},bye; config: {}'
);
assert.deepStrictEqual(response.toHistory(), [
{
content: [
{
text: 'hi',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi',
},
{
text: '; config: {}',
},
],
role: 'model',
},
{
content: [
{
text: 'bye',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi,Echo: hi,; config: {},bye',
},
{
text: '; config: {}',
},
],
role: 'model',
},
]);
});

it('maintains history in the session with streaming', async () => {
const env = await ai.defineEnvironment({
name: 'agent',
stateSchema: z.object({
name: z.string(),
})
});
const session = env.createSession();

let {response, stream} = await session.sendStream('hi');

let chunks: string[] = [];
for await (const chunk of stream) {
chunks.push(chunk.text());
}
assert.strictEqual((await response).text(), 'Echo: hi; config: {}');
assert.deepStrictEqual(chunks, ['3', '2', '1']);

({response, stream} = await session.sendStream('bye'));

chunks = [];
for await (const chunk of stream) {
chunks.push(chunk.text());
}

assert.deepStrictEqual(chunks, ['3', '2', '1']);
assert.strictEqual(
(await response).text(),
'Echo: hi,Echo: hi,; config: {},bye; config: {}'
);
assert.deepStrictEqual((await response).toHistory(), [
{
content: [
{
text: 'hi',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi',
},
{
text: '; config: {}',
},
],
role: 'model',
},
{
content: [
{
text: 'bye',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi,Echo: hi,; config: {},bye',
},
{
text: '; config: {}',
},
],
role: 'model',
},
]);
});

it('stores state and messages in the store', async () => {
const store = new TestMemorySessionStore();
const env = await ai.defineEnvironment({
name: 'agent',
store,
stateSchema: z.object({
name: z.string(),
})
});
const session = env.createSession();
await session.send('hi');
await session.send('bye');

const state = await store.get(session.id);

assert.deepStrictEqual(state?.messages, [
{
content: [
{
text: 'hi',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi',
},
{
text: '; config: {}',
},
],
role: 'model',
},
{
content: [
{
text: 'bye',
},
],
role: 'user',
},
{
content: [
{
text: 'Echo: hi,Echo: hi,; config: {},bye',
},
{
text: '; config: {}',
},
],
role: 'model',
},
]);
});
});
16 changes: 16 additions & 0 deletions js/genkit/tests/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

import { Genkit } from '../src/genkit';
import { z } from '../src/index';
import { SessionData, SessionStore } from '../src/session';

export function defineEchoModel(ai: Genkit) {
ai.defineModel(
Expand Down Expand Up @@ -76,3 +78,17 @@ export function defineEchoModel(ai: Genkit) {
async function runAsync<O>(fn: () => O): Promise<O> {
return Promise.resolve(fn());
}

export class TestMemorySessionStore<S extends z.ZodTypeAny>
implements SessionStore<S>
{
data: Record<string, SessionData<S>> = {};

async get(sessionId: string): Promise<SessionData<S> | undefined> {
return this.data[sessionId];
}

async save(sessionId: string, sessionData: SessionData<S>): Promise<void> {
this.data[sessionId] = sessionData;
}
}
6 changes: 3 additions & 3 deletions js/genkit/tests/models_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ describe('models', () => {
const response = await ai.generate({
prompt: 'hi',
});
assert.strictEqual(response.text(), 'Echo: hi; config: undefined');
assert.strictEqual(response.text(), 'Echo: hi; config: {}');
});

it('streams the default model', async () => {
Expand All @@ -49,7 +49,7 @@ describe('models', () => {
}
assert.strictEqual(
(await response).text(),
'Echo: hi; config: undefined'
'Echo: hi; config: {}'
);
assert.deepStrictEqual(chunks, ['3', '2', '1']);
});
Expand All @@ -68,7 +68,7 @@ describe('models', () => {
model: 'echoModel',
prompt: 'hi',
});
assert.strictEqual(response.text(), 'Echo: hi; config: undefined');
assert.strictEqual(response.text(), 'Echo: hi; config: {}');
});
});
});
Expand Down
Loading

0 comments on commit f9a14a8

Please sign in to comment.