Skip to content

Commit

Permalink
mistral[patch]: Force tool use in withStructuredOutput (#5932)
Browse files Browse the repository at this point in the history
* mistral[patch]: Force tool use in withStructuredOutput

* added test

* chore: lint files
  • Loading branch information
bracesproul authored Jun 28, 2024
1 parent 0a9090d commit 8625c1e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
4 changes: 2 additions & 2 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ export class ChatMistralAI<
},
},
],
tool_choice: "auto",
tool_choice: "any",
} as Partial<CallOptions>);
outputParser = new JsonOutputKeyToolsParser({
returnSingle: true,
Expand Down Expand Up @@ -830,7 +830,7 @@ export class ChatMistralAI<
function: openAIFunctionDefinition,
},
],
tool_choice: "auto",
tool_choice: "any",
} as Partial<CallOptions>);
outputParser = new JsonOutputKeyToolsParser<RunOutput>({
returnSingle: true,
Expand Down
29 changes: 29 additions & 0 deletions libs/langchain-mistralai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,32 @@ test("Invoke token count usage_metadata", async () => {
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("withStructuredOutput will always force tool usage", async () => {
const model = new ChatMistralAI({
temperature: 0,
model: "mistral-large-latest",
});

const weatherTool = z
.object({
location: z.string().describe("The name of city to get the weather for."),
})
.describe(
"Get the weather of a specific location and return the temperature in Celsius."
);
const modelWithTools = model.withStructuredOutput(weatherTool, {
name: "get_weather",
includeRaw: true,
});
const response = await modelWithTools.invoke(
"What is the sum of 271623 and 281623? It is VERY important you use a calculator tool to give me the answer."
);

if (!("tool_calls" in response.raw)) {
throw new Error("Tool call not found in response");
}
const castMessage = response.raw as AIMessage;
expect(castMessage.tool_calls).toHaveLength(1);
expect(castMessage.tool_calls?.[0].name).toBe("get_weather");
});

0 comments on commit 8625c1e

Please sign in to comment.