Skip to content

Commit

Permalink
feat: Add support to createAIExtractFunction for custom validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
transitive-bullshit committed Dec 1, 2023
1 parent ad54806 commit 65129d8
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 52 deletions.
27 changes: 3 additions & 24 deletions examples/ai-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
ChatModel,
Msg,
createAIFunction,
createAIExtractFunction,
createAIRunner,
} from '@dexaai/dexter';

Expand Down Expand Up @@ -69,30 +68,10 @@ const weatherCapitalRunner = createAIRunner({
systemMessage: `You use functions to answer questions about the weather and capital cities.`,
});

/** A function to extract people names from a message. */
const extractPeopleNamesRunner = createAIExtractFunction({
chatModel: new ChatModel({ params: { model: 'gpt-4-1106-preview' } }),
systemMessage: `You use functions to extract people names from a message.`,
name: 'log_people_names',
description: `Use this to log the full names of people from a message. Don't include duplicate names.`,
schema: z.object({
names: z.array(
z
.string()
.describe(
`The name of a person from the message. Normalize the name by removing suffixes, prefixes, and fixing capitalization`
)
),
}),
});

/**
* npx tsx examples/ai-runner.ts
*/
async function main() {
// Use OpenAI functions to extract data adhering to a Zod schema
const peopleNames = await extractPeopleNamesRunner(
`Dr. Andrew Huberman interviewed Tony Hawk, an idol of Andrew Hubermans.`
);
console.log('peopleNames', peopleNames);

// Run with a string input
const rString = await weatherCapitalRunner(
`Whats the capital of California and NY and the weather for both`
Expand Down
33 changes: 33 additions & 0 deletions examples/extract-people-names.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import 'dotenv/config';
import { z } from 'zod';
import { ChatModel, createAIExtractFunction } from '@dexaai/dexter';

/** A function to extract people names from text. */
const extractPeopleNamesRunner = createAIExtractFunction({
chatModel: new ChatModel({ params: { model: 'gpt-4-1106-preview' } }),
systemMessage: `You use functions to extract people names from a message.`,
name: 'log_people_names',
description: `Use this to log the full names of people from a message. Don't include duplicate names.`,
schema: z.object({
names: z.array(
z
.string()
.describe(
`The name of a person from the message. Normalize the name by removing suffixes, prefixes, and fixing capitalization`
)
),
}),
});

/**
* npx tsx examples/extract-people-names.ts
*/
async function main() {
// Use OpenAI functions to extract data adhering to a Zod schema
const peopleNames = await extractPeopleNamesRunner(
`Dr. Andrew Huberman interviewed Tony Hawk, an idol of Andrew Hubermans.`
);
console.log('peopleNames', peopleNames);
}

main().catch(console.error);
69 changes: 47 additions & 22 deletions src/prompt/functions/ai-extract-function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,71 @@ import { createAIRunner } from './ai-runner.js';
/**
* Use OpenAI function calling to extract data from a message.
*/
export function createAIExtractFunction<Schema extends z.ZodObject<any>>({
chatModel,
name,
description,
schema,
maxRetries = 0,
systemMessage,
}: {
/** The ChatModel used to make API calls. */
chatModel: Model.Chat.Model;
/** The name of the extractor function. */
name: string;
/** The description of the extractor function. */
description?: string;
/** The Zod schema for the data to extract. */
schema: Schema;
/** The maximum number of times to retry the function call. */
maxRetries?: number;
/** Add a system message to the beginning of the messages array. */
systemMessage?: string;
}): Prompt.ExtractFunction<Schema> {
export function createAIExtractFunction<Schema extends z.ZodObject<any>>(
{
chatModel,
name,
description,
schema,
maxRetries = 0,
systemMessage,
params,
context,
functionCallConcurrency,
}: {
/** The ChatModel used to make API calls. */
chatModel: Model.Chat.Model;
/** The name of the extractor function. */
name: string;
/** The description of the extractor function. */
description?: string;
/** The Zod schema for the data to extract. */
schema: Schema;
/** The maximum number of times to retry the function call. */
maxRetries?: number;
/** Add a system message to the beginning of the messages array. */
systemMessage?: string;
/** Model params to use for each API call (optional). */
params?: Prompt.Runner.ModelParams;
/** Optional context to pass to ChatModel.run calls */
context?: Model.Ctx;
/** The number of function calls to make concurrently. */
functionCallConcurrency?: number;
},
/**
* Optional custom extraction function to call with the parsed arguments.
*
* This is useful for adding custom validation to the extracted data.
*/
customExtractImplementation?: (
params: z.infer<Schema>
) => z.infer<Schema> | Promise<z.infer<Schema>>
): Prompt.ExtractFunction<Schema> {
// The AIFunction that will be used to extract the data
const extractFunction = createAIFunction(
{
name,
description,
argsSchema: schema,
},
async (args) => args
async (args): Promise<z.infer<Schema>> => {
if (customExtractImplementation) return customExtractImplementation(args);
else return args;
}
);

// Create a runner that will call the function, validate the args and retry
// if necessary, and return the result.
const runner = createAIRunner({
chatModel,
systemMessage,
context,
functions: [extractFunction],
mode: 'functions',
maxIterations: maxRetries + 1,
functionCallConcurrency,
params: {
...params,
function_call: { name },
},
shouldBreakLoop: (message) => Msg.isFuncResult(message),
Expand Down
16 changes: 10 additions & 6 deletions src/prompt/functions/ai-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ import { Msg, getErrorMsg } from '../index.js';
import type { Prompt } from '../types.js';
import type { Model } from '../../index.js';

type RunnerModelParams = Partial<
Omit<Model.Chat.Run & Model.Chat.Config, 'messages' | 'functions' | 'tools'>
>;

/**
* Creates a function to run a chat model in a loop
* - Handles parsing, running, and inserting responses for function & tool call messages
Expand All @@ -31,7 +27,9 @@ export function createAIRunner<Content extends any = string>(args: {
/** Add a system message to the beginning of the messages array. */
systemMessage?: string;
/** Model params to use for each API call (optional). */
params?: RunnerModelParams;
params?: Prompt.Runner.ModelParams;
/** Optional context to pass to ChatModel.run calls */
context?: Model.Ctx;
}): Prompt.Runner<Content> {
/** Return the content string or an empty string if null. */
function defaultValidateContent(content: string | null): Content {
Expand All @@ -53,10 +51,16 @@ export function createAIRunner<Content extends any = string>(args: {
functionCallConcurrency,
systemMessage,
params: runnerModelParams,
context: runnerContext,
validateContent = defaultValidateContent,
shouldBreakLoop = defaultShouldBreakLoop,
} = args;

const mergedContext = {
...runnerContext,
...context,
};

// Add the functions/tools to the model params
const additonalParams = getParams({ functions, mode });

Expand Down Expand Up @@ -86,7 +90,7 @@ export function createAIRunner<Content extends any = string>(args: {
...additonalParams,
messages,
};
const { message } = await chatModel.run(runParams, context);
const { message } = await chatModel.run(runParams, mergedContext);

This comment has been minimized.

Copy link
@rileytomasek

rileytomasek Dec 2, 2023

Contributor

@transitive-bullshit The ChatModel already stores context and params state, and chatModel.run() will merge the new context with the existing model context.

I'm really not a fan of having so many ways to do the same thing and increasing the amount of args passed to functions/classes like this. Can we remove this duplicate code path for passing params/context to keep things cleaner.

This comment has been minimized.

Copy link
@transitive-bullshit

transitive-bullshit Dec 2, 2023

Author Collaborator

I agree w/ the preference to not have multiple ways of doing the same thing, but it wasn't possible to do this and have createAIExtractFunction support context and params at the same time.

e.g., it's very reasonable to want to customize context/params for createAIExtractFunction, and if you can do it there, it makes sense to have the two constructor options mirror each other.

This is an artifact of having createAIRunner and createAIExtractFunction btw which are very slight variations on each other, whereas I still think it was cleaner to just have createAIChain which did both.

This comment has been minimized.

Copy link
@rileytomasek

rileytomasek Dec 2, 2023

Contributor

does this not fix the issue? #14

This comment has been minimized.

Copy link
@transitive-bullshit

transitive-bullshit Dec 3, 2023

Author Collaborator

discussed offline and went w/ a variation of #14 and #17 🚀

messages.push(message);

// Run functions from tool/function call messages and append the new messages
Expand Down
7 changes: 7 additions & 0 deletions src/prompt/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ export namespace Prompt {
'model'
>;

export type ModelParams = Partial<
Omit<
Model.Chat.Run & Model.Chat.Config,
'messages' | 'functions' | 'tools'
>
>;

/** Response from executing a runner */
export type Response<Content extends any = string> =
| {
Expand Down

0 comments on commit 65129d8

Please sign in to comment.