Skip to content

Commit

Permalink
✨ feat: add Ai21Labs model provider (lobehub#3727)
Browse files Browse the repository at this point in the history
* ✨ feat: add Ai21Labs model provider

* πŸ”¨ chore: fix rebase conflicts

* πŸ› fix: fix CI error

* πŸ’„ style: add model price

* πŸ› fix: fix CI error
  • Loading branch information
hezhijie0327 committed Sep 18, 2024
1 parent 8eac1bd commit d2fe0f0
Show file tree
Hide file tree
Showing 14 changed files with 349 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ ENV ACCESS_CODE="" \

# Model Variables
ENV \
# AI21
AI21_API_KEY="" \
# Ai360
AI360_API_KEY="" \
# Anthropic
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile.database
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ ENV NEXT_PUBLIC_S3_DOMAIN="" \

# Model Variables
ENV \
# AI21
AI21_API_KEY="" \
# Ai360
AI360_API_KEY="" \
# Anthropic
Expand Down
2 changes: 2 additions & 0 deletions src/app/(main)/settings/llm/ProviderList/providers.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useMemo } from 'react';

import {
Ai21ProviderCard,
Ai360ProviderCard,
AnthropicProviderCard,
BaichuanProviderCard,
Expand Down Expand Up @@ -57,6 +58,7 @@ export const useProviderList = (): ProviderItem[] => {
TogetherAIProviderCard,
FireworksAIProviderCard,
UpstageProviderCard,
Ai21ProviderCard,
QwenProviderCard,
SparkProviderCard,
ZhiPuProviderCard,
Expand Down
7 changes: 7 additions & 0 deletions src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {

const apiKey = apiKeyManager.pick(payload?.apiKey || SPARK_API_KEY);

return { apiKey };
}
case ModelProvider.Ai21: {
const { AI21_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || AI21_API_KEY);

return { apiKey };
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/config/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ export const getLLMConfig = () => {

ENABLED_SPARK: z.boolean(),
SPARK_API_KEY: z.string().optional(),

ENABLED_AI21: z.boolean(),
AI21_API_KEY: z.string().optional(),
},
runtimeEnv: {
API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE,
Expand Down Expand Up @@ -225,6 +228,9 @@ export const getLLMConfig = () => {

ENABLED_SPARK: !!process.env.SPARK_API_KEY,
SPARK_API_KEY: process.env.SPARK_API_KEY,

ENABLED_AI21: !!process.env.AI21_API_KEY,
AI21_API_KEY: process.env.AI21_API_KEY,
},
});
};
Expand Down
37 changes: 37 additions & 0 deletions src/config/modelProviders/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { ModelProviderCard } from '@/types/llm';

// ref https://docs.ai21.com/reference/jamba-15-api-ref
const Ai21: ModelProviderCard = {
chatModels: [
{
displayName: 'Jamba 1.5 Mini',
enabled: true,
functionCall: true,
id: 'jamba-1.5-mini',
pricing: {
input: 0.2,
output: 0.4,
},
tokens: 256_000,
},
{
displayName: 'Jamba 1.5 Large',
enabled: true,
functionCall: true,
id: 'jamba-1.5-large',
pricing: {
input: 2,
output: 8,
},
tokens: 256_000,
},
],
checkModel: 'jamba-1.5-mini',
id: 'ai21',
modelList: { showModelFetcher: true },
modelsUrl: 'https://docs.ai21.com/reference',
name: 'Ai21Labs',
url: 'https://studio.ai21.com',
};

export default Ai21;
4 changes: 4 additions & 0 deletions src/config/modelProviders/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ChatModelCard, ModelProviderCard } from '@/types/llm';

import Ai21Provider from './ai21';
import Ai360Provider from './ai360';
import AnthropicProvider from './anthropic';
import AzureProvider from './azure';
Expand Down Expand Up @@ -55,6 +56,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
SiliconCloudProvider.chatModels,
UpstageProvider.chatModels,
SparkProvider.chatModels,
Ai21Provider.chatModels,
].flat();

export const DEFAULT_MODEL_PROVIDER_LIST = [
Expand All @@ -74,6 +76,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
TogetherAIProvider,
FireworksAIProvider,
UpstageProvider,
Ai21Provider,
QwenProvider,
SparkProvider,
ZhiPuProvider,
Expand All @@ -96,6 +99,7 @@ export const isProviderDisableBroswerRequest = (id: string) => {
return !!provider;
};

export { default as Ai21ProviderCard } from './ai21';
export { default as Ai360ProviderCard } from './ai360';
export { default as AnthropicProviderCard } from './anthropic';
export { default as AzureProviderCard } from './azure';
Expand Down
5 changes: 5 additions & 0 deletions src/const/settings/llm.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
Ai21ProviderCard,
Ai360ProviderCard,
AnthropicProviderCard,
BaichuanProviderCard,
Expand Down Expand Up @@ -31,6 +32,10 @@ import { ModelProvider } from '@/libs/agent-runtime';
import { UserModelProviderConfig } from '@/types/user/settings';

export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = {
ai21: {
enabled: false,
enabledModels: filterEnabledModels(Ai21ProviderCard),
},
ai360: {
enabled: false,
enabledModels: filterEnabledModels(Ai360ProviderCard),
Expand Down
7 changes: 7 additions & 0 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ClientOptions } from 'openai';
import type { TracePayload } from '@/const/trace';

import { LobeRuntimeAI } from './BaseAI';
import { LobeAi21AI } from './ai21';
import { LobeAi360AI } from './ai360';
import { LobeAnthropicAI } from './anthropic';
import { LobeAzureOpenAI } from './azureOpenai';
Expand Down Expand Up @@ -117,6 +118,7 @@ class AgentRuntime {
static async initializeWithProviderOptions(
provider: string,
params: Partial<{
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
Expand Down Expand Up @@ -289,6 +291,11 @@ class AgentRuntime {
runtimeModel = new LobeSparkAI(params.spark);
break;
}

case ModelProvider.Ai21: {
runtimeModel = new LobeAi21AI(params.ai21);
break;
}
}

return new AgentRuntime(runtimeModel);
Expand Down
Loading

0 comments on commit d2fe0f0

Please sign in to comment.