From 65129d8fd45d8a631c74bf2ee89eeabc0e33c8fe Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Thu, 30 Nov 2023 23:58:10 -0600 Subject: [PATCH] feat: Add support to createAIExtractFunction for custom validation logic --- examples/ai-runner.ts | 27 +------- examples/extract-people-names.ts | 33 ++++++++++ src/prompt/functions/ai-extract-function.ts | 69 ++++++++++++++------- src/prompt/functions/ai-runner.ts | 16 +++-- src/prompt/types.ts | 7 +++ 5 files changed, 100 insertions(+), 52 deletions(-) create mode 100644 examples/extract-people-names.ts diff --git a/examples/ai-runner.ts b/examples/ai-runner.ts index 8c87ee9..daef4c5 100644 --- a/examples/ai-runner.ts +++ b/examples/ai-runner.ts @@ -4,7 +4,6 @@ import { ChatModel, Msg, createAIFunction, - createAIExtractFunction, createAIRunner, } from '@dexaai/dexter'; @@ -69,30 +68,10 @@ const weatherCapitalRunner = createAIRunner({ systemMessage: `You use functions to answer questions about the weather and capital cities.`, }); -/** A function to extract people names from a message. */ -const extractPeopleNamesRunner = createAIExtractFunction({ - chatModel: new ChatModel({ params: { model: 'gpt-4-1106-preview' } }), - systemMessage: `You use functions to extract people names from a message.`, - name: 'log_people_names', - description: `Use this to log the full names of people from a message. Don't include duplicate names.`, - schema: z.object({ - names: z.array( - z - .string() - .describe( - `The name of a person from the message. Normalize the name by removing suffixes, prefixes, and fixing capitalization` - ) - ), - }), -}); - +/** + * npx tsx examples/ai-runner.ts + */ async function main() { - // Use OpenAI functions to extract data adhering to a Zod schema - const peopleNames = await extractPeopleNamesRunner( - `Dr. Andrew Huberman interviewed Tony Hawk, an idol of Andrew Hubermans.` - ); - console.log('peopleNames', peopleNames); - // Run with a string input const rString = await weatherCapitalRunner( `Whats the capital of California and NY and the weather for both` diff --git a/examples/extract-people-names.ts b/examples/extract-people-names.ts new file mode 100644 index 0000000..326102a --- /dev/null +++ b/examples/extract-people-names.ts @@ -0,0 +1,33 @@ +import 'dotenv/config'; +import { z } from 'zod'; +import { ChatModel, createAIExtractFunction } from '@dexaai/dexter'; + +/** A function to extract people names from text. */ +const extractPeopleNamesRunner = createAIExtractFunction({ + chatModel: new ChatModel({ params: { model: 'gpt-4-1106-preview' } }), + systemMessage: `You use functions to extract people names from a message.`, + name: 'log_people_names', + description: `Use this to log the full names of people from a message. Don't include duplicate names.`, + schema: z.object({ + names: z.array( + z + .string() + .describe( + `The name of a person from the message. Normalize the name by removing suffixes, prefixes, and fixing capitalization` + ) + ), + }), +}); + +/** + * npx tsx examples/extract-people-names.ts + */ +async function main() { + // Use OpenAI functions to extract data adhering to a Zod schema + const peopleNames = await extractPeopleNamesRunner( + `Dr. Andrew Huberman interviewed Tony Hawk, an idol of Andrew Hubermans.` + ); + console.log('peopleNames', peopleNames); +} + +main().catch(console.error); diff --git a/src/prompt/functions/ai-extract-function.ts b/src/prompt/functions/ai-extract-function.ts index 9197df7..81450f3 100644 --- a/src/prompt/functions/ai-extract-function.ts +++ b/src/prompt/functions/ai-extract-function.ts @@ -7,27 +7,46 @@ import { createAIRunner } from './ai-runner.js'; /** * Use OpenAI function calling to extract data from a message. */ -export function createAIExtractFunction>({ - chatModel, - name, - description, - schema, - maxRetries = 0, - systemMessage, -}: { - /** The ChatModel used to make API calls. */ - chatModel: Model.Chat.Model; - /** The name of the extractor function. */ - name: string; - /** The description of the extractor function. */ - description?: string; - /** The Zod schema for the data to extract. */ - schema: Schema; - /** The maximum number of times to retry the function call. */ - maxRetries?: number; - /** Add a system message to the beginning of the messages array. */ - systemMessage?: string; -}): Prompt.ExtractFunction { +export function createAIExtractFunction>( + { + chatModel, + name, + description, + schema, + maxRetries = 0, + systemMessage, + params, + context, + functionCallConcurrency, + }: { + /** The ChatModel used to make API calls. */ + chatModel: Model.Chat.Model; + /** The name of the extractor function. */ + name: string; + /** The description of the extractor function. */ + description?: string; + /** The Zod schema for the data to extract. */ + schema: Schema; + /** The maximum number of times to retry the function call. */ + maxRetries?: number; + /** Add a system message to the beginning of the messages array. */ + systemMessage?: string; + /** Model params to use for each API call (optional). */ + params?: Prompt.Runner.ModelParams; + /** Optional context to pass to ChatModel.run calls */ + context?: Model.Ctx; + /** The number of function calls to make concurrently. */ + functionCallConcurrency?: number; + }, + /** + * Optional custom extraction function to call with the parsed arguments. + * + * This is useful for adding custom validation to the extracted data. + */ + customExtractImplementation?: ( + params: z.infer + ) => z.infer | Promise> +): Prompt.ExtractFunction { // The AIFunction that will be used to extract the data const extractFunction = createAIFunction( { @@ -35,7 +54,10 @@ export function createAIExtractFunction>({ description, argsSchema: schema, }, - async (args) => args + async (args): Promise> => { + if (customExtractImplementation) return customExtractImplementation(args); + else return args; + } ); // Create a runner that will call the function, validate the args and retry @@ -43,10 +65,13 @@ export function createAIExtractFunction>({ const runner = createAIRunner({ chatModel, systemMessage, + context, functions: [extractFunction], mode: 'functions', maxIterations: maxRetries + 1, + functionCallConcurrency, params: { + ...params, function_call: { name }, }, shouldBreakLoop: (message) => Msg.isFuncResult(message), diff --git a/src/prompt/functions/ai-runner.ts b/src/prompt/functions/ai-runner.ts index ac0e04c..372aed3 100644 --- a/src/prompt/functions/ai-runner.ts +++ b/src/prompt/functions/ai-runner.ts @@ -3,10 +3,6 @@ import { Msg, getErrorMsg } from '../index.js'; import type { Prompt } from '../types.js'; import type { Model } from '../../index.js'; -type RunnerModelParams = Partial< - Omit ->; - /** * Creates a function to run a chat model in a loop * - Handles parsing, running, and inserting responses for function & tool call messages @@ -31,7 +27,9 @@ export function createAIRunner(args: { /** Add a system message to the beginning of the messages array. */ systemMessage?: string; /** Model params to use for each API call (optional). */ - params?: RunnerModelParams; + params?: Prompt.Runner.ModelParams; + /** Optional context to pass to ChatModel.run calls */ + context?: Model.Ctx; }): Prompt.Runner { /** Return the content string or an empty string if null. */ function defaultValidateContent(content: string | null): Content { @@ -53,10 +51,16 @@ export function createAIRunner(args: { functionCallConcurrency, systemMessage, params: runnerModelParams, + context: runnerContext, validateContent = defaultValidateContent, shouldBreakLoop = defaultShouldBreakLoop, } = args; + const mergedContext = { + ...runnerContext, + ...context, + }; + // Add the functions/tools to the model params const additonalParams = getParams({ functions, mode }); @@ -86,7 +90,7 @@ export function createAIRunner(args: { ...additonalParams, messages, }; - const { message } = await chatModel.run(runParams, context); + const { message } = await chatModel.run(runParams, mergedContext); messages.push(message); // Run functions from tool/function call messages and append the new messages diff --git a/src/prompt/types.ts b/src/prompt/types.ts index 4dbe8a6..db36e08 100644 --- a/src/prompt/types.ts +++ b/src/prompt/types.ts @@ -18,6 +18,13 @@ export namespace Prompt { 'model' >; + export type ModelParams = Partial< + Omit< + Model.Chat.Run & Model.Chat.Config, + 'messages' | 'functions' | 'tools' + > + >; + /** Response from executing a runner */ export type Response = | {