diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000000..7f5566fb979 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM node:18-bullseye + +RUN useradd -m -s /bin/bash vscode +RUN mkdir -p /workspaces && chown -R vscode:vscode /workspaces +WORKDIR /workspaces diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ebfd2685ee6..a3bb7805501 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -13,5 +13,6 @@ } }, "postCreateCommand": "", - "features": { "ghcr.io/devcontainers/features/git:1": {} } + "features": { "ghcr.io/devcontainers/features/git:1": {} }, + "remoteUser": "vscode" } diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index c67fca63019..277ac84f856 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -2,7 +2,9 @@ version: "3.8" services: app: - image: node:19-bullseye + build: + context: .. + dockerfile: .devcontainer/Dockerfile # restart: always links: - mongodb @@ -30,8 +32,8 @@ services: # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. # (Adding the "ports" property to this file will not forward from a Codespace.) - # Uncomment the next line to use a non-root user for all processes - See https://aka.ms/vscode-remote/containers/non-root for details. - # user: vscode + # Use a non-root user for all processes - See https://aka.ms/vscode-remote/containers/non-root for details. + user: vscode # Overrides default command so things don't shut down after the process ends. command: /bin/sh -c "while sleep 1000; do :; done" diff --git a/.env.example b/.env.example index 2e23a09a349..bd212cc7baa 100644 --- a/.env.example +++ b/.env.example @@ -1,21 +1,18 @@ -#=============================================================# -# LibreChat Configuration # -#=============================================================# -# Please refer to the reference documentation for assistance # -# with configuring your LibreChat environment. The guide is # -# available both online and within your local LibreChat # -# directory: # -# Online: https://docs.librechat.ai/install/dotenv.html # -# Locally: ./docs/install/dotenv.md # -#=============================================================# +#=====================================================================# +# LibreChat Configuration # +#=====================================================================# +# Please refer to the reference documentation for assistance # +# with configuring your LibreChat environment. The guide is # +# available both online and within your local LibreChat # +# directory: # +# Online: https://docs.librechat.ai/install/configuration/dotenv.html # +# Locally: ./docs/install/configuration/dotenv.md # +#=====================================================================# #==================================================# # Server Configuration # #==================================================# -APP_TITLE=LibreChat -# CUSTOM_FOOTER="My custom footer" - HOST=localhost PORT=3080 @@ -26,6 +23,13 @@ DOMAIN_SERVER=http://localhost:3080 NO_INDEX=true +#===============# +# JSON Logging # +#===============# + +# Use when process console logs in cloud deployment like GCP/AWS +CONSOLE_JSON=false + #===============# # Debug Logging # #===============# @@ -40,38 +44,62 @@ DEBUG_CONSOLE=false # UID=1000 # GID=1000 +#===============# +# Configuration # +#===============# +# Use an absolute path, a relative path, or a URL + +# CONFIG_PATH="/alternative/path/to/librechat.yaml" + #===================================================# # Endpoints # #===================================================# -# ENDPOINTS=openAI,azureOpenAI,bingAI,chatGPTBrowser,google,gptPlugins,anthropic +# ENDPOINTS=openAI,assistants,azureOpenAI,bingAI,google,gptPlugins,anthropic PROXY= +#===================================# +# Known Endpoints - librechat.yaml # +#===================================# +# https://docs.librechat.ai/install/configuration/ai_endpoints.html + +# GROQ_API_KEY= +# SHUTTLEAI_KEY= +# OPENROUTER_KEY= +# MISTRAL_API_KEY= +# ANYSCALE_API_KEY= +# FIREWORKS_API_KEY= +# PERPLEXITY_API_KEY= +# TOGETHERAI_API_KEY= + #============# # Anthropic # #============# ANTHROPIC_API_KEY=user_provided -ANTHROPIC_MODELS=claude-1,claude-instant-1,claude-2 +# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k # ANTHROPIC_REVERSE_PROXY= #============# # Azure # #============# -# AZURE_API_KEY= -AZURE_OPENAI_MODELS=gpt-3.5-turbo,gpt-4 -# AZURE_OPENAI_DEFAULT_MODEL=gpt-3.5-turbo -# PLUGINS_USE_AZURE="true" -AZURE_USE_MODEL_AS_DEPLOYMENT_NAME=TRUE +# Note: these variables are DEPRECATED +# Use the `librechat.yaml` configuration for `azureOpenAI` instead +# You may also continue to use them if you opt out of using the `librechat.yaml` configuration -# AZURE_OPENAI_API_INSTANCE_NAME= -# AZURE_OPENAI_API_DEPLOYMENT_NAME= -# AZURE_OPENAI_API_VERSION= -# AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME= -# AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME= +# AZURE_OPENAI_DEFAULT_MODEL=gpt-3.5-turbo # Deprecated +# AZURE_OPENAI_MODELS=gpt-3.5-turbo,gpt-4 # Deprecated +# AZURE_USE_MODEL_AS_DEPLOYMENT_NAME=TRUE # Deprecated +# AZURE_API_KEY= # Deprecated +# AZURE_OPENAI_API_INSTANCE_NAME= # Deprecated +# AZURE_OPENAI_API_DEPLOYMENT_NAME= # Deprecated +# AZURE_OPENAI_API_VERSION= # Deprecated +# AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME= # Deprecated +# AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME= # Deprecated +# PLUGINS_USE_AZURE="true" # Deprecated #============# # BingAI # @@ -80,14 +108,6 @@ AZURE_USE_MODEL_AS_DEPLOYMENT_NAME=TRUE BINGAI_TOKEN=user_provided # BINGAI_HOST=https://cn.bing.com -#============# -# ChatGPT # -#============# - -CHATGPT_TOKEN= -CHATGPT_MODELS=text-davinci-002-render-sha -# CHATGPT_REVERSE_PROXY= - #============# # Google # #============# @@ -101,7 +121,7 @@ GOOGLE_KEY=user_provided #============# OPENAI_API_KEY=user_provided -# OPENAI_MODELS=gpt-3.5-turbo-1106,gpt-4-1106-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-0301,gpt-4,gpt-4-0314,gpt-4-0613 +# OPENAI_MODELS=gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k DEBUG_OPENAI=false @@ -115,7 +135,15 @@ DEBUG_OPENAI=false # OPENAI_REVERSE_PROXY= -# OPENAI_ORGANIZATION= +# OPENAI_ORGANIZATION= + +#====================# +# Assistants API # +#====================# + +ASSISTANTS_API_KEY=user_provided +# ASSISTANTS_BASE_URL= +# ASSISTANTS_MODELS=gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview #============# # OpenRouter # @@ -127,7 +155,7 @@ DEBUG_OPENAI=false # Plugins # #============# -# PLUGIN_MODELS=gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-0301,gpt-4,gpt-4-0314,gpt-4-0613 +# PLUGIN_MODELS=gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613 DEBUG_PLUGINS=true @@ -147,20 +175,20 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # DALL·E #---------------- -# DALLE_API_KEY= # Key for both DALL-E-2 and DALL-E-3 -# DALLE3_API_KEY= # Key for DALL-E-3 only -# DALLE2_API_KEY= # Key for DALL-E-2 only -# DALLE3_SYSTEM_PROMPT="Your DALL-E-3 System Prompt here" -# DALLE2_SYSTEM_PROMPT="Your DALL-E-2 System Prompt here" -# DALLE_REVERSE_PROXY= # Reverse proxy for DALL-E-2 and DALL-E-3 -# DALLE3_BASEURL= # Base URL for DALL-E-3 -# DALLE2_BASEURL= # Base URL for DALL-E-2 +# DALLE_API_KEY= +# DALLE3_API_KEY= +# DALLE2_API_KEY= +# DALLE3_SYSTEM_PROMPT= +# DALLE2_SYSTEM_PROMPT= +# DALLE_REVERSE_PROXY= +# DALLE3_BASEURL= +# DALLE2_BASEURL= # DALL·E (via Azure OpenAI) # Note: requires some of the variables above to be set #---------------- -# DALLE3_AZURE_API_VERSION= # Azure OpenAI API version for DALL-E-3 -# DALLE2_AZURE_API_VERSION= # Azure OpenAI API versiion for DALL-E-2 +# DALLE3_AZURE_API_VERSION= +# DALLE2_AZURE_API_VERSION= # Google #----------------- @@ -175,6 +203,14 @@ SERPAPI_API_KEY= #----------------- SD_WEBUI_URL=http://host.docker.internal:7860 +# Tavily +#----------------- +TAVILY_API_KEY= + +# Traversaal +#----------------- +TRAVERSAAL_API_KEY= + # WolframAlpha #----------------- WOLFRAM_APP_ID= @@ -202,7 +238,7 @@ MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt OPENAI_MODERATION=false OPENAI_MODERATION_API_KEY= -# OPENAI_MODERATION_REVERSE_PROXY=not working with some reverse proxys +# OPENAI_MODERATION_REVERSE_PROXY= BAN_VIOLATIONS=true BAN_DURATION=1000 * 60 * 60 * 2 @@ -230,6 +266,8 @@ LIMIT_MESSAGE_USER=false MESSAGE_USER_MAX=40 MESSAGE_USER_WINDOW=1 +ILLEGAL_MODEL_REQ_SCORE=5 + #========================# # Balance # #========================# @@ -278,6 +316,9 @@ OPENID_ISSUER= OPENID_SESSION_SECRET= OPENID_SCOPE="openid profile email" OPENID_CALLBACK_URL=/oauth/openid/callback +OPENID_REQUIRED_ROLE= +OPENID_REQUIRED_ROLE_TOKEN_KIND= +OPENID_REQUIRED_ROLE_PARAMETER_PATH= OPENID_BUTTON_LABEL= OPENID_IMAGE_URL= @@ -286,15 +327,15 @@ OPENID_IMAGE_URL= # Email Password Reset # #========================# -EMAIL_SERVICE= -EMAIL_HOST= -EMAIL_PORT=25 -EMAIL_ENCRYPTION= -EMAIL_ENCRYPTION_HOSTNAME= -EMAIL_ALLOW_SELFSIGNED= -EMAIL_USERNAME= -EMAIL_PASSWORD= -EMAIL_FROM_NAME= +EMAIL_SERVICE= +EMAIL_HOST= +EMAIL_PORT=25 +EMAIL_ENCRYPTION= +EMAIL_ENCRYPTION_HOSTNAME= +EMAIL_ALLOW_SELFSIGNED= +EMAIL_USERNAME= +EMAIL_PASSWORD= +EMAIL_FROM_NAME= EMAIL_FROM=noreply@librechat.ai #========================# @@ -308,6 +349,16 @@ FIREBASE_STORAGE_BUCKET= FIREBASE_MESSAGING_SENDER_ID= FIREBASE_APP_ID= +#===================================================# +# UI # +#===================================================# + +APP_TITLE=LibreChat +# CUSTOM_FOOTER="My custom footer" +HELP_AND_FAQ_URL=https://librechat.ai + +# SHOW_BIRTHDAY_ICON=true + #==================================================# # Others # #==================================================# diff --git a/.eslintrc.js b/.eslintrc.js index a3d71acd69f..e85e0d768ca 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -19,6 +19,7 @@ module.exports = { 'e2e/playwright-report/**/*', 'packages/data-provider/types/**/*', 'packages/data-provider/dist/**/*', + 'packages/data-provider/test_bundle/**/*', 'data-node/**/*', 'meili_data/**/*', 'node_modules/**/*', @@ -131,6 +132,12 @@ module.exports = { }, ], }, + { + files: ['./packages/data-provider/specs/**/*.ts'], + parserOptions: { + project: './packages/data-provider/tsconfig.spec.json', + }, + }, ], settings: { react: { diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md index 3f39cc00b3b..cb767cbd7cd 100644 --- a/.github/CODE_OF_CONDUCT.md +++ b/.github/CODE_OF_CONDUCT.md @@ -60,7 +60,7 @@ representative at an online or offline event. Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement here on GitHub or -on the official [Discord Server](https://discord.gg/uDyZ5Tzhct). +on the official [Discord Server](https://discord.librechat.ai). All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 36618437fab..142f67c953f 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -8,7 +8,7 @@ If the feature you would like to contribute has not already received prior appro Please note that a pull request involving a feature that has not been reviewed and approved by the project maintainers may be rejected. We appreciate your understanding and cooperation. -If you would like to discuss the changes you wish to make, join our [Discord community](https://discord.gg/uDyZ5Tzhct), where you can engage with other contributors and seek guidance from the community. +If you would like to discuss the changes you wish to make, join our [Discord community](https://discord.librechat.ai), where you can engage with other contributors and seek guidance from the community. ## Our Standards diff --git a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml index b6b64c3f2de..5c88b9f70dc 100644 --- a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml +++ b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml @@ -50,7 +50,7 @@ body: id: terms attributes: label: Code of Conduct - description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/CODE_OF_CONDUCT.md) + description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md) options: - label: I agree to follow this project's Code of Conduct required: true diff --git a/.github/SECURITY.md b/.github/SECURITY.md index bd105f2526c..b01e04e0160 100644 --- a/.github/SECURITY.md +++ b/.github/SECURITY.md @@ -12,7 +12,7 @@ When reporting a security vulnerability, you have the following options to reach - **Option 2: GitHub Issues**: You can initiate first contact via GitHub Issues. However, please note that initial contact through GitHub Issues should not include any sensitive details. -- **Option 3: Discord Server**: You can join our [Discord community](https://discord.gg/5rbRxn4uME) and initiate first contact in the `#issues` channel. However, please ensure that initial contact through Discord does not include any sensitive details. +- **Option 3: Discord Server**: You can join our [Discord community](https://discord.librechat.ai) and initiate first contact in the `#issues` channel. However, please ensure that initial contact through Discord does not include any sensitive details. _After the initial contact, we will establish a private communication channel for further discussion._ @@ -39,11 +39,11 @@ Please note that as a security-conscious community, we may not always disclose d This security policy applies to the following GitHub repository: -- Repository: [LibreChat](https://github.com/danny-avila/LibreChat) +- Repository: [LibreChat](https://github.librechat.ai) ## Contact -If you have any questions or concerns regarding the security of our project, please join our [Discord community](https://discord.gg/NGaa9RPCft) and report them in the appropriate channel. You can also reach out to us by [opening an issue](https://github.com/danny-avila/LibreChat/issues/new) on GitHub. Please note that the response time may vary depending on the nature and severity of the inquiry. +If you have any questions or concerns regarding the security of our project, please join our [Discord community](https://discord.librechat.ai) and report them in the appropriate channel. You can also reach out to us by [opening an issue](https://github.com/danny-avila/LibreChat/issues/new) on GitHub. Please note that the response time may vary depending on the nature and severity of the inquiry. ## Acknowledgments diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 06d2656bd64..a1542cb76e4 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,8 +15,9 @@ Please delete any irrelevant options. - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update -- [ ] Documentation update - [ ] Translation update +- [ ] Documentation update + ## Testing @@ -26,6 +27,8 @@ Please describe your test process and include instructions so that we can reprod ## Checklist +Please delete any irrelevant options. + - [ ] My code adheres to this project's style guidelines - [ ] I have performed a self-review of my own code - [ ] I have commented in any complex areas of my code @@ -34,3 +37,4 @@ Please describe your test process and include instructions so that we can reprod - [ ] I have written tests demonstrating that my changes are effective or that my feature works - [ ] Local unit tests pass with my changes - [ ] Any changes dependent on mine have been merged and published in downstream modules. +- [ ] New documents have been locally validated with mkdocs diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index fddb6cdac63..db46653c651 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -30,10 +30,28 @@ jobs: - name: Install Data Provider run: npm run build:data-provider + + - name: Create empty auth.json file + run: | + mkdir -p api/data + echo '{}' > api/data/auth.json + + - name: Check for Circular dependency in rollup + working-directory: ./packages/data-provider + run: | + output=$(npm run rollup:api) + echo "$output" + if echo "$output" | grep -q "Circular dependency"; then + echo "Error: Circular dependency detected!" + exit 1 + fi - name: Run unit tests run: cd api && npm run test:ci + - name: Run librechat-data-provider unit tests + run: cd packages/data-provider && npm run test:ci + - name: Run linters uses: wearerequired/lint-action@v2 with: diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml deleted file mode 100644 index 23c6ad48cc8..00000000000 --- a/.github/workflows/container.yml +++ /dev/null @@ -1,83 +0,0 @@ -name: Docker Compose Build on Tag - -# The workflow is triggered when a tag is pushed -on: - push: - tags: - - "*" - -jobs: - build: - runs-on: ubuntu-latest - - steps: - # Check out the repository - - name: Checkout - uses: actions/checkout@v4 - - # Set up Docker - - name: Set up Docker - uses: docker/setup-buildx-action@v3 - - # Set up QEMU for cross-platform builds - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - # Log in to GitHub Container Registry - - name: Log in to GitHub Container Registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - # Prepare Docker Build - - name: Build Docker images - run: | - cp .env.example .env - - # Tag and push librechat-api - - name: Docker metadata for librechat-api - id: meta-librechat-api - uses: docker/metadata-action@v5 - with: - images: | - ghcr.io/${{ github.repository_owner }}/librechat-api - tags: | - type=raw,value=latest - type=semver,pattern={{version}} - type=semver,pattern={{major}} - type=semver,pattern={{major}}.{{minor}} - - - name: Build and librechat-api - uses: docker/build-push-action@v5 - with: - file: Dockerfile.multi - context: . - push: true - tags: ${{ steps.meta-librechat-api.outputs.tags }} - platforms: linux/amd64,linux/arm64 - target: api-build - - # Tag and push librechat - - name: Docker metadata for librechat - id: meta-librechat - uses: docker/metadata-action@v5 - with: - images: | - ghcr.io/${{ github.repository_owner }}/librechat - tags: | - type=raw,value=latest - type=semver,pattern={{version}} - type=semver,pattern={{major}} - type=semver,pattern={{major}}.{{minor}} - - - name: Build and librechat - uses: docker/build-push-action@v5 - with: - file: Dockerfile - context: . - push: true - tags: ${{ steps.meta-librechat.outputs.tags }} - platforms: linux/amd64,linux/arm64 - target: node diff --git a/.github/workflows/dev-images.yml b/.github/workflows/dev-images.yml index e0149e05e9c..41d427c6c8b 100644 --- a/.github/workflows/dev-images.yml +++ b/.github/workflows/dev-images.yml @@ -2,18 +2,38 @@ name: Docker Dev Images Build on: workflow_dispatch: + push: + branches: + - main + paths: + - 'api/**' + - 'client/**' + - 'packages/**' jobs: build: runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: api-build + file: Dockerfile.multi + image_name: librechat-dev-api + - target: node + file: Dockerfile + image_name: librechat-dev steps: # Check out the repository - name: Checkout uses: actions/checkout@v4 - # Set up Docker - - name: Set up Docker + # Set up QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Set up Docker Buildx + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 # Log in to GitHub Container Registry @@ -24,22 +44,29 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # Build Docker images - - name: Build Docker images + # Login to Docker Hub + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # Prepare the environment + - name: Prepare environment run: | cp .env.example .env - docker build -f Dockerfile.multi --target api-build -t librechat-dev-api . - docker build -f Dockerfile -t librechat-dev . - # Tag and push the images to GitHub Container Registry - - name: Tag and push images - run: | - docker tag librechat-dev-api:latest ghcr.io/${{ github.repository_owner }}/librechat-dev-api:${{ github.sha }} - docker push ghcr.io/${{ github.repository_owner }}/librechat-dev-api:${{ github.sha }} - docker tag librechat-dev-api:latest ghcr.io/${{ github.repository_owner }}/librechat-dev-api:latest - docker push ghcr.io/${{ github.repository_owner }}/librechat-dev-api:latest - - docker tag librechat-dev:latest ghcr.io/${{ github.repository_owner }}/librechat-dev:${{ github.sha }} - docker push ghcr.io/${{ github.repository_owner }}/librechat-dev:${{ github.sha }} - docker tag librechat-dev:latest ghcr.io/${{ github.repository_owner }}/librechat-dev:latest - docker push ghcr.io/${{ github.repository_owner }}/librechat-dev:latest + # Build and push Docker images for each target + - name: Build and push Docker images + uses: docker/build-push-action@v5 + with: + context: . + file: ${{ matrix.file }} + push: true + tags: | + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }} + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest + platforms: linux/amd64,linux/arm64 + target: ${{ matrix.target }} diff --git a/.github/workflows/generate_embeddings.yml b/.github/workflows/generate_embeddings.yml new file mode 100644 index 00000000000..c514f9c1d6b --- /dev/null +++ b/.github/workflows/generate_embeddings.yml @@ -0,0 +1,20 @@ +name: 'generate_embeddings' +on: + workflow_dispatch: + push: + branches: + - main + paths: + - 'docs/**' + +jobs: + generate: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: supabase/embeddings-generator@v0.0.5 + with: + supabase-url: ${{ secrets.SUPABASE_URL }} + supabase-service-role-key: ${{ secrets.SUPABASE_SERVICE_ROLE_KEY }} + openai-key: ${{ secrets.OPENAI_DOC_EMBEDDINGS_KEY }} + docs-root-path: 'docs' \ No newline at end of file diff --git a/.github/workflows/latest-images-main.yml b/.github/workflows/latest-images-main.yml deleted file mode 100644 index 5149cecb0e6..00000000000 --- a/.github/workflows/latest-images-main.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Docker Compose Build on Main Branch - -on: - workflow_dispatch: # This line allows manual triggering - -jobs: - build: - runs-on: ubuntu-latest - - steps: - # Check out the repository - - name: Checkout - uses: actions/checkout@v4 - - # Set up Docker - - name: Set up Docker - uses: docker/setup-buildx-action@v3 - - # Log in to GitHub Container Registry - - name: Log in to GitHub Container Registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - # Run docker-compose build - - name: Build Docker images - run: | - cp .env.example .env - docker-compose build - docker build -f Dockerfile.multi --target api-build -t librechat-api . - - # Tag and push the images with the 'latest' tag - - name: Tag image and push - run: | - docker tag librechat:latest ghcr.io/${{ github.repository_owner }}/librechat:latest - docker push ghcr.io/${{ github.repository_owner }}/librechat:latest - docker tag librechat-api:latest ghcr.io/${{ github.repository_owner }}/librechat-api:latest - docker push ghcr.io/${{ github.repository_owner }}/librechat-api:latest diff --git a/.github/workflows/main-image-workflow.yml b/.github/workflows/main-image-workflow.yml new file mode 100644 index 00000000000..43c9d957534 --- /dev/null +++ b/.github/workflows/main-image-workflow.yml @@ -0,0 +1,69 @@ +name: Docker Compose Build Latest Main Image Tag (Manual Dispatch) + +on: + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: api-build + file: Dockerfile.multi + image_name: librechat-api + - target: node + file: Dockerfile + image_name: librechat + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Fetch tags and set the latest tag + run: | + git fetch --tags + echo "LATEST_TAG=$(git describe --tags `git rev-list --tags --max-count=1`)" >> $GITHUB_ENV + + # Set up QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Set up Docker Buildx + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Log in to GitHub Container Registry + - name: Log in to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Login to Docker Hub + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # Prepare the environment + - name: Prepare environment + run: | + cp .env.example .env + + # Build and push Docker images for each target + - name: Build and push Docker images + uses: docker/build-push-action@v5 + with: + context: . + file: ${{ matrix.file }} + push: true + tags: | + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ env.LATEST_TAG }} + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ env.LATEST_TAG }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest + platforms: linux/amd64,linux/arm64 + target: ${{ matrix.target }} diff --git a/.github/workflows/tag-images.yml b/.github/workflows/tag-images.yml new file mode 100644 index 00000000000..e90f43978ab --- /dev/null +++ b/.github/workflows/tag-images.yml @@ -0,0 +1,67 @@ +name: Docker Images Build on Tag + +on: + push: + tags: + - '*' + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: api-build + file: Dockerfile.multi + image_name: librechat-api + - target: node + file: Dockerfile + image_name: librechat + + steps: + # Check out the repository + - name: Checkout + uses: actions/checkout@v4 + + # Set up QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Set up Docker Buildx + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Log in to GitHub Container Registry + - name: Log in to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Login to Docker Hub + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # Prepare the environment + - name: Prepare environment + run: | + cp .env.example .env + + # Build and push Docker images for each target + - name: Build and push Docker images + uses: docker/build-push-action@v5 + with: + context: . + file: ${{ matrix.file }} + push: true + tags: | + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.ref_name }} + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.ref_name }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest + platforms: linux/amd64,linux/arm64 + target: ${{ matrix.target }} diff --git a/.gitignore b/.gitignore index 765de5cb799..c55115988b9 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ bower_components/ #config file librechat.yaml +librechat.yml # Environment .npmrc @@ -74,6 +75,7 @@ src/style - official.css config.local.ts **/storageState.json junit.xml +**/.venv/ # docker override file docker-compose.override.yaml @@ -88,4 +90,10 @@ auth.json /packages/ux-shared/ /images -!client/src/components/Nav/SettingsTabs/Data/ \ No newline at end of file +!client/src/components/Nav/SettingsTabs/Data/ + +# User uploads +uploads/ + +# owner +release/ \ No newline at end of file diff --git a/.husky/pre-commit b/.husky/pre-commit index af85628072b..67f5b002728 100755 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1,4 +1,4 @@ -#!/usr/bin/env sh +#!/usr/bin/env sh set -e . "$(dirname -- "$0")/_/husky.sh" [ -n "$CI" ] && exit 0 diff --git a/Dockerfile b/Dockerfile index edc79c2497a..fd087eae39d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,35 @@ +# v0.7.0 + # Base node image -FROM node:18-alpine AS node +FROM node:18-alpine3.18 AS node + +RUN apk add g++ make py3-pip +RUN npm install -g node-gyp +RUN apk --no-cache add curl -COPY . /app +RUN mkdir -p /app && chown node:node /app WORKDIR /app +USER node + +COPY --chown=node:node . . + # Allow mounting of these files, which have no default # values. RUN touch .env -# Install call deps - Install curl for health check -RUN apk --no-cache add curl && \ - npm ci +RUN npm config set fetch-retry-maxtimeout 600000 +RUN npm config set fetch-retries 5 +RUN npm config set fetch-retry-mintimeout 15000 +RUN npm install --no-audit # React client build ENV NODE_OPTIONS="--max-old-space-size=2048" RUN npm run frontend +# Create directories for the volumes to inherit +# the correct permissions +RUN mkdir -p /app/client/public/images /app/api/logs + # Node API setup EXPOSE 3080 ENV HOST=0.0.0.0 diff --git a/Dockerfile.multi b/Dockerfile.multi index 0d5ebec5e23..00ed37e3ef8 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -1,3 +1,5 @@ +# v0.7.0 + # Build API, Client and Data Provider FROM node:20-alpine AS base @@ -11,11 +13,12 @@ RUN npm run build # React client build FROM data-provider-build AS client-build WORKDIR /app/client -COPY ./client/ ./ +COPY ./client/package*.json ./ # Copy data-provider to client's node_modules RUN mkdir -p /app/client/node_modules/librechat-data-provider/ RUN cp -R /app/packages/data-provider/* /app/client/node_modules/librechat-data-provider/ RUN npm install +COPY ./client/ ./ ENV NODE_OPTIONS="--max-old-space-size=2048" RUN npm run build @@ -24,6 +27,8 @@ FROM data-provider-build AS api-build WORKDIR /app/api COPY api/package*.json ./ COPY api/ ./ +# Copy helper scripts +COPY config/ ./ # Copy data-provider to API's node_modules RUN mkdir -p /app/api/node_modules/librechat-data-provider/ RUN cp -R /app/packages/data-provider/* /app/api/node_modules/librechat-data-provider/ diff --git a/README.md b/README.md index 00cd890b073..901ddbc7c14 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@

- + - -

LibreChat

- +

+ LibreChat +

@@ -27,42 +27,48 @@

- - Deploy on Railway - -

- -

- - Deploy on Sealos - + + Deploy on Railway + + + Deploy on Zeabur + + + Deploy on Sealos +

# 📃 Features - - 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and 11-2023 updates - - 💬 Multimodal Chat: - - Upload and analyze images with GPT-4 and Gemini Vision 📸 - - More filetypes and Assistants API integration in Active Development 🚧 - - 🌎 Multilingual UI: - - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, - - Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands - - 🤖 AI model selection: OpenAI API, Azure, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins - - 💾 Create, Save, & Share Custom Presets - - 🔄 Edit, Resubmit, and Continue messages with conversation branching - - 📤 Export conversations as screenshots, markdown, text, json. - - 🔍 Search all messages/conversations - - 🔌 Plugins, including web access, image generation with DALL-E-3 and more - - 👥 Multi-User, Secure Authentication with Moderation and Token spend tools - - ⚙️ Configure Proxy, Reverse Proxy, Docker, many Deployment options, and completely Open-Source -[For a thorough review of our features, see our docs here](https://docs.librechat.ai/features/plugins/introduction.html) 📚 +- 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and latest updates +- 💬 Multimodal Chat: + - Upload and analyze images with Claude 3, GPT-4, and Gemini Vision 📸 + - Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, & Google. 🗃️ + - Advanced Agents with Files, Code Interpreter, Tools, and API Actions 🔦 + - Available through the [OpenAI Assistants API](https://platform.openai.com/docs/assistants/overview) 🌤️ + - Non-OpenAI Agents in Active Development 🚧 +- 🌎 Multilingual UI: + - English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, + - Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית +- 🤖 AI model selection: OpenAI, Azure OpenAI, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins, Assistants API (including Azure Assistants) +- 💾 Create, Save, & Share Custom Presets +- 🔄 Edit, Resubmit, and Continue messages with conversation branching +- 📤 Export conversations as screenshots, markdown, text, json. +- 🔍 Search all messages/conversations +- 🔌 Plugins, including web access, image generation with DALL-E-3 and more +- 👥 Multi-User, Secure Authentication with Moderation and Token spend tools +- ⚙️ Configure Proxy, Reverse Proxy, Docker, & many Deployment options +- 📖 Completely Open-Source & Built in Public +- 🧑‍🤝‍🧑 Community-driven development, support, and feedback +[For a thorough review of our features, see our docs here](https://docs.librechat.ai/features/plugins/introduction.html) 📚 ## 🪶 All-In-One AI Conversations with LibreChat + LibreChat brings together the future of assistant AIs with the revolutionary technology of OpenAI's ChatGPT. Celebrating the original styling, LibreChat gives you the ability to integrate multiple AI models. It also integrates and enhances original client features such as conversation and message search, prompt templates and plugins. With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform. - + [![Watch the video](https://img.youtube.com/vi/pNIOs1ovsXw/maxresdefault.jpg)](https://youtu.be/pNIOs1ovsXw) @@ -71,11 +77,13 @@ Click on the thumbnail to open the video☝️ --- ## 📚 Documentation + For more information on how to use our advanced features, install and configure our software, and access our guidelines and tutorials, please check out our documentation at [docs.librechat.ai](https://docs.librechat.ai) --- -## 📝 Changelog +## 📝 Changelog + Keep up with the latest updates by visiting the releases page - [Releases](https://github.com/danny-avila/LibreChat/releases) **⚠️ [Breaking Changes](docs/general_info/breaking_changes.md)** @@ -96,14 +104,15 @@ Please consult the breaking changes before updating. --- ## ✨ Contributions + Contributions, suggestions, bug reports and fixes are welcome! -For new features, components, or extensions, please open an issue and discuss before sending a PR. +For new features, components, or extensions, please open an issue and discuss before sending a PR. --- -💖 This project exists in its current state thanks to all the people who contribute ---- +## 💖 This project exists in its current state thanks to all the people who contribute + diff --git a/api/app/chatgpt-browser.js b/api/app/chatgpt-browser.js index 467e67785d3..818661555dc 100644 --- a/api/app/chatgpt-browser.js +++ b/api/app/chatgpt-browser.js @@ -1,5 +1,6 @@ require('dotenv').config(); const { KeyvFile } = require('keyv-file'); +const { Constants } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('../server/services/UserService'); const browserClient = async ({ @@ -48,7 +49,7 @@ const browserClient = async ({ options = { ...options, parentMessageId, conversationId }; } - if (parentMessageId === '00000000-0000-0000-0000-000000000000') { + if (parentMessageId === Constants.NO_PARENT) { delete options.conversationId; } diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 0441a49334e..6d478defab0 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -1,6 +1,19 @@ const Anthropic = require('@anthropic-ai/sdk'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); -const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); +const { + getResponseSender, + EModelEndpoint, + validateVisionModel, +} = require('librechat-data-provider'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { + titleFunctionPrompt, + parseTitleFromPrompt, + truncateText, + formatMessage, + createContextHandlers, +} = require('./prompts'); +const spendTokens = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -10,12 +23,20 @@ const AI_PROMPT = '\n\nAssistant:'; const tokenizersCache = {}; +/** Helper function to introduce a delay before retrying */ +function delayBeforeRetry(attempts, baseDelay = 1000) { + return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts)); +} + class AnthropicClient extends BaseClient { constructor(apiKey, options = {}) { super(apiKey, options); this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY; this.userLabel = HUMAN_PROMPT; this.assistantLabel = AI_PROMPT; + this.contextStrategy = options.contextStrategy + ? options.contextStrategy.toLowerCase() + : 'discard'; this.setOptions(options); } @@ -47,6 +68,12 @@ class AnthropicClient extends BaseClient { stop: modelOptions.stop, // no stop method for now }; + this.isClaude3 = this.modelOptions.model.includes('claude-3'); + this.useMessages = this.isClaude3 || !!this.options.attachments; + + this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229'; + this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); + this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000; this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500; @@ -87,7 +114,12 @@ class AnthropicClient extends BaseClient { return this; } + /** + * Get the initialized Anthropic client. + * @returns {Anthropic} The Anthropic client instance. + */ getClient() { + /** @type {Anthropic.default.RequestOptions} */ const options = { apiKey: this.apiKey, }; @@ -99,6 +131,75 @@ class AnthropicClient extends BaseClient { return new Anthropic(options); } + getTokenCountForResponse(response) { + return this.getTokenCountForMessage({ + role: 'assistant', + content: response.text, + }); + } + + /** + * + * Checks if the model is a vision model based on request attachments and sets the appropriate options: + * - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request. + * - Sets `this.isVisionModel` to `true` if vision request. + * - Deletes `this.modelOptions.stop` if vision request. + * @param {MongoFile[]} attachments + */ + checkVisionRequest(attachments) { + const availableModels = this.options.modelsConfig?.[EModelEndpoint.anthropic]; + this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); + + const visionModelAvailable = availableModels?.includes(this.defaultVisionModel); + if ( + attachments && + attachments.some((file) => file?.type && file?.type?.includes('image')) && + visionModelAvailable && + !this.isVisionModel + ) { + this.modelOptions.model = this.defaultVisionModel; + this.isVisionModel = true; + } + } + + /** + * Calculate the token cost in tokens for an image based on its dimensions and detail level. + * + * For reference, see: https://docs.anthropic.com/claude/docs/vision#image-costs + * + * @param {Object} image - The image object. + * @param {number} image.width - The width of the image. + * @param {number} image.height - The height of the image. + * @returns {number} The calculated token cost measured by tokens. + * + */ + calculateImageTokenCost({ width, height }) { + return Math.ceil((width * height) / 750); + } + + async addImageURLs(message, attachments) { + const { files, image_urls } = await encodeAndFormat( + this.options.req, + attachments, + EModelEndpoint.anthropic, + ); + message.image_urls = image_urls.length ? image_urls : undefined; + return files; + } + + async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) { + await spendTokens( + { + context, + user: this.user, + conversationId: this.conversationId, + model: model ?? this.modelOptions.model, + endpointTokenConfig: this.options.endpointTokenConfig, + }, + { promptTokens, completionTokens }, + ); + } + async buildMessages(messages, parentMessageId) { const orderedMessages = this.constructor.getMessagesForConversation({ messages, @@ -107,28 +208,145 @@ class AnthropicClient extends BaseClient { logger.debug('[AnthropicClient] orderedMessages', { orderedMessages, parentMessageId }); - const formattedMessages = orderedMessages.map((message) => ({ - author: message.isCreatedByUser ? this.userLabel : this.assistantLabel, - content: message?.content ?? message.text, - })); + if (this.options.attachments) { + const attachments = await this.options.attachments; + const images = attachments.filter((file) => file.type.includes('image')); + + if (images.length && !this.isVisionModel) { + throw new Error('Images are only supported with the Claude 3 family of models'); + } + + const latestMessage = orderedMessages[orderedMessages.length - 1]; + + if (this.message_file_map) { + this.message_file_map[latestMessage.messageId] = attachments; + } else { + this.message_file_map = { + [latestMessage.messageId]: attachments, + }; + } + + const files = await this.addImageURLs(latestMessage, attachments); + + this.options.attachments = files; + } + + if (this.message_file_map) { + this.contextHandlers = createContextHandlers( + this.options.req, + orderedMessages[orderedMessages.length - 1].text, + ); + } + + const formattedMessages = orderedMessages.map((message, i) => { + const formattedMessage = this.useMessages + ? formatMessage({ + message, + endpoint: EModelEndpoint.anthropic, + }) + : { + author: message.isCreatedByUser ? this.userLabel : this.assistantLabel, + content: message?.content ?? message.text, + }; + + const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount; + /* If tokens were never counted, or, is a Vision request and the message has files, count again */ + if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) { + orderedMessages[i].tokenCount = this.getTokenCountForMessage(formattedMessage); + } + + /* If message has files, calculate image token cost */ + if (this.message_file_map && this.message_file_map[message.messageId]) { + const attachments = this.message_file_map[message.messageId]; + for (const file of attachments) { + if (file.embedded) { + this.contextHandlers?.processFile(file); + continue; + } + + orderedMessages[i].tokenCount += this.calculateImageTokenCost({ + width: file.width, + height: file.height, + }); + } + } + + formattedMessage.tokenCount = orderedMessages[i].tokenCount; + return formattedMessage; + }); + + if (this.contextHandlers) { + this.augmentedPrompt = await this.contextHandlers.createContext(); + this.options.promptPrefix = this.augmentedPrompt + (this.options.promptPrefix ?? ''); + } + + let { context: messagesInWindow, remainingContextTokens } = + await this.getMessagesWithinTokenLimit(formattedMessages); + + const tokenCountMap = orderedMessages + .slice(orderedMessages.length - messagesInWindow.length) + .reduce((map, message, index) => { + const { messageId } = message; + if (!messageId) { + return map; + } + + map[messageId] = orderedMessages[index].tokenCount; + return map; + }, {}); + + logger.debug('[AnthropicClient]', { + messagesInWindow: messagesInWindow.length, + remainingContextTokens, + }); let lastAuthor = ''; let groupedMessages = []; - for (let message of formattedMessages) { + for (let i = 0; i < messagesInWindow.length; i++) { + const message = messagesInWindow[i]; + const author = message.role ?? message.author; // If last author is not same as current author, add to new group - if (lastAuthor !== message.author) { - groupedMessages.push({ - author: message.author, + if (lastAuthor !== author) { + const newMessage = { content: [message.content], - }); - lastAuthor = message.author; + }; + + if (message.role) { + newMessage.role = message.role; + } else { + newMessage.author = message.author; + } + + groupedMessages.push(newMessage); + lastAuthor = author; // If same author, append content to the last group } else { groupedMessages[groupedMessages.length - 1].content.push(message.content); } } + groupedMessages = groupedMessages.map((msg, i) => { + const isLast = i === groupedMessages.length - 1; + if (msg.content.length === 1) { + const content = msg.content[0]; + return { + ...msg, + // reason: final assistant content cannot end with trailing whitespace + content: + isLast && this.useMessages && msg.role === 'assistant' && typeof content === 'string' + ? content?.trim() + : content, + }; + } + + if (!this.useMessages && msg.tokenCount) { + delete msg.tokenCount; + } + + return msg; + }); + let identityPrefix = ''; if (this.options.userLabel) { identityPrefix = `\nHuman's name: ${this.options.userLabel}`; @@ -154,9 +372,10 @@ class AnthropicClient extends BaseClient { // Prompt AI to respond, empty if last message was from AI let isEdited = lastAuthor === this.assistantLabel; const promptSuffix = isEdited ? '' : `${promptPrefix}${this.assistantLabel}\n`; - let currentTokenCount = isEdited - ? this.getTokenCount(promptPrefix) - : this.getTokenCount(promptSuffix); + let currentTokenCount = + isEdited || this.useMessages + ? this.getTokenCount(promptPrefix) + : this.getTokenCount(promptSuffix); let promptBody = ''; const maxTokenCount = this.maxPromptTokens; @@ -224,7 +443,69 @@ class AnthropicClient extends BaseClient { return true; }; - await buildPromptBody(); + const messagesPayload = []; + const buildMessagesPayload = async () => { + let canContinue = true; + + if (promptPrefix) { + this.systemMessage = promptPrefix; + } + + while (currentTokenCount < maxTokenCount && groupedMessages.length > 0 && canContinue) { + const message = groupedMessages.pop(); + + let tokenCountForMessage = message.tokenCount ?? this.getTokenCountForMessage(message); + + const newTokenCount = currentTokenCount + tokenCountForMessage; + const exceededMaxCount = newTokenCount > maxTokenCount; + + if (exceededMaxCount && messagesPayload.length === 0) { + throw new Error( + `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, + ); + } else if (exceededMaxCount) { + canContinue = false; + break; + } + + delete message.tokenCount; + messagesPayload.unshift(message); + currentTokenCount = newTokenCount; + + // Switch off isEdited after using it once + if (isEdited && message.role === 'assistant') { + isEdited = false; + } + + // Wait for next tick to avoid blocking the event loop + await new Promise((resolve) => setImmediate(resolve)); + } + }; + + const processTokens = () => { + // Add 2 tokens for metadata after all messages have been counted. + currentTokenCount += 2; + + // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. + this.modelOptions.maxOutputTokens = Math.min( + this.maxContextTokens - currentTokenCount, + this.maxResponseTokens, + ); + }; + + if (this.modelOptions.model.startsWith('claude-3')) { + await buildMessagesPayload(); + processTokens(); + return { + prompt: messagesPayload, + context: messagesInWindow, + promptTokens: currentTokenCount, + tokenCountMap, + }; + } else { + await buildPromptBody(); + processTokens(); + } if (nextMessage.remove) { promptBody = promptBody.replace(nextMessage.messageString, ''); @@ -234,22 +515,26 @@ class AnthropicClient extends BaseClient { let prompt = `${promptBody}${promptSuffix}`; - // Add 2 tokens for metadata after all messages have been counted. - currentTokenCount += 2; - - // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. - this.modelOptions.maxOutputTokens = Math.min( - this.maxContextTokens - currentTokenCount, - this.maxResponseTokens, - ); - - return { prompt, context }; + return { prompt, context, promptTokens: currentTokenCount, tokenCountMap }; } getCompletion() { logger.debug('AnthropicClient doesn\'t use getCompletion (all handled in sendCompletion)'); } + /** + * Creates a message or completion response using the Anthropic client. + * @param {Anthropic} client - The Anthropic client instance. + * @param {Anthropic.default.MessageCreateParams | Anthropic.default.CompletionCreateParams} options - The options for the message or completion. + * @param {boolean} useMessages - Whether to use messages or completions. Defaults to `this.useMessages`. + * @returns {Promise} The response from the Anthropic client. + */ + async createResponse(client, options, useMessages) { + return useMessages ?? this.useMessages + ? await client.messages.create(options) + : await client.completions.create(options); + } + async sendCompletion(payload, { onProgress, abortController }) { if (!abortController) { abortController = new AbortController(); @@ -279,36 +564,88 @@ class AnthropicClient extends BaseClient { topP: top_p, topK: top_k, } = this.modelOptions; + const requestOptions = { - prompt: payload, model, stream: stream || true, - max_tokens_to_sample: maxOutputTokens || 1500, stop_sequences, temperature, metadata, top_p, top_k, }; + + if (this.useMessages) { + requestOptions.messages = payload; + requestOptions.max_tokens = maxOutputTokens || 1500; + } else { + requestOptions.prompt = payload; + requestOptions.max_tokens_to_sample = maxOutputTokens || 1500; + } + + if (this.systemMessage) { + requestOptions.system = this.systemMessage; + } + logger.debug('[AnthropicClient]', { ...requestOptions }); - const response = await client.completions.create(requestOptions); - signal.addEventListener('abort', () => { - logger.debug('[AnthropicClient] message aborted!'); - response.controller.abort(); - }); + const handleChunk = (currentChunk) => { + if (currentChunk) { + text += currentChunk; + onProgress(currentChunk); + } + }; + + const maxRetries = 3; + async function processResponse() { + let attempts = 0; + + while (attempts < maxRetries) { + let response; + try { + response = await this.createResponse(client, requestOptions); + + signal.addEventListener('abort', () => { + logger.debug('[AnthropicClient] message aborted!'); + if (response.controller?.abort) { + response.controller.abort(); + } + }); + + for await (const completion of response) { + // Handle each completion as before + if (completion?.delta?.text) { + handleChunk(completion.delta.text); + } else if (completion.completion) { + handleChunk(completion.completion); + } + } - for await (const completion of response) { - // Uncomment to debug message stream - // logger.debug(completion); - text += completion.completion; - onProgress(completion.completion); + // Successful processing, exit loop + break; + } catch (error) { + attempts += 1; + logger.warn( + `User: ${this.user} | Anthropic Request ${attempts} failed: ${error.message}`, + ); + + if (attempts < maxRetries) { + await delayBeforeRetry(attempts, 350); + } else { + throw new Error(`Operation failed after ${maxRetries} attempts: ${error.message}`); + } + } finally { + signal.removeEventListener('abort', () => { + logger.debug('[AnthropicClient] message aborted!'); + if (response.controller?.abort) { + response.controller.abort(); + } + }); + } + } } - signal.removeEventListener('abort', () => { - logger.debug('[AnthropicClient] message aborted!'); - response.controller.abort(); - }); + await processResponse.bind(this)(); return text.trim(); } @@ -317,6 +654,7 @@ class AnthropicClient extends BaseClient { return { promptPrefix: this.options.promptPrefix, modelLabel: this.options.modelLabel, + resendFiles: this.options.resendFiles, ...this.modelOptions, }; } @@ -342,6 +680,78 @@ class AnthropicClient extends BaseClient { getTokenCount(text) { return this.gptEncoder.encode(text, 'all').length; } + + /** + * Generates a concise title for a conversation based on the user's input text and response. + * Involves sending a chat completion request with specific instructions for title generation. + * + * This function capitlizes on [Anthropic's function calling training](https://docs.anthropic.com/claude/docs/functions-external-tools). + * + * @param {Object} params - The parameters for the conversation title generation. + * @param {string} params.text - The user's input. + * @param {string} [params.responseText=''] - The AI's immediate response to the user. + * + * @returns {Promise} A promise that resolves to the generated conversation title. + * In case of failure, it will return the default title, "New Chat". + */ + async titleConvo({ text, responseText = '' }) { + let title = 'New Chat'; + const convo = ` + ${truncateText(text)} + + + ${JSON.stringify(truncateText(responseText))} + `; + + const { ANTHROPIC_TITLE_MODEL } = process.env ?? {}; + const model = this.options.titleModel ?? ANTHROPIC_TITLE_MODEL ?? 'claude-3-haiku-20240307'; + const system = titleFunctionPrompt; + + const titleChatCompletion = async () => { + const content = ` + ${convo} + + + Please generate a title for this conversation.`; + + const titleMessage = { role: 'user', content }; + const requestOptions = { + model, + temperature: 0.3, + max_tokens: 1024, + system, + stop_sequences: ['\n\nHuman:', '\n\nAssistant', ''], + messages: [titleMessage], + }; + + try { + const response = await this.createResponse(this.getClient(), requestOptions, true); + let promptTokens = response?.usage?.input_tokens; + let completionTokens = response?.usage?.output_tokens; + if (!promptTokens) { + promptTokens = this.getTokenCountForMessage(titleMessage); + promptTokens += this.getTokenCountForMessage({ role: 'system', content: system }); + } + if (!completionTokens) { + completionTokens = this.getTokenCountForMessage(response.content[0]); + } + await this.recordTokenUsage({ + model, + promptTokens, + completionTokens, + context: 'title', + }); + const text = response.content[0].text; + title = parseTitleFromPrompt(text); + } catch (e) { + logger.error('[AnthropicClient] There was an issue generating the title', e); + } + }; + + await titleChatCompletion(); + logger.debug('[AnthropicClient] Convo Title: ' + title); + return title; + } } module.exports = AnthropicClient; diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index aa39084b9fa..f7ed3b9cf18 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -1,8 +1,9 @@ const crypto = require('crypto'); -const { supportsBalanceCheck } = require('librechat-data-provider'); +const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const checkBalance = require('~/models/checkBalance'); +const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -22,7 +23,7 @@ class BaseClient { throw new Error('Method \'setOptions\' must be implemented.'); } - getCompletion() { + async getCompletion() { throw new Error('Method \'getCompletion\' must be implemented.'); } @@ -46,10 +47,6 @@ class BaseClient { logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response); } - async addPreviousAttachments(messages) { - return messages; - } - async recordTokenUsage({ promptTokens, completionTokens }) { logger.debug('`[BaseClient] recordTokenUsage` not implemented.', { promptTokens, @@ -77,7 +74,7 @@ class BaseClient { const saveOptions = this.getSaveOptions(); this.abortController = opts.abortController ?? new AbortController(); const conversationId = opts.conversationId ?? crypto.randomUUID(); - const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000'; + const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT; const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID(); let responseMessageId = opts.responseMessageId ?? crypto.randomUUID(); let head = isEdited ? responseMessageId : parentMessageId; @@ -428,7 +425,10 @@ class BaseClient { await this.saveMessageToDatabase(userMessage, saveOptions, user); } - if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[this.options.endpoint]) { + if ( + isEnabled(process.env.CHECK_BALANCE) && + supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint] + ) { await checkBalance({ req: this.options.req, res: this.options.res, @@ -438,11 +438,14 @@ class BaseClient { amount: promptTokens, model: this.modelOptions.model, endpoint: this.options.endpoint, + endpointTokenConfig: this.options.endpointTokenConfig, }, }); } const completion = await this.sendCompletion(payload, opts); + this.abortController.requestCompleted = true; + const responseMessage = { messageId: responseMessageId, conversationId, @@ -453,6 +456,7 @@ class BaseClient { sender: this.sender, text: addSpaceIfNeeded(generation) + completion, promptTokens, + ...(this.metadata ?? {}), }; if ( @@ -548,7 +552,7 @@ class BaseClient { * * Each message object should have an 'id' or 'messageId' property and may have a 'parentMessageId' property. * The 'parentMessageId' is the ID of the message that the current message is a reply to. - * If 'parentMessageId' is not present, null, or is '00000000-0000-0000-0000-000000000000', + * If 'parentMessageId' is not present, null, or is Constants.NO_PARENT, * the message is considered a root message. * * @param {Object} options - The options for the function. @@ -603,9 +607,7 @@ class BaseClient { } currentMessageId = - message.parentMessageId === '00000000-0000-0000-0000-000000000000' - ? null - : message.parentMessageId; + message.parentMessageId === Constants.NO_PARENT ? null : message.parentMessageId; } orderedMessages.reverse(); @@ -679,6 +681,54 @@ class BaseClient { return await this.sendCompletion(payload, opts); } + + /** + * + * @param {TMessage[]} _messages + * @returns {Promise} + */ + async addPreviousAttachments(_messages) { + if (!this.options.resendFiles) { + return _messages; + } + + /** + * + * @param {TMessage} message + */ + const processMessage = async (message) => { + if (!this.message_file_map) { + /** @type {Record */ + this.message_file_map = {}; + } + + const fileIds = message.files.map((file) => file.file_id); + const files = await getFiles({ + file_id: { $in: fileIds }, + }); + + await this.addImageURLs(message, files); + + this.message_file_map[message.messageId] = files; + return message; + }; + + const promises = []; + + for (const message of _messages) { + if (!message.files) { + promises.push(message); + continue; + } + + promises.push(processMessage(message)); + } + + const messages = await Promise.all(promises); + + this.checkVisionRequest(Object.values(this.message_file_map ?? {}).flat()); + return messages; + } } module.exports = BaseClient; diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index c1ae54fdf08..d218849513a 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -1,9 +1,19 @@ -const crypto = require('crypto'); const Keyv = require('keyv'); +const crypto = require('crypto'); +const { + EModelEndpoint, + resolveHeaders, + CohereConstants, + mapModelToAzureConfig, +} = require('librechat-data-provider'); +const { CohereClient } = require('cohere-ai'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { createCoherePayload } = require('./llm'); const { Agent, ProxyAgent } = require('undici'); const BaseClient = require('./BaseClient'); +const { logger } = require('~/config'); +const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); const CHATGPT_MODEL = 'gpt-3.5-turbo'; const tokenizersCache = {}; @@ -140,11 +150,13 @@ class ChatGPTClient extends BaseClient { return tokenizer; } - async getCompletion(input, onProgress, abortController = null) { + /** @type {getCompletion} */ + async getCompletion(input, onProgress, onTokenProgress, abortController = null) { if (!abortController) { abortController = new AbortController(); } - const modelOptions = { ...this.modelOptions }; + + let modelOptions = { ...this.modelOptions }; if (typeof onProgress === 'function') { modelOptions.stream = true; } @@ -159,56 +171,176 @@ class ChatGPTClient extends BaseClient { } const { debug } = this.options; - const url = this.completionsUrl; + let baseURL = this.completionsUrl; if (debug) { console.debug(); - console.debug(url); + console.debug(baseURL); console.debug(modelOptions); console.debug(); } - if (this.azure || this.options.azure) { - // Azure does not accept `model` in the body, so we need to remove it. - delete modelOptions.model; - } - const opts = { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify(modelOptions), dispatcher: new Agent({ bodyTimeout: 0, headersTimeout: 0, }), }; - if (this.apiKey && this.options.azure) { - opts.headers['api-key'] = this.apiKey; + if (this.isVisionModel) { + modelOptions.max_tokens = 4000; + } + + /** @type {TAzureConfig | undefined} */ + const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; + + const isAzure = this.azure || this.options.azure; + if ( + (isAzure && this.isVisionModel && azureConfig) || + (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) + ) { + const { modelGroupMap, groupMap } = azureConfig; + const { + azureOptions, + baseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName: modelOptions.model, + modelGroupMap, + groupMap, + }); + opts.headers = resolveHeaders(headers); + this.langchainProxy = extractBaseURL(baseURL); + this.apiKey = azureOptions.azureOpenAIApiKey; + + const groupName = modelGroupMap[modelOptions.model].group; + this.options.addParams = azureConfig.groupMap[groupName].addParams; + this.options.dropParams = azureConfig.groupMap[groupName].dropParams; + // Note: `forcePrompt` not re-assigned as only chat models are vision models + + this.azure = !serverless && azureOptions; + this.azureEndpoint = + !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); + } + + if (this.options.headers) { + opts.headers = { ...opts.headers, ...this.options.headers }; + } + + if (isAzure) { + // Azure does not accept `model` in the body, so we need to remove it. + delete modelOptions.model; + + baseURL = this.langchainProxy + ? constructAzureURL({ + baseURL: this.langchainProxy, + azureOptions: this.azure, + }) + : this.azureEndpoint.split(/(? msg.role === 'system'); + + if (systemMessageIndex > 0) { + const [systemMessage] = messages.splice(systemMessageIndex, 1); + messages.unshift(systemMessage); + } + + modelOptions.messages = messages; + + if (messages.length === 1 && messages[0].role === 'system') { + modelOptions.messages[0].role = 'user'; + } + } + + if (this.options.addParams && typeof this.options.addParams === 'object') { + modelOptions = { + ...modelOptions, + ...this.options.addParams, + }; + logger.debug('[ChatGPTClient] chatCompletion: added params', { + addParams: this.options.addParams, + modelOptions, + }); + } + + if (this.options.dropParams && Array.isArray(this.options.dropParams)) { + this.options.dropParams.forEach((param) => { + delete modelOptions[param]; + }); + logger.debug('[ChatGPTClient] chatCompletion: dropped params', { + dropParams: this.options.dropParams, + modelOptions, + }); + } + + if (baseURL.startsWith(CohereConstants.API_URL)) { + const payload = createCoherePayload({ modelOptions }); + return await this.cohereChatCompletion({ payload, onTokenProgress }); + } + + if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) { + baseURL = baseURL.split('v1')[0] + 'v1/completions'; + } else if ( + baseURL.includes('v1') && + !baseURL.includes('/chat/completions') && + this.isChatCompletion + ) { + baseURL = baseURL.split('v1')[0] + 'v1/chat/completions'; + } + + const BASE_URL = new URL(baseURL); + if (opts.defaultQuery) { + Object.entries(opts.defaultQuery).forEach(([key, value]) => { + BASE_URL.searchParams.append(key, value); + }); + delete opts.defaultQuery; + } + + const completionsURL = BASE_URL.toString(); + opts.body = JSON.stringify(modelOptions); + if (modelOptions.stream) { // eslint-disable-next-line no-async-promise-executor return new Promise(async (resolve, reject) => { try { let done = false; - await fetchEventSource(url, { + await fetchEventSource(completionsURL, { ...opts, signal: abortController.signal, async onopen(response) { @@ -236,7 +368,6 @@ class ChatGPTClient extends BaseClient { // workaround for private API not sending [DONE] event if (!done) { onProgress('[DONE]'); - abortController.abort(); resolve(); } }, @@ -249,14 +380,13 @@ class ChatGPTClient extends BaseClient { }, onmessage(message) { if (debug) { - // console.debug(message); + console.debug(message); } if (!message.data || message.event === 'ping') { return; } if (message.data === '[DONE]') { onProgress('[DONE]'); - abortController.abort(); resolve(); done = true; return; @@ -269,7 +399,7 @@ class ChatGPTClient extends BaseClient { } }); } - const response = await fetch(url, { + const response = await fetch(completionsURL, { ...opts, signal: abortController.signal, }); @@ -287,6 +417,35 @@ class ChatGPTClient extends BaseClient { return response.json(); } + /** @type {cohereChatCompletion} */ + async cohereChatCompletion({ payload, onTokenProgress }) { + const cohere = new CohereClient({ + token: this.apiKey, + environment: this.completionsUrl, + }); + + if (!payload.stream) { + const chatResponse = await cohere.chat(payload); + return chatResponse.text; + } + + const chatStream = await cohere.chatStream(payload); + let reply = ''; + for await (const message of chatStream) { + if (!message) { + continue; + } + + if (message.eventType === 'text-generation' && message.text) { + onTokenProgress(message.text); + } else if (message.eventType === 'stream-end' && message.response) { + reply = message.response.text; + } + } + + return reply; + } + async generateTitle(userMessage, botMessage) { const instructionsPayload = { role: 'system', diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 950cc8d1116..c5edcb275a8 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -4,16 +4,17 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai'); const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema'); -const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { + validateVisionModel, getResponseSender, - EModelEndpoint, endpointSettings, + EModelEndpoint, AuthKeys, } = require('librechat-data-provider'); +const { encodeAndFormat } = require('~/server/services/Files/images'); +const { formatMessage, createContextHandlers } = require('./prompts'); const { getModelMaxTokens } = require('~/utils'); -const { formatMessage } = require('./prompts'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -123,18 +124,11 @@ class GoogleClient extends BaseClient { // stop: modelOptions.stop // no stop method for now }; - if (this.options.attachments) { - this.modelOptions.model = 'gemini-pro-vision'; - } + this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); // TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google this.isGenerativeModel = this.modelOptions.model.includes('gemini'); - this.isVisionModel = validateVisionModel(this.modelOptions.model); const { isGenerativeModel } = this; - if (this.isVisionModel && !this.options.attachments) { - this.modelOptions.model = 'gemini-pro'; - this.isVisionModel = false; - } this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat'); const { isChatModel } = this; this.isTextModel = @@ -219,6 +213,33 @@ class GoogleClient extends BaseClient { return this; } + /** + * + * Checks if the model is a vision model based on request attachments and sets the appropriate options: + * @param {MongoFile[]} attachments + */ + checkVisionRequest(attachments) { + /* Validation vision request */ + this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision'; + const availableModels = this.options.modelsConfig?.[EModelEndpoint.google]; + this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); + + if ( + attachments && + attachments.some((file) => file?.type && file?.type?.includes('image')) && + availableModels?.includes(this.defaultVisionModel) && + !this.isVisionModel + ) { + this.modelOptions.model = this.defaultVisionModel; + this.isVisionModel = true; + } + + if (this.isVisionModel && !attachments) { + this.modelOptions.model = 'gemini-pro'; + this.isVisionModel = false; + } + } + formatMessages() { return ((message) => ({ author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), @@ -226,18 +247,45 @@ class GoogleClient extends BaseClient { })).bind(this); } - async buildVisionMessages(messages = [], parentMessageId) { - const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId); - const attachments = await this.options.attachments; + /** + * + * Adds image URLs to the message object and returns the files + * + * @param {TMessage[]} messages + * @param {MongoFile[]} files + * @returns {Promise} + */ + async addImageURLs(message, attachments) { const { files, image_urls } = await encodeAndFormat( this.options.req, - attachments.filter((file) => file.type.includes('image')), + attachments, EModelEndpoint.google, ); + message.image_urls = image_urls.length ? image_urls : undefined; + return files; + } + async buildVisionMessages(messages = [], parentMessageId) { + const attachments = await this.options.attachments; const latestMessage = { ...messages[messages.length - 1] }; + this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text); + + if (this.contextHandlers) { + for (const file of attachments) { + if (file.embedded) { + this.contextHandlers?.processFile(file); + continue; + } + } + + this.augmentedPrompt = await this.contextHandlers.createContext(); + this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix; + } + + const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId); + + const files = await this.addImageURLs(latestMessage, attachments); - latestMessage.image_urls = image_urls; this.options.attachments = files; latestMessage.text = prompt; @@ -264,7 +312,7 @@ class GoogleClient extends BaseClient { ); } - if (this.options.attachments) { + if (this.options.attachments && this.isGenerativeModel) { return this.buildVisionMessages(messages, parentMessageId); } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index ca0c8d84248..f66afda4abd 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,21 +1,35 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { getResponseSender, ImageDetailCost, ImageDetail } = require('librechat-data-provider'); +const { + ImageDetail, + EModelEndpoint, + resolveHeaders, + ImageDetailCost, + CohereConstants, + getResponseSender, + validateVisionModel, + mapModelToAzureConfig, +} = require('librechat-data-provider'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { - getModelMaxTokens, - genAzureChatCompletion, extractBaseURL, constructAzureURL, + getModelMaxTokens, + genAzureChatCompletion, } = require('~/utils'); -const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); -const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); +const { + truncateText, + formatMessage, + createContextHandlers, + CUT_OFF_PROMPT, + titleInstruction, +} = require('./prompts'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { handleOpenAIErrors } = require('./tools/util'); const spendTokens = require('~/models/spendTokens'); const { createLLM, RunManager } = require('./llm'); const ChatGPTClient = require('./ChatGPTClient'); const { isEnabled } = require('~/server/utils'); -const { getFiles } = require('~/models/File'); const { summaryBuffer } = require('./memory'); const { runTitleChain } = require('./chains'); const { tokenSplit } = require('./document'); @@ -32,7 +46,10 @@ class OpenAIClient extends BaseClient { super(apiKey, options); this.ChatGPTClient = new ChatGPTClient(); this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this); + /** @type {getCompletion} */ this.getCompletion = this.ChatGPTClient.getCompletion.bind(this); + /** @type {cohereChatCompletion} */ + this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this); this.contextStrategy = options.contextStrategy ? options.contextStrategy.toLowerCase() : 'discard'; @@ -40,6 +57,10 @@ class OpenAIClient extends BaseClient { /** @type {AzureOptions} */ this.azure = options.azure || false; this.setOptions(options); + this.metadata = {}; + + /** @type {string | undefined} - The API Completions URL */ + this.completionsUrl; } // TODO: PluginsClient calls this 3x, unneeded @@ -83,7 +104,12 @@ class OpenAIClient extends BaseClient { }; } - this.checkVisionRequest(this.options.attachments); + this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview'; + if (typeof this.options.attachments?.then === 'function') { + this.options.attachments.then((attachments) => this.checkVisionRequest(attachments)); + } else { + this.checkVisionRequest(this.options.attachments); + } const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; if (OPENROUTER_API_KEY && !this.azure) { @@ -131,7 +157,13 @@ class OpenAIClient extends BaseClient { const { isChatGptModel } = this; this.isUnofficialChatGptModel = model.startsWith('text-chat') || model.startsWith('text-davinci-002-render'); - this.maxContextTokens = getModelMaxTokens(model) ?? 4095; // 1 less than maximum + + this.maxContextTokens = + getModelMaxTokens( + model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? 4095; // 1 less than maximum if (this.shouldSummarize) { this.maxContextTokens = Math.floor(this.maxContextTokens / 2); @@ -208,13 +240,20 @@ class OpenAIClient extends BaseClient { * - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request. * - Sets `this.isVisionModel` to `true` if vision request. * - Deletes `this.modelOptions.stop` if vision request. - * @param {Array | MongoFile[]> | Record} attachments + * @param {MongoFile[]} attachments */ checkVisionRequest(attachments) { - this.isVisionModel = validateVisionModel(this.modelOptions.model); + const availableModels = this.options.modelsConfig?.[this.options.endpoint]; + this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); - if (attachments && !this.isVisionModel) { - this.modelOptions.model = 'gpt-4-vision-preview'; + const visionModelAvailable = availableModels?.includes(this.defaultVisionModel); + if ( + attachments && + attachments.some((file) => file?.type && file?.type?.includes('image')) && + visionModelAvailable && + !this.isVisionModel + ) { + this.modelOptions.model = this.defaultVisionModel; this.isVisionModel = true; } @@ -349,7 +388,7 @@ class OpenAIClient extends BaseClient { return { chatGptLabel: this.options.chatGptLabel, promptPrefix: this.options.promptPrefix, - resendImages: this.options.resendImages, + resendFiles: this.options.resendFiles, imageDetail: this.options.imageDetail, ...this.modelOptions, }; @@ -363,54 +402,6 @@ class OpenAIClient extends BaseClient { }; } - /** - * - * @param {TMessage[]} _messages - * @returns {TMessage[]} - */ - async addPreviousAttachments(_messages) { - if (!this.options.resendImages) { - return _messages; - } - - /** - * - * @param {TMessage} message - */ - const processMessage = async (message) => { - if (!this.message_file_map) { - /** @type {Record */ - this.message_file_map = {}; - } - - const fileIds = message.files.map((file) => file.file_id); - const files = await getFiles({ - file_id: { $in: fileIds }, - }); - - await this.addImageURLs(message, files); - - this.message_file_map[message.messageId] = files; - return message; - }; - - const promises = []; - - for (const message of _messages) { - if (!message.files) { - promises.push(message); - continue; - } - - promises.push(processMessage(message)); - } - - const messages = await Promise.all(promises); - - this.checkVisionRequest(this.message_file_map); - return messages; - } - /** * * Adds image URLs to the message object and returns the files @@ -421,8 +412,7 @@ class OpenAIClient extends BaseClient { */ async addImageURLs(message, attachments) { const { files, image_urls } = await encodeAndFormat(this.options.req, attachments); - - message.image_urls = image_urls; + message.image_urls = image_urls.length ? image_urls : undefined; return files; } @@ -450,23 +440,9 @@ class OpenAIClient extends BaseClient { let promptTokens; promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); - if (promptPrefix) { - promptPrefix = `Instructions:\n${promptPrefix}`; - instructions = { - role: 'system', - name: 'instructions', - content: promptPrefix, - }; - - if (this.contextStrategy) { - instructions.tokenCount = this.getTokenCountForMessage(instructions); - } - } if (this.options.attachments) { - const attachments = (await this.options.attachments).filter((file) => - file.type.includes('image'), - ); + const attachments = await this.options.attachments; if (this.message_file_map) { this.message_file_map[orderedMessages[orderedMessages.length - 1].messageId] = attachments; @@ -484,6 +460,13 @@ class OpenAIClient extends BaseClient { this.options.attachments = files; } + if (this.message_file_map) { + this.contextHandlers = createContextHandlers( + this.options.req, + orderedMessages[orderedMessages.length - 1].text, + ); + } + const formattedMessages = orderedMessages.map((message, i) => { const formattedMessage = formatMessage({ message, @@ -502,6 +485,11 @@ class OpenAIClient extends BaseClient { if (this.message_file_map && this.message_file_map[message.messageId]) { const attachments = this.message_file_map[message.messageId]; for (const file of attachments) { + if (file.embedded) { + this.contextHandlers?.processFile(file); + continue; + } + orderedMessages[i].tokenCount += this.calculateImageTokenCost({ width: file.width, height: file.height, @@ -513,6 +501,24 @@ class OpenAIClient extends BaseClient { return formattedMessage; }); + if (this.contextHandlers) { + this.augmentedPrompt = await this.contextHandlers.createContext(); + promptPrefix = this.augmentedPrompt + promptPrefix; + } + + if (promptPrefix) { + promptPrefix = `Instructions:\n${promptPrefix.trim()}`; + instructions = { + role: 'system', + name: 'instructions', + content: promptPrefix, + }; + + if (this.contextStrategy) { + instructions.tokenCount = this.getTokenCountForMessage(instructions); + } + } + // TODO: need to handle interleaving instructions better if (this.contextStrategy) { ({ payload, tokenCountMap, promptTokens, messages } = await this.handleContextStrategy({ @@ -540,15 +546,16 @@ class OpenAIClient extends BaseClient { return result; } + /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { let reply = ''; let result = null; let streamResult = null; this.modelOptions.user = this.user; const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null; - const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion); + const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined'); if (typeof opts.onProgress === 'function' && useOldMethod) { - await this.getCompletion( + const completionResult = await this.getCompletion( payload, (progressMessage) => { if (progressMessage === '[DONE]') { @@ -581,12 +588,16 @@ class OpenAIClient extends BaseClient { opts.onProgress(token); reply += token; }, + opts.onProgress, opts.abortController || new AbortController(), ); + + if (completionResult && typeof completionResult === 'string') { + reply = completionResult; + } } else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) { reply = await this.chatCompletion({ payload, - clientOptions: opts, onProgress: opts.onProgress, abortController: opts.abortController, }); @@ -594,9 +605,14 @@ class OpenAIClient extends BaseClient { result = await this.getCompletion( payload, null, + opts.onProgress, opts.abortController || new AbortController(), ); + if (result && typeof result === 'string') { + return result.trim(); + } + logger.debug('[OpenAIClient] sendCompletion: result', result); if (this.isChatCompletion) { @@ -606,11 +622,11 @@ class OpenAIClient extends BaseClient { } } - if (streamResult && typeof opts.addMetadata === 'function') { + if (streamResult) { const { finish_reason } = streamResult.choices[0]; - opts.addMetadata({ finish_reason }); + this.metadata = { finish_reason }; } - return reply.trim(); + return (reply ?? '').trim(); } initializeLLM({ @@ -624,6 +640,7 @@ class OpenAIClient extends BaseClient { context, tokenBuffer, initialMessageCount, + conversationId, }) { const modelOptions = { modelName: modelName ?? model, @@ -653,6 +670,16 @@ class OpenAIClient extends BaseClient { }; } + const { headers } = this.options; + if (headers && typeof headers === 'object' && !Array.isArray(headers)) { + configOptions.baseOptions = { + headers: resolveHeaders({ + ...headers, + ...configOptions?.baseOptions?.headers, + }), + }; + } + if (this.options.proxy) { configOptions.httpAgent = new HttpsProxyAgent(this.options.proxy); configOptions.httpsAgent = new HttpsProxyAgent(this.options.proxy); @@ -671,7 +698,7 @@ class OpenAIClient extends BaseClient { callbacks: runManager.createCallbacks({ context, tokenBuffer, - conversationId: this.conversationId, + conversationId: this.conversationId ?? conversationId, initialMessageCount, }), }); @@ -687,12 +714,13 @@ class OpenAIClient extends BaseClient { * * @param {Object} params - The parameters for the conversation title generation. * @param {string} params.text - The user's input. + * @param {string} [params.conversationId] - The current conversationId, if not already defined on client initialization. * @param {string} [params.responseText=''] - The AI's immediate response to the user. * * @returns {Promise} A promise that resolves to the generated conversation title. * In case of failure, it will return the default title, "New Chat". */ - async titleConvo({ text, responseText = '' }) { + async titleConvo({ text, conversationId, responseText = '' }) { let title = 'New Chat'; const convo = `||>User: "${truncateText(text)}" @@ -712,6 +740,39 @@ class OpenAIClient extends BaseClient { max_tokens: 16, }; + /** @type {TAzureConfig | undefined} */ + const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; + + const resetTitleOptions = !!( + (this.azure && azureConfig) || + (azureConfig && this.options.endpoint === EModelEndpoint.azureOpenAI) + ); + + if (resetTitleOptions) { + const { modelGroupMap, groupMap } = azureConfig; + const { + azureOptions, + baseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName: modelOptions.model, + modelGroupMap, + groupMap, + }); + + this.options.headers = resolveHeaders(headers); + this.options.reverseProxyUrl = baseURL ?? null; + this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl); + this.apiKey = azureOptions.azureOpenAIApiKey; + + const groupName = modelGroupMap[modelOptions.model].group; + this.options.addParams = azureConfig.groupMap[groupName].addParams; + this.options.dropParams = azureConfig.groupMap[groupName].dropParams; + this.options.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; + this.azure = !serverless && azureOptions; + } + const titleChatCompletion = async () => { modelOptions.model = model; @@ -723,8 +784,7 @@ class OpenAIClient extends BaseClient { const instructionsPayload = [ { role: 'system', - content: `Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect. -Write in the detected language. Title in 5 Words or Less. No Punctuation or Quotation. Do not mention the language. All first letters of every word should be capitalized and write the title in User Language only. + content: `Please generate ${titleInstruction} ${convo} @@ -732,10 +792,18 @@ ${convo} }, ]; + const promptTokens = this.getTokenCountForMessage(instructionsPayload[0]); + try { + let useChatCompletion = true; + if (this.options.reverseProxyUrl === CohereConstants.API_URL) { + useChatCompletion = false; + } title = ( - await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion: true }) + await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion }) ).replaceAll('"', ''); + const completionTokens = this.getTokenCount(title); + this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' }); } catch (e) { logger.error( '[OpenAIClient] There was an issue generating the title with the completion method', @@ -752,7 +820,12 @@ ${convo} try { this.abortController = new AbortController(); - const llm = this.initializeLLM({ ...modelOptions, context: 'title', tokenBuffer: 150 }); + const llm = this.initializeLLM({ + ...modelOptions, + conversationId, + context: 'title', + tokenBuffer: 150, + }); title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal }); } catch (e) { if (e?.message?.toLowerCase()?.includes('abort')) { @@ -779,7 +852,12 @@ ${convo} // TODO: remove the gpt fallback and make it specific to endpoint const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {}; const model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; - const maxContextTokens = getModelMaxTokens(model) ?? 4095; + const maxContextTokens = + getModelMaxTokens( + model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? 4095; // 1 less than maximum // 3 tokens for the assistant label, and 98 for the summarizer prompt (101) let promptBuffer = 101; @@ -877,14 +955,14 @@ ${convo} } } - async recordTokenUsage({ promptTokens, completionTokens }) { - logger.debug('[OpenAIClient] recordTokenUsage:', { promptTokens, completionTokens }); + async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) { await spendTokens( { + context, user: this.user, model: this.modelOptions.model, - context: 'message', conversationId: this.conversationId, + endpointTokenConfig: this.options.endpointTokenConfig, }, { promptTokens, completionTokens }, ); @@ -897,7 +975,7 @@ ${convo} }); } - async chatCompletion({ payload, onProgress, clientOptions, abortController = null }) { + async chatCompletion({ payload, onProgress, abortController = null }) { let error = null; const errorCallback = (err) => (error = err); let intermediateReply = ''; @@ -918,15 +996,6 @@ ${convo} } const baseURL = extractBaseURL(this.completionsUrl); - // let { messages: _msgsToLog, ...modelOptionsToLog } = modelOptions; - // if (modelOptionsToLog.messages) { - // _msgsToLog = modelOptionsToLog.messages.map((msg) => { - // let { content, ...rest } = msg; - - // if (content) - // return { ...rest, content: truncateText(content) }; - // }); - // } logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions }); const opts = { baseURL, @@ -951,6 +1020,38 @@ ${convo} modelOptions.max_tokens = 4000; } + /** @type {TAzureConfig | undefined} */ + const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; + + if ( + (this.azure && this.isVisionModel && azureConfig) || + (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) + ) { + const { modelGroupMap, groupMap } = azureConfig; + const { + azureOptions, + baseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName: modelOptions.model, + modelGroupMap, + groupMap, + }); + opts.defaultHeaders = resolveHeaders(headers); + this.langchainProxy = extractBaseURL(baseURL); + this.apiKey = azureOptions.azureOpenAIApiKey; + + const groupName = modelGroupMap[modelOptions.model].group; + this.options.addParams = azureConfig.groupMap[groupName].addParams; + this.options.dropParams = azureConfig.groupMap[groupName].dropParams; + // Note: `forcePrompt` not re-assigned as only chat models are vision models + + this.azure = !serverless && azureOptions; + this.azureEndpoint = + !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); + } + if (this.azure || this.options.azure) { // Azure does not accept `model` in the body, so we need to remove it. delete modelOptions.model; @@ -958,9 +1059,10 @@ ${convo} opts.baseURL = this.langchainProxy ? constructAzureURL({ baseURL: this.langchainProxy, - azure: this.azure, + azureOptions: this.azure, }) - : this.azureEndpoint.split(/\/(chat|completion)/)[0]; + : this.azureEndpoint.split(/(? msg.role === 'system'); + + if (systemMessageIndex > 0) { + const [systemMessage] = messages.splice(systemMessageIndex, 1); + messages.unshift(systemMessage); + } + + modelOptions.messages = messages; + if (messages.length === 1 && messages[0].role === 'system') { modelOptions.messages[0].role = 'user'; } @@ -988,12 +1104,20 @@ ${convo} ...modelOptions, ...this.options.addParams, }; + logger.debug('[OpenAIClient] chatCompletion: added params', { + addParams: this.options.addParams, + modelOptions, + }); } if (this.options.dropParams && Array.isArray(this.options.dropParams)) { this.options.dropParams.forEach((param) => { delete modelOptions[param]; }); + logger.debug('[OpenAIClient] chatCompletion: dropped params', { + dropParams: this.options.dropParams, + modelOptions, + }); } let UnexpectedRoleError = false; @@ -1009,6 +1133,16 @@ ${convo} .on('error', (err) => { handleOpenAIErrors(err, errorCallback, 'stream'); }) + .on('finalChatCompletion', (finalChatCompletion) => { + const finalMessage = finalChatCompletion?.choices?.[0]?.message; + if (finalMessage && finalMessage?.role !== 'assistant') { + finalChatCompletion.choices[0].message.role = 'assistant'; + } + + if (finalMessage && !finalMessage?.content?.trim()) { + finalChatCompletion.choices[0].message.content = intermediateReply; + } + }) .on('finalMessage', (message) => { if (message?.role !== 'assistant') { stream.messages.push({ role: 'assistant', content: intermediateReply }); @@ -1054,12 +1188,20 @@ ${convo} } const { message, finish_reason } = chatCompletion.choices[0]; - if (chatCompletion && typeof clientOptions.addMetadata === 'function') { - clientOptions.addMetadata({ finish_reason }); + if (chatCompletion) { + this.metadata = { finish_reason }; } logger.debug('[OpenAIClient] chatCompletion response', chatCompletion); + if (!message?.content?.trim() && intermediateReply.length) { + logger.debug( + '[OpenAIClient] chatCompletion: using intermediateReply due to empty message.content', + { intermediateReply }, + ); + return intermediateReply; + } + return message.content; } catch (err) { if ( @@ -1072,6 +1214,9 @@ ${convo} err?.message?.includes( 'OpenAI error: Invalid final message: OpenAI expects final message to include role=assistant', ) || + err?.message?.includes( + 'stream ended without producing a ChatCompletionMessage with role=assistant', + ) || err?.message?.includes('The server had an error processing your request') || err?.message?.includes('missing finish_reason') || err?.message?.includes('missing role') || diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 6118c3547a1..033c122664f 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -3,6 +3,7 @@ const { CallbackManager } = require('langchain/callbacks'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); +const { processFileURL } = require('~/server/services/Files/process'); const { EModelEndpoint } = require('librechat-data-provider'); const { formatLangChainMessages } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); @@ -30,10 +31,6 @@ class PluginsClient extends OpenAIClient { super.setOptions(options); - if (this.functionsAgent && this.agentOptions.model && !this.useOpenRouter) { - this.agentOptions.model = this.getFunctionModelName(this.agentOptions.model); - } - this.isGpt3 = this.modelOptions?.model?.includes('gpt-3'); if (this.options.reverseProxyUrl) { @@ -113,6 +110,7 @@ class PluginsClient extends OpenAIClient { openAIApiKey: this.openAIApiKey, conversationId: this.conversationId, fileStrategy: this.options.req.app.locals.fileStrategy, + processFileURL, message, }, }); diff --git a/api/app/clients/llm/createCoherePayload.js b/api/app/clients/llm/createCoherePayload.js new file mode 100644 index 00000000000..58803d76f3c --- /dev/null +++ b/api/app/clients/llm/createCoherePayload.js @@ -0,0 +1,85 @@ +const { CohereConstants } = require('librechat-data-provider'); +const { titleInstruction } = require('../prompts/titlePrompts'); + +// Mapping OpenAI roles to Cohere roles +const roleMap = { + user: CohereConstants.ROLE_USER, + assistant: CohereConstants.ROLE_CHATBOT, + system: CohereConstants.ROLE_SYSTEM, // Recognize and map the system role explicitly +}; + +/** + * Adjusts an OpenAI ChatCompletionPayload to conform with Cohere's expected chat payload format. + * Now includes handling for "system" roles explicitly mentioned. + * + * @param {Object} options - Object containing the model options. + * @param {ChatCompletionPayload} options.modelOptions - The OpenAI model payload options. + * @returns {CohereChatStreamRequest} Cohere-compatible chat API payload. + */ +function createCoherePayload({ modelOptions }) { + /** @type {string | undefined} */ + let preamble; + let latestUserMessageContent = ''; + const { + stream, + stop, + top_p, + temperature, + frequency_penalty, + presence_penalty, + max_tokens, + messages, + model, + ...rest + } = modelOptions; + + // Filter out the latest user message and transform remaining messages to Cohere's chat_history format + let chatHistory = messages.reduce((acc, message, index, arr) => { + const isLastUserMessage = index === arr.length - 1 && message.role === 'user'; + + const messageContent = + typeof message.content === 'string' + ? message.content + : message.content.map((part) => (part.type === 'text' ? part.text : '')).join(' '); + + if (isLastUserMessage) { + latestUserMessageContent = messageContent; + } else { + acc.push({ + role: roleMap[message.role] || CohereConstants.ROLE_USER, + message: messageContent, + }); + } + + return acc; + }, []); + + if ( + chatHistory.length === 1 && + chatHistory[0].role === CohereConstants.ROLE_SYSTEM && + !latestUserMessageContent.length + ) { + const message = chatHistory[0].message; + latestUserMessageContent = message.includes(titleInstruction) + ? CohereConstants.TITLE_MESSAGE + : '.'; + preamble = message; + } + + return { + message: latestUserMessageContent, + model: model, + chatHistory, + stream: stream ?? false, + temperature: temperature, + frequencyPenalty: frequency_penalty, + presencePenalty: presence_penalty, + maxTokens: max_tokens, + stopSequences: stop, + preamble, + p: top_p, + ...rest, + }; +} + +module.exports = createCoherePayload; diff --git a/api/app/clients/llm/createLLM.js b/api/app/clients/llm/createLLM.js index 62f2fe86f95..09b29cca8e9 100644 --- a/api/app/clients/llm/createLLM.js +++ b/api/app/clients/llm/createLLM.js @@ -55,16 +55,18 @@ function createLLM({ } if (azure && configOptions.basePath) { - configOptions.basePath = constructAzureURL({ + const azureURL = constructAzureURL({ baseURL: configOptions.basePath, - azure: azureOptions, + azureOptions, }); + azureOptions.azureOpenAIBasePath = azureURL.split( + `/${azureOptions.azureOpenAIApiDeploymentName}`, + )[0]; } return new ChatOpenAI( { streaming, - verbose: true, credentials, configuration, ...azureOptions, diff --git a/api/app/clients/llm/index.js b/api/app/clients/llm/index.js index 46478ade63b..2e09bbb841b 100644 --- a/api/app/clients/llm/index.js +++ b/api/app/clients/llm/index.js @@ -1,7 +1,9 @@ const createLLM = require('./createLLM'); const RunManager = require('./RunManager'); +const createCoherePayload = require('./createCoherePayload'); module.exports = { createLLM, RunManager, + createCoherePayload, }; diff --git a/api/app/clients/prompts/createContextHandlers.js b/api/app/clients/prompts/createContextHandlers.js new file mode 100644 index 00000000000..e48dfd8e672 --- /dev/null +++ b/api/app/clients/prompts/createContextHandlers.js @@ -0,0 +1,159 @@ +const axios = require('axios'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); + +const footer = `Use the context as your learned knowledge to better answer the user. + +In your response, remember to follow these guidelines: +- If you don't know the answer, simply say that you don't know. +- If you are unsure how to answer, ask for clarification. +- Avoid mentioning that you obtained the information from the context. + +Answer appropriately in the user's language. +`; + +function createContextHandlers(req, userMessageContent) { + if (!process.env.RAG_API_URL) { + return; + } + + const queryPromises = []; + const processedFiles = []; + const processedIds = new Set(); + const jwtToken = req.headers.authorization.split(' ')[1]; + const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT); + + const query = async (file) => { + if (useFullContext) { + return axios.get(`${process.env.RAG_API_URL}/documents/${file.file_id}/context`, { + headers: { + Authorization: `Bearer ${jwtToken}`, + }, + }); + } + + return axios.post( + `${process.env.RAG_API_URL}/query`, + { + file_id: file.file_id, + query: userMessageContent, + k: 4, + }, + { + headers: { + Authorization: `Bearer ${jwtToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + }; + + const processFile = async (file) => { + if (file.embedded && !processedIds.has(file.file_id)) { + try { + const promise = query(file); + queryPromises.push(promise); + processedFiles.push(file); + processedIds.add(file.file_id); + } catch (error) { + logger.error(`Error processing file ${file.filename}:`, error); + } + } + }; + + const createContext = async () => { + try { + if (!queryPromises.length || !processedFiles.length) { + return ''; + } + + const oneFile = processedFiles.length === 1; + const header = `The user has attached ${oneFile ? 'a' : processedFiles.length} file${ + !oneFile ? 's' : '' + } to the conversation:`; + + const files = `${ + oneFile + ? '' + : ` + ` + }${processedFiles + .map( + (file) => ` + + ${file.filename} + ${file.type} + `, + ) + .join('')}${ + oneFile + ? '' + : ` + ` + }`; + + const resolvedQueries = await Promise.all(queryPromises); + + const context = resolvedQueries + .map((queryResult, index) => { + const file = processedFiles[index]; + let contextItems = queryResult.data; + + const generateContext = (currentContext) => + ` + + ${file.filename} + ${currentContext} + + `; + + if (useFullContext) { + return generateContext(`\n${contextItems}`); + } + + contextItems = queryResult.data + .map((item) => { + const pageContent = item[0].page_content; + return ` + + + `; + }) + .join(''); + + return generateContext(contextItems); + }) + .join(''); + + if (useFullContext) { + const prompt = `${header} + ${context} + ${footer}`; + + return prompt; + } + + const prompt = `${header} + ${files} + + A semantic search was executed with the user's message as the query, retrieving the following context inside XML tags. + + ${context} + + + ${footer}`; + + return prompt; + } catch (error) { + logger.error('Error creating context:', error); + throw error; + } + }; + + return { + processFile, + createContext, + }; +} + +module.exports = createContextHandlers; diff --git a/api/app/clients/prompts/createVisionPrompt.js b/api/app/clients/prompts/createVisionPrompt.js new file mode 100644 index 00000000000..5d8a7bbf51b --- /dev/null +++ b/api/app/clients/prompts/createVisionPrompt.js @@ -0,0 +1,34 @@ +/** + * Generates a prompt instructing the user to describe an image in detail, tailored to different types of visual content. + * @param {boolean} pluralized - Whether to pluralize the prompt for multiple images. + * @returns {string} - The generated vision prompt. + */ +const createVisionPrompt = (pluralized = false) => { + return `Please describe the image${ + pluralized ? 's' : '' + } in detail, covering relevant aspects such as: + + For photographs, illustrations, or artwork: + - The main subject(s) and their appearance, positioning, and actions + - The setting, background, and any notable objects or elements + - Colors, lighting, and overall mood or atmosphere + - Any interesting details, textures, or patterns + - The style, technique, or medium used (if discernible) + + For screenshots or images containing text: + - The content and purpose of the text + - The layout, formatting, and organization of the information + - Any notable visual elements, such as logos, icons, or graphics + - The overall context or message conveyed by the screenshot + + For graphs, charts, or data visualizations: + - The type of graph or chart (e.g., bar graph, line chart, pie chart) + - The variables being compared or analyzed + - Any trends, patterns, or outliers in the data + - The axis labels, scales, and units of measurement + - The title, legend, and any additional context provided + + Be as specific and descriptive as possible while maintaining clarity and concision.`; +}; + +module.exports = createVisionPrompt; diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index 1b97bc7ffa1..c19eee260af 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -1,3 +1,4 @@ +const { EModelEndpoint } = require('librechat-data-provider'); const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema'); /** @@ -7,10 +8,16 @@ const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema'); * @param {Object} params.message - The message object to format. * @param {string} [params.message.role] - The role of the message sender (must be 'user'). * @param {string} [params.message.content] - The text content of the message. + * @param {EModelEndpoint} [params.endpoint] - Identifier for specific endpoint handling * @param {Array} [params.image_urls] - The image_urls to attach to the message. * @returns {(Object)} - The formatted message. */ -const formatVisionMessage = ({ message, image_urls }) => { +const formatVisionMessage = ({ message, image_urls, endpoint }) => { + if (endpoint === EModelEndpoint.anthropic) { + message.content = [...image_urls, { type: 'text', text: message.content }]; + return message; + } + message.content = [{ type: 'text', text: message.content }, ...image_urls]; return message; @@ -29,10 +36,11 @@ const formatVisionMessage = ({ message, image_urls }) => { * @param {Array} [params.message.image_urls] - The image_urls attached to the message for Vision API. * @param {string} [params.userName] - The name of the user. * @param {string} [params.assistantName] - The name of the assistant. + * @param {string} [params.endpoint] - Identifier for specific endpoint handling * @param {boolean} [params.langChain=false] - Whether to return a LangChain message object. * @returns {(Object|HumanMessage|AIMessage|SystemMessage)} - The formatted message. */ -const formatMessage = ({ message, userName, assistantName, langChain = false }) => { +const formatMessage = ({ message, userName, assistantName, endpoint, langChain = false }) => { let { role: _role, _name, sender, text, content: _content, lc_id } = message; if (lc_id && lc_id[2] && !langChain) { const roleMapping = { @@ -51,7 +59,11 @@ const formatMessage = ({ message, userName, assistantName, langChain = false }) const { image_urls } = message; if (Array.isArray(image_urls) && image_urls.length > 0 && role === 'user') { - return formatVisionMessage({ message: formattedMessage, image_urls: message.image_urls }); + return formatVisionMessage({ + message: formattedMessage, + image_urls: message.image_urls, + endpoint, + }); } if (_name) { diff --git a/api/app/clients/prompts/formatMessages.spec.js b/api/app/clients/prompts/formatMessages.spec.js index 636cdb1c8e5..8d4956b3811 100644 --- a/api/app/clients/prompts/formatMessages.spec.js +++ b/api/app/clients/prompts/formatMessages.spec.js @@ -1,5 +1,6 @@ -const { formatMessage, formatLangChainMessages, formatFromLangChain } = require('./formatMessages'); +const { Constants } = require('librechat-data-provider'); const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema'); +const { formatMessage, formatLangChainMessages, formatFromLangChain } = require('./formatMessages'); describe('formatMessage', () => { it('formats user message', () => { @@ -61,7 +62,7 @@ describe('formatMessage', () => { isCreatedByUser: true, isEdited: false, model: null, - parentMessageId: '00000000-0000-0000-0000-000000000000', + parentMessageId: Constants.NO_PARENT, sender: 'User', text: 'hi', tokenCount: 5, diff --git a/api/app/clients/prompts/index.js b/api/app/clients/prompts/index.js index 40db3d90439..36bb6f7e283 100644 --- a/api/app/clients/prompts/index.js +++ b/api/app/clients/prompts/index.js @@ -4,6 +4,8 @@ const handleInputs = require('./handleInputs'); const instructions = require('./instructions'); const titlePrompts = require('./titlePrompts'); const truncateText = require('./truncateText'); +const createVisionPrompt = require('./createVisionPrompt'); +const createContextHandlers = require('./createContextHandlers'); module.exports = { ...formatMessages, @@ -12,4 +14,6 @@ module.exports = { ...instructions, ...titlePrompts, truncateText, + createVisionPrompt, + createContextHandlers, }; diff --git a/api/app/clients/prompts/titlePrompts.js b/api/app/clients/prompts/titlePrompts.js index 1e893ba295d..83d8909f3a7 100644 --- a/api/app/clients/prompts/titlePrompts.js +++ b/api/app/clients/prompts/titlePrompts.js @@ -27,7 +27,63 @@ ${convo}`, return titlePrompt; }; +const titleInstruction = + 'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"'; +const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title. + +You may call them like this: + + +$TOOL_NAME + +<$PARAMETER_NAME>$PARAMETER_VALUE +... + + + + +Here are the tools available: + + +submit_title + +Submit a brief title in the conversation's language, following the parameter description closely. + + + +title +string +${titleInstruction} + + + +`; + +/** + * Parses titles from title functions based on the provided prompt. + * @param {string} prompt - The prompt containing the title function. + * @returns {string} The parsed title. "New Chat" if no title is found. + */ +function parseTitleFromPrompt(prompt) { + const titleRegex = /(.+?)<\/title>/; + const titleMatch = prompt.match(titleRegex); + + if (titleMatch && titleMatch[1]) { + const title = titleMatch[1].trim(); + + // // Capitalize the first letter of each word; Note: unnecessary due to title case prompting + // const capitalizedTitle = title.replace(/\b\w/g, (char) => char.toUpperCase()); + + return title; + } + + return 'New Chat'; +} + module.exports = { langPrompt, + titleInstruction, createTitlePrompt, + titleFunctionPrompt, + parseTitleFromPrompt, }; diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 889499fbc29..9ffa7e04f1b 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -1,3 +1,4 @@ +const { Constants } = require('librechat-data-provider'); const { initializeFakeClient } = require('./FakeClient'); jest.mock('../../../lib/db/connectDb'); @@ -307,7 +308,7 @@ describe('BaseClient', () => { const unorderedMessages = [ { id: '3', parentMessageId: '2', text: 'Message 3' }, { id: '2', parentMessageId: '1', text: 'Message 2' }, - { id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' }, + { id: '1', parentMessageId: Constants.NO_PARENT, text: 'Message 1' }, ]; it('should return ordered messages based on parentMessageId', () => { @@ -316,7 +317,7 @@ describe('BaseClient', () => { parentMessageId: '3', }); expect(result).toEqual([ - { id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' }, + { id: '1', parentMessageId: Constants.NO_PARENT, text: 'Message 1' }, { id: '2', parentMessageId: '1', text: 'Message 2' }, { id: '3', parentMessageId: '2', text: 'Message 3' }, ]); diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index b4e42b1fc51..dfd57b23b94 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -1,9 +1,10 @@ +const crypto = require('crypto'); +const { Constants } = require('librechat-data-provider'); const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); const PluginsClient = require('../PluginsClient'); -const crypto = require('crypto'); -jest.mock('../../../lib/db/connectDb'); -jest.mock('../../../models/Conversation', () => { +jest.mock('~/lib/db/connectDb'); +jest.mock('~/models/Conversation', () => { return function () { return { save: jest.fn(), @@ -12,6 +13,12 @@ jest.mock('../../../models/Conversation', () => { }; }); +const defaultAzureOptions = { + azureOpenAIApiInstanceName: 'your-instance-name', + azureOpenAIApiDeploymentName: 'your-deployment-name', + azureOpenAIApiVersion: '2020-07-01-preview', +}; + describe('PluginsClient', () => { let TestAgent; let options = { @@ -60,7 +67,7 @@ describe('PluginsClient', () => { TestAgent.setOptions(opts); } const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; + const parentMessageId = opts.parentMessageId || Constants.NO_PARENT; const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); this.pastMessages = await TestAgent.loadHistory( conversationId, @@ -187,4 +194,30 @@ describe('PluginsClient', () => { expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); }); }); + describe('Azure OpenAI tests specific to Plugins', () => { + // TODO: add more tests for Azure OpenAI integration with Plugins + // let client; + // beforeEach(() => { + // client = new PluginsClient('dummy_api_key'); + // }); + + test('should not call getFunctionModelName when azure options are set', () => { + const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName'); + const model = 'gpt-4-turbo'; + + // note, without the azure change in PR #1766, `getFunctionModelName` is called twice + const testClient = new PluginsClient('dummy_api_key', { + agentOptions: { + model, + agent: 'functions', + }, + azure: defaultAzureOptions, + }); + + expect(spy).not.toHaveBeenCalled(); + expect(testClient.agentOptions.model).toBe(model); + + spy.mockRestore(); + }); + }); }); diff --git a/api/app/clients/tools/DALL-E.js b/api/app/clients/tools/DALL-E.js index 4eca7f7932e..4600bdb026e 100644 --- a/api/app/clients/tools/DALL-E.js +++ b/api/app/clients/tools/DALL-E.js @@ -3,42 +3,39 @@ const OpenAI = require('openai'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('langchain/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); -const { processFileURL } = require('~/server/services/Files/process'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); -const { - DALLE2_SYSTEM_PROMPT, - DALLE_REVERSE_PROXY, - PROXY, - DALLE2_AZURE_API_VERSION, - DALLE2_BASEURL, - DALLE2_API_KEY, - DALLE_API_KEY, -} = process.env; class OpenAICreateImage extends Tool { constructor(fields = {}) { super(); this.userId = fields.userId; this.fileStrategy = fields.fileStrategy; + if (fields.processFileURL) { + this.processFileURL = fields.processFileURL.bind(this); + } let apiKey = fields.DALLE2_API_KEY ?? fields.DALLE_API_KEY ?? this.getApiKey(); const config = { apiKey }; - if (DALLE_REVERSE_PROXY) { - config.baseURL = extractBaseURL(DALLE_REVERSE_PROXY); + if (process.env.DALLE_REVERSE_PROXY) { + config.baseURL = extractBaseURL(process.env.DALLE_REVERSE_PROXY); } - if (DALLE2_AZURE_API_VERSION && DALLE2_BASEURL) { - config.baseURL = DALLE2_BASEURL; - config.defaultQuery = { 'api-version': DALLE2_AZURE_API_VERSION }; - config.defaultHeaders = { 'api-key': DALLE2_API_KEY, 'Content-Type': 'application/json' }; - config.apiKey = DALLE2_API_KEY; + if (process.env.DALLE2_AZURE_API_VERSION && process.env.DALLE2_BASEURL) { + config.baseURL = process.env.DALLE2_BASEURL; + config.defaultQuery = { 'api-version': process.env.DALLE2_AZURE_API_VERSION }; + config.defaultHeaders = { + 'api-key': process.env.DALLE2_API_KEY, + 'Content-Type': 'application/json', + }; + config.apiKey = process.env.DALLE2_API_KEY; } - if (PROXY) { - config.httpAgent = new HttpsProxyAgent(PROXY); + if (process.env.PROXY) { + config.httpAgent = new HttpsProxyAgent(process.env.PROXY); } this.openai = new OpenAI(config); @@ -51,7 +48,7 @@ Guidelines: "Subject: [subject], Style: [style], Color: [color], Details: [details], Emotion: [emotion]" - Generate images only once per human query unless explicitly requested by the user`; this.description_for_model = - DALLE2_SYSTEM_PROMPT ?? + process.env.DALLE2_SYSTEM_PROMPT ?? `// Whenever a description of an image is given, generate prompts (following these rules), and use dalle to create the image. If the user does not ask for a specific number of images, default to creating 2 prompts to send to dalle that are written to be as diverse as possible. All prompts sent to dalle must abide by the following policies: // 1. Prompts must be in English. Translate to English if needed. // 2. One image per function call. Create only 1 image per request unless explicitly told to generate more than 1 image. @@ -67,7 +64,7 @@ Guidelines: } getApiKey() { - const apiKey = DALLE2_API_KEY ?? DALLE_API_KEY ?? ''; + const apiKey = process.env.DALLE2_API_KEY ?? process.env.DALLE_API_KEY ?? ''; if (!apiKey) { throw new Error('Missing DALLE_API_KEY environment variable.'); } @@ -86,13 +83,21 @@ Guidelines: } async _call(input) { - const resp = await this.openai.images.generate({ - prompt: this.replaceUnwantedChars(input), - // TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them? - n: 1, - // size: '1024x1024' - size: '512x512', - }); + let resp; + + try { + resp = await this.openai.images.generate({ + prompt: this.replaceUnwantedChars(input), + // TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them? + n: 1, + // size: '1024x1024' + size: '512x512', + }); + } catch (error) { + logger.error('[DALL-E] Problem generating the image:', error); + return `Something went wrong when trying to generate the image. The DALL-E API may be unavailable: +Error Message: ${error.message}`; + } const theImageUrl = resp.data[0].url; @@ -116,15 +121,16 @@ Guidelines: }); try { - const result = await processFileURL({ + const result = await this.processFileURL({ fileStrategy: this.fileStrategy, userId: this.userId, URL: theImageUrl, fileName: imageName, basePath: 'images', + context: FileContext.image_generation, }); - this.result = this.wrapInMarkdown(result); + this.result = this.wrapInMarkdown(result.filepath); } catch (error) { logger.error('Error while saving the image:', error); this.result = `Failed to save the image locally. ${error.message}`; diff --git a/api/app/clients/tools/GoogleSearch.js b/api/app/clients/tools/GoogleSearch.js deleted file mode 100644 index 3d7574b6c19..00000000000 --- a/api/app/clients/tools/GoogleSearch.js +++ /dev/null @@ -1,121 +0,0 @@ -const { google } = require('googleapis'); -const { Tool } = require('langchain/tools'); -const { logger } = require('~/config'); - -/** - * Represents a tool that allows an agent to use the Google Custom Search API. - * @extends Tool - */ -class GoogleSearchAPI extends Tool { - constructor(fields = {}) { - super(); - this.cx = fields.GOOGLE_CSE_ID || this.getCx(); - this.apiKey = fields.GOOGLE_API_KEY || this.getApiKey(); - this.customSearch = undefined; - } - - /** - * The name of the tool. - * @type {string} - */ - name = 'google'; - - /** - * A description for the agent to use - * @type {string} - */ - description = - 'Use the \'google\' tool to retrieve internet search results relevant to your input. The results will return links and snippets of text from the webpages'; - description_for_model = - 'Use the \'google\' tool to retrieve internet search results relevant to your input. The results will return links and snippets of text from the webpages'; - - getCx() { - const cx = process.env.GOOGLE_CSE_ID || ''; - if (!cx) { - throw new Error('Missing GOOGLE_CSE_ID environment variable.'); - } - return cx; - } - - getApiKey() { - const apiKey = process.env.GOOGLE_API_KEY || ''; - if (!apiKey) { - throw new Error('Missing GOOGLE_API_KEY environment variable.'); - } - return apiKey; - } - - getCustomSearch() { - if (!this.customSearch) { - const version = 'v1'; - this.customSearch = google.customsearch(version); - } - return this.customSearch; - } - - resultsToReadableFormat(results) { - let output = 'Results:\n'; - - results.forEach((resultObj, index) => { - output += `Title: ${resultObj.title}\n`; - output += `Link: ${resultObj.link}\n`; - if (resultObj.snippet) { - output += `Snippet: ${resultObj.snippet}\n`; - } - - if (index < results.length - 1) { - output += '\n'; - } - }); - - return output; - } - - /** - * Calls the tool with the provided input and returns a promise that resolves with a response from the Google Custom Search API. - * @param {string} input - The input to provide to the API. - * @returns {Promise<String>} A promise that resolves with a response from the Google Custom Search API. - */ - async _call(input) { - try { - const metadataResults = []; - const response = await this.getCustomSearch().cse.list({ - q: input, - cx: this.cx, - auth: this.apiKey, - num: 5, // Limit the number of results to 5 - }); - - // return response.data; - // logger.debug(response.data); - - if (!response.data.items || response.data.items.length === 0) { - return this.resultsToReadableFormat([ - { title: 'No good Google Search Result was found', link: '' }, - ]); - } - - // const results = response.items.slice(0, numResults); - const results = response.data.items; - - for (const result of results) { - const metadataResult = { - title: result.title || '', - link: result.link || '', - }; - if (result.snippet) { - metadataResult.snippet = result.snippet; - } - metadataResults.push(metadataResult); - } - - return this.resultsToReadableFormat(metadataResults); - } catch (error) { - logger.error('[GoogleSearchAPI]', error); - // throw error; - return 'There was an error searching Google.'; - } - } -} - -module.exports = GoogleSearchAPI; diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index f5410e89eec..f16d229e6b7 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -1,35 +1,44 @@ -const GoogleSearchAPI = require('./GoogleSearch'); +const availableTools = require('./manifest.json'); +// Basic Tools +const CodeBrew = require('./CodeBrew'); +const WolframAlphaAPI = require('./Wolfram'); +const AzureAiSearch = require('./AzureAiSearch'); const OpenAICreateImage = require('./DALL-E'); -const DALLE3 = require('./structured/DALLE3'); -const StructuredSD = require('./structured/StableDiffusion'); const StableDiffusionAPI = require('./StableDiffusion'); -const WolframAlphaAPI = require('./Wolfram'); -const StructuredWolfram = require('./structured/Wolfram'); const SelfReflectionTool = require('./SelfReflection'); -const AzureAiSearch = require('./AzureAiSearch'); -const StructuredACS = require('./structured/AzureAISearch'); + +// Structured Tools +const DALLE3 = require('./structured/DALLE3'); const ChatTool = require('./structured/ChatTool'); const E2BTools = require('./structured/E2BTools'); const CodeSherpa = require('./structured/CodeSherpa'); +const StructuredSD = require('./structured/StableDiffusion'); +const StructuredACS = require('./structured/AzureAISearch'); const CodeSherpaTools = require('./structured/CodeSherpaTools'); -const availableTools = require('./manifest.json'); -const CodeBrew = require('./CodeBrew'); +const GoogleSearchAPI = require('./structured/GoogleSearch'); +const StructuredWolfram = require('./structured/Wolfram'); +const TavilySearchResults = require('./structured/TavilySearchResults'); +const TraversaalSearch = require('./structured/TraversaalSearch'); module.exports = { availableTools, + // Basic Tools + CodeBrew, + AzureAiSearch, GoogleSearchAPI, + WolframAlphaAPI, OpenAICreateImage, - DALLE3, StableDiffusionAPI, - StructuredSD, - WolframAlphaAPI, - StructuredWolfram, SelfReflectionTool, - AzureAiSearch, - StructuredACS, - E2BTools, + // Structured Tools + DALLE3, ChatTool, + E2BTools, CodeSherpa, + StructuredSD, + StructuredACS, CodeSherpaTools, - CodeBrew, + StructuredWolfram, + TavilySearchResults, + TraversaalSearch, }; diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index bde4c8a87a9..3daaf9dd3bc 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -1,4 +1,17 @@ [ + { + "name": "Traversaal", + "pluginKey": "traversaal_search", + "description": "Traversaal is a robust search API tailored for LLM Agents. Get an API key here: https://api.traversaal.ai", + "icon": "https://traversaal.ai/favicon.ico", + "authConfig": [ + { + "authField": "TRAVERSAAL_API_KEY", + "label": "Traversaal API Key", + "description": "Get your API key here: <a href=\"https://api.traversaal.ai\" target=\"_blank\">https://api.traversaal.ai</a>" + } + ] + }, { "name": "Google", "pluginKey": "google", @@ -89,7 +102,7 @@ "icon": "https://i.imgur.com/u2TzXzH.png", "authConfig": [ { - "authField": "DALLE2_API_KEY", + "authField": "DALLE2_API_KEY||DALLE_API_KEY", "label": "OpenAI API Key", "description": "You can use DALL-E with your API Key from OpenAI." } @@ -102,12 +115,25 @@ "icon": "https://i.imgur.com/u2TzXzH.png", "authConfig": [ { - "authField": "DALLE3_API_KEY", + "authField": "DALLE3_API_KEY||DALLE_API_KEY", "label": "OpenAI API Key", "description": "You can use DALL-E with your API Key from OpenAI." } ] }, + { + "name": "Tavily Search", + "pluginKey": "tavily_search_results_json", + "description": "Tavily Search is a robust search API tailored for LLM Agents. It seamlessly integrates with diverse data sources to ensure a superior, relevant search experience.", + "icon": "https://tavily.com/favicon.ico", + "authConfig": [ + { + "authField": "TAVILY_API_KEY", + "label": "Tavily API Key", + "description": "Get your API key here: https://app.tavily.com/" + } + ] + }, { "name": "Calculator", "pluginKey": "calculator", diff --git a/api/app/clients/tools/structured/AzureAISearch.js b/api/app/clients/tools/structured/AzureAISearch.js index 9b50aa2c433..0ce7b43fb21 100644 --- a/api/app/clients/tools/structured/AzureAISearch.js +++ b/api/app/clients/tools/structured/AzureAISearch.js @@ -19,6 +19,13 @@ class AzureAISearch extends StructuredTool { this.name = 'azure-ai-search'; this.description = 'Use the \'azure-ai-search\' tool to retrieve search results relevant to your input'; + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + // Define schema + this.schema = z.object({ + query: z.string().describe('Search word or phrase to Azure AI Search'), + }); // Initialize properties using helper function this.serviceEndpoint = this._initializeField( @@ -51,12 +58,16 @@ class AzureAISearch extends StructuredTool { ); // Check for required fields - if (!this.serviceEndpoint || !this.indexName || !this.apiKey) { + if (!this.override && (!this.serviceEndpoint || !this.indexName || !this.apiKey)) { throw new Error( 'Missing AZURE_AI_SEARCH_SERVICE_ENDPOINT, AZURE_AI_SEARCH_INDEX_NAME, or AZURE_AI_SEARCH_API_KEY environment variable.', ); } + if (this.override) { + return; + } + // Create SearchClient this.client = new SearchClient( this.serviceEndpoint, @@ -64,11 +75,6 @@ class AzureAISearch extends StructuredTool { new AzureKeyCredential(this.apiKey), { apiVersion: this.apiVersion }, ); - - // Define schema - this.schema = z.object({ - query: z.string().describe('Search word or phrase to Azure AI Search'), - }); } // Improved error handling and logging diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index 33df93e7fcf..3155992ca9b 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -4,42 +4,47 @@ const OpenAI = require('openai'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('langchain/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); -const { processFileURL } = require('~/server/services/Files/process'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); -const { - DALLE3_SYSTEM_PROMPT, - DALLE_REVERSE_PROXY, - PROXY, - DALLE3_AZURE_API_VERSION, - DALLE3_BASEURL, - DALLE3_API_KEY, -} = process.env; class DALLE3 extends Tool { constructor(fields = {}) { super(); + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + /** @type {boolean} Necessary for output to contain all image metadata. */ + this.returnMetadata = fields.returnMetadata ?? false; this.userId = fields.userId; this.fileStrategy = fields.fileStrategy; + if (fields.processFileURL) { + /** @type {processFileURL} Necessary for output to contain all image metadata. */ + this.processFileURL = fields.processFileURL.bind(this); + } + let apiKey = fields.DALLE3_API_KEY ?? fields.DALLE_API_KEY ?? this.getApiKey(); const config = { apiKey }; - if (DALLE_REVERSE_PROXY) { - config.baseURL = extractBaseURL(DALLE_REVERSE_PROXY); + if (process.env.DALLE_REVERSE_PROXY) { + config.baseURL = extractBaseURL(process.env.DALLE_REVERSE_PROXY); } - if (DALLE3_AZURE_API_VERSION && DALLE3_BASEURL) { - config.baseURL = DALLE3_BASEURL; - config.defaultQuery = { 'api-version': DALLE3_AZURE_API_VERSION }; - config.defaultHeaders = { 'api-key': DALLE3_API_KEY, 'Content-Type': 'application/json' }; - config.apiKey = DALLE3_API_KEY; + if (process.env.DALLE3_AZURE_API_VERSION && process.env.DALLE3_BASEURL) { + config.baseURL = process.env.DALLE3_BASEURL; + config.defaultQuery = { 'api-version': process.env.DALLE3_AZURE_API_VERSION }; + config.defaultHeaders = { + 'api-key': process.env.DALLE3_API_KEY, + 'Content-Type': 'application/json', + }; + config.apiKey = process.env.DALLE3_API_KEY; } - if (PROXY) { - config.httpAgent = new HttpsProxyAgent(PROXY); + if (process.env.PROXY) { + config.httpAgent = new HttpsProxyAgent(process.env.PROXY); } + /** @type {OpenAI} */ this.openai = new OpenAI(config); this.name = 'dalle'; this.description = `Use DALLE to create images from text descriptions. @@ -47,7 +52,7 @@ class DALLE3 extends Tool { - Create only one image, without repeating or listing descriptions outside the "prompts" field. - Maintains the original intent of the description, with parameters for image style, quality, and size to tailor the output.`; this.description_for_model = - DALLE3_SYSTEM_PROMPT ?? + process.env.DALLE3_SYSTEM_PROMPT ?? `// Whenever a description of an image is given, generate prompts (following these rules), and use dalle to create the image. If the user does not ask for a specific number of images, default to creating 2 prompts to send to dalle that are written to be as diverse as possible. All prompts sent to dalle must abide by the following policies: // 1. Prompts must be in English. Translate to English if needed. // 2. One image per function call. Create only 1 image per request unless explicitly told to generate more than 1 image. @@ -86,7 +91,7 @@ class DALLE3 extends Tool { getApiKey() { const apiKey = process.env.DALLE3_API_KEY ?? process.env.DALLE_API_KEY ?? ''; - if (!apiKey) { + if (!apiKey && !this.override) { throw new Error('Missing DALLE_API_KEY environment variable.'); } return apiKey; @@ -120,6 +125,7 @@ class DALLE3 extends Tool { n: 1, }); } catch (error) { + logger.error('[DALL-E-3] Problem generating the image:', error); return `Something went wrong when trying to generate the image. The DALL-E API may be unavailable: Error Message: ${error.message}`; } @@ -150,15 +156,20 @@ Error Message: ${error.message}`; }); try { - const result = await processFileURL({ + const result = await this.processFileURL({ fileStrategy: this.fileStrategy, userId: this.userId, URL: theImageUrl, fileName: imageName, basePath: 'images', + context: FileContext.image_generation, }); - this.result = this.wrapInMarkdown(result); + if (this.returnMetadata) { + this.result = result; + } else { + this.result = this.wrapInMarkdown(result.filepath); + } } catch (error) { logger.error('Error while saving the image:', error); this.result = `Failed to save the image locally. ${error.message}`; diff --git a/api/app/clients/tools/structured/GoogleSearch.js b/api/app/clients/tools/structured/GoogleSearch.js new file mode 100644 index 00000000000..92d33272c83 --- /dev/null +++ b/api/app/clients/tools/structured/GoogleSearch.js @@ -0,0 +1,65 @@ +const { z } = require('zod'); +const { Tool } = require('@langchain/core/tools'); +const { getEnvironmentVariable } = require('@langchain/core/utils/env'); + +class GoogleSearchResults extends Tool { + static lc_name() { + return 'GoogleSearchResults'; + } + + constructor(fields = {}) { + super(fields); + this.envVarApiKey = 'GOOGLE_API_KEY'; + this.envVarSearchEngineId = 'GOOGLE_CSE_ID'; + this.override = fields.override ?? false; + this.apiKey = fields.apiKey ?? getEnvironmentVariable(this.envVarApiKey); + this.searchEngineId = + fields.searchEngineId ?? getEnvironmentVariable(this.envVarSearchEngineId); + + this.kwargs = fields?.kwargs ?? {}; + this.name = 'google'; + this.description = + 'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events.'; + + this.schema = z.object({ + query: z.string().min(1).describe('The search query string.'), + max_results: z + .number() + .min(1) + .max(10) + .optional() + .describe('The maximum number of search results to return. Defaults to 10.'), + // Note: Google API has its own parameters for search customization, adjust as needed. + }); + } + + async _call(input) { + const validationResult = this.schema.safeParse(input); + if (!validationResult.success) { + throw new Error(`Validation failed: ${JSON.stringify(validationResult.error.issues)}`); + } + + const { query, max_results = 5 } = validationResult.data; + + const response = await fetch( + `https://www.googleapis.com/customsearch/v1?key=${this.apiKey}&cx=${ + this.searchEngineId + }&q=${encodeURIComponent(query)}&num=${max_results}`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + + const json = await response.json(); + if (!response.ok) { + throw new Error(`Request failed with status ${response.status}: ${json.error.message}`); + } + + return JSON.stringify(json); + } +} + +module.exports = GoogleSearchResults; diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 1fc5096730e..e891cbb398a 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -4,12 +4,28 @@ const { z } = require('zod'); const path = require('path'); const axios = require('axios'); const sharp = require('sharp'); +const { v4: uuidv4 } = require('uuid'); const { StructuredTool } = require('langchain/tools'); +const { FileContext } = require('librechat-data-provider'); +const paths = require('~/config/paths'); const { logger } = require('~/config'); class StableDiffusionAPI extends StructuredTool { constructor(fields) { super(); + /** @type {string} User ID */ + this.userId = fields.userId; + /** @type {Express.Request | undefined} Express Request object, only provided by ToolService */ + this.req = fields.req; + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + /** @type {boolean} Necessary for output to contain all image metadata. */ + this.returnMetadata = fields.returnMetadata ?? false; + if (fields.uploadImageBuffer) { + /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ + this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); + } + this.name = 'stable-diffusion'; this.url = fields.SD_WEBUI_URL || this.getServerURL(); this.description_for_model = `// Generate images and visuals using text. @@ -44,7 +60,7 @@ class StableDiffusionAPI extends StructuredTool { getMarkdownImageUrl(imageName) { const imageUrl = path - .join(this.relativeImageUrl, imageName) + .join(this.relativePath, this.userId, imageName) .replace(/\\/g, '/') .replace('public/', ''); return `![generated image](/${imageUrl})`; @@ -52,7 +68,7 @@ class StableDiffusionAPI extends StructuredTool { getServerURL() { const url = process.env.SD_WEBUI_URL || ''; - if (!url) { + if (!url && !this.override) { throw new Error('Missing SD_WEBUI_URL environment variable.'); } return url; @@ -70,46 +86,67 @@ class StableDiffusionAPI extends StructuredTool { width: 1024, height: 1024, }; - const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload); - const image = response.data.images[0]; - const pngPayload = { image: `data:image/png;base64,${image}` }; - const response2 = await axios.post(`${url}/sdapi/v1/png-info`, pngPayload); - const info = response2.data.info; + const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload); + const image = generationResponse.data.images[0]; + + /** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */ + let info = {}; + try { + info = JSON.parse(generationResponse.data.info); + } catch (error) { + logger.error('[StableDiffusion] Error while getting image metadata:', error); + } - // Generate unique name - const imageName = `${Date.now()}.png`; - this.outputPath = path.resolve( - __dirname, - '..', - '..', - '..', - '..', - '..', - 'client', - 'public', - 'images', - ); - const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client'); - this.relativeImageUrl = path.relative(appRoot, this.outputPath); + const file_id = uuidv4(); + const imageName = `${file_id}.png`; + const { imageOutput: imageOutputPath, clientPath } = paths; + const filepath = path.join(imageOutputPath, this.userId, imageName); + this.relativePath = path.relative(clientPath, imageOutputPath); - // Check if directory exists, if not create it - if (!fs.existsSync(this.outputPath)) { - fs.mkdirSync(this.outputPath, { recursive: true }); + if (!fs.existsSync(path.join(imageOutputPath, this.userId))) { + fs.mkdirSync(path.join(imageOutputPath, this.userId), { recursive: true }); } try { const buffer = Buffer.from(image.split(',', 1)[0], 'base64'); + if (this.returnMetadata && this.uploadImageBuffer && this.req) { + const file = await this.uploadImageBuffer({ + req: this.req, + context: FileContext.image_generation, + resize: false, + metadata: { + buffer, + height: info.height, + width: info.width, + bytes: Buffer.byteLength(buffer), + filename: imageName, + type: 'image/png', + file_id, + }, + }); + + const generationInfo = info.infotexts[0].split('\n').pop(); + return { + ...file, + prompt, + metadata: { + negative_prompt, + seed: info.seed, + info: generationInfo, + }, + }; + } + await sharp(buffer) .withMetadata({ iptcpng: { - parameters: info, + parameters: info.infotexts[0], }, }) - .toFile(this.outputPath + '/' + imageName); + .toFile(filepath); this.result = this.getMarkdownImageUrl(imageName); } catch (error) { logger.error('[StableDiffusion] Error while saving the image:', error); - // this.result = theImageUrl; } return this.result; diff --git a/api/app/clients/tools/structured/TavilySearchResults.js b/api/app/clients/tools/structured/TavilySearchResults.js new file mode 100644 index 00000000000..3945ac1d00f --- /dev/null +++ b/api/app/clients/tools/structured/TavilySearchResults.js @@ -0,0 +1,92 @@ +const { z } = require('zod'); +const { Tool } = require('@langchain/core/tools'); +const { getEnvironmentVariable } = require('@langchain/core/utils/env'); + +class TavilySearchResults extends Tool { + static lc_name() { + return 'TavilySearchResults'; + } + + constructor(fields = {}) { + super(fields); + this.envVar = 'TAVILY_API_KEY'; + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + this.apiKey = fields.apiKey ?? this.getApiKey(); + + this.kwargs = fields?.kwargs ?? {}; + this.name = 'tavily_search_results_json'; + this.description = + 'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events.'; + + this.schema = z.object({ + query: z.string().min(1).describe('The search query string.'), + max_results: z + .number() + .min(1) + .max(10) + .optional() + .describe('The maximum number of search results to return. Defaults to 5.'), + search_depth: z + .enum(['basic', 'advanced']) + .optional() + .describe( + 'The depth of the search, affecting result quality and response time (`basic` or `advanced`). Default is basic for quick results and advanced for indepth high quality results but longer response time. Advanced calls equals 2 requests.', + ), + include_images: z + .boolean() + .optional() + .describe( + 'Whether to include a list of query-related images in the response. Default is False.', + ), + include_answer: z + .boolean() + .optional() + .describe('Whether to include answers in the search results. Default is False.'), + // include_raw_content: z.boolean().optional().describe('Whether to include raw content in the search results. Default is False.'), + // include_domains: z.array(z.string()).optional().describe('A list of domains to specifically include in the search results.'), + // exclude_domains: z.array(z.string()).optional().describe('A list of domains to specifically exclude from the search results.'), + }); + } + + getApiKey() { + const apiKey = getEnvironmentVariable(this.envVar); + if (!apiKey && !this.override) { + throw new Error(`Missing ${this.envVar} environment variable.`); + } + return apiKey; + } + + async _call(input) { + const validationResult = this.schema.safeParse(input); + if (!validationResult.success) { + throw new Error(`Validation failed: ${JSON.stringify(validationResult.error.issues)}`); + } + + const { query, ...rest } = validationResult.data; + + const requestBody = { + api_key: this.apiKey, + query, + ...rest, + ...this.kwargs, + }; + + const response = await fetch('https://api.tavily.com/search', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requestBody), + }); + + const json = await response.json(); + if (!response.ok) { + throw new Error(`Request failed with status ${response.status}: ${json.error}`); + } + + return JSON.stringify(json); + } +} + +module.exports = TavilySearchResults; diff --git a/api/app/clients/tools/structured/TraversaalSearch.js b/api/app/clients/tools/structured/TraversaalSearch.js new file mode 100644 index 00000000000..e8ceeda134f --- /dev/null +++ b/api/app/clients/tools/structured/TraversaalSearch.js @@ -0,0 +1,89 @@ +const { z } = require('zod'); +const { Tool } = require('@langchain/core/tools'); +const { getEnvironmentVariable } = require('@langchain/core/utils/env'); +const { logger } = require('~/config'); + +/** + * Tool for the Traversaal AI search API, Ares. + */ +class TraversaalSearch extends Tool { + static lc_name() { + return 'TraversaalSearch'; + } + constructor(fields) { + super(fields); + this.name = 'traversaal_search'; + this.description = `An AI search engine optimized for comprehensive, accurate, and trusted results. + Useful for when you need to answer questions about current events. Input should be a search query.`; + this.description_for_model = + '\'Please create a specific sentence for the AI to understand and use as a query to search the web based on the user\'s request. For example, "Find information about the highest mountains in the world." or "Show me the latest news articles about climate change and its impact on polar ice caps."\''; + this.schema = z.object({ + query: z + .string() + .describe( + 'A properly written sentence to be interpreted by an AI to search the web according to the user\'s request.', + ), + }); + + this.apiKey = fields?.TRAVERSAAL_API_KEY ?? this.getApiKey(); + } + + getApiKey() { + const apiKey = getEnvironmentVariable('TRAVERSAAL_API_KEY'); + if (!apiKey && this.override) { + throw new Error( + 'No Traversaal API key found. Either set an environment variable named "TRAVERSAAL_API_KEY" or pass an API key as "apiKey".', + ); + } + return apiKey; + } + + // eslint-disable-next-line no-unused-vars + async _call({ query }, _runManager) { + const body = { + query: [query], + }; + try { + const response = await fetch('https://api-ares.traversaal.ai/live/predict', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'x-api-key': this.apiKey, + }, + body: JSON.stringify({ ...body }), + }); + const json = await response.json(); + if (!response.ok) { + throw new Error( + `Request failed with status code ${response.status}: ${json.error ?? json.message}`, + ); + } + if (!json.data) { + throw new Error('Could not parse Traversaal API results. Please try again.'); + } + + const baseText = json.data?.response_text ?? ''; + const sources = json.data?.web_url; + const noResponse = 'No response found in Traversaal API results'; + + if (!baseText && !sources) { + return noResponse; + } + + const sourcesText = sources?.length ? '\n\nSources:\n - ' + sources.join('\n - ') : ''; + + const result = baseText + sourcesText; + + if (!result) { + return noResponse; + } + + return result; + } catch (error) { + logger.error('Traversaal API request failed', error); + return `Traversaal API request failed: ${error.message}`; + } + } +} + +module.exports = TraversaalSearch; diff --git a/api/app/clients/tools/structured/Wolfram.js b/api/app/clients/tools/structured/Wolfram.js index 2c5c6e023a1..fc857b35cb2 100644 --- a/api/app/clients/tools/structured/Wolfram.js +++ b/api/app/clients/tools/structured/Wolfram.js @@ -7,6 +7,9 @@ const { logger } = require('~/config'); class WolframAlphaAPI extends StructuredTool { constructor(fields) { super(); + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + this.name = 'wolfram'; this.apiKey = fields.WOLFRAM_APP_ID || this.getAppId(); this.description_for_model = `// Access dynamic computation and curated data from WolframAlpha and Wolfram Cloud. @@ -55,7 +58,7 @@ class WolframAlphaAPI extends StructuredTool { getAppId() { const appId = process.env.WOLFRAM_APP_ID || ''; - if (!appId) { + if (!appId && !this.override) { throw new Error('Missing WOLFRAM_APP_ID environment variable.'); } return appId; diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index 58771b1459e..1b28de2faf1 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -1,14 +1,11 @@ const OpenAI = require('openai'); const DALLE3 = require('../DALLE3'); -const { processFileURL } = require('~/server/services/Files/process'); const { logger } = require('~/config'); jest.mock('openai'); -jest.mock('~/server/services/Files/process', () => ({ - processFileURL: jest.fn(), -})); +const processFileURL = jest.fn(); jest.mock('~/server/services/Files/images', () => ({ getImageBasename: jest.fn().mockImplementation((url) => { @@ -69,7 +66,7 @@ describe('DALLE3', () => { jest.resetModules(); process.env = { ...originalEnv, DALLE_API_KEY: mockApiKey }; // Instantiate DALLE3 for tests that do not depend on DALLE3_SYSTEM_PROMPT - dalle = new DALLE3(); + dalle = new DALLE3({ processFileURL }); }); afterEach(() => { @@ -78,7 +75,8 @@ describe('DALLE3', () => { process.env = originalEnv; }); - it('should throw an error if DALLE_API_KEY is missing', () => { + it('should throw an error if all potential API keys are missing', () => { + delete process.env.DALLE3_API_KEY; delete process.env.DALLE_API_KEY; expect(() => new DALLE3()).toThrow('Missing DALLE_API_KEY environment variable.'); }); @@ -112,7 +110,9 @@ describe('DALLE3', () => { }; generate.mockResolvedValue(mockResponse); - processFileURL.mockResolvedValue('http://example.com/img-test.png'); + processFileURL.mockResolvedValue({ + filepath: 'http://example.com/img-test.png', + }); const result = await dalle._call(mockData); diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 1d9a3a00749..7ed18658711 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -6,19 +6,23 @@ const { OpenAIEmbeddings } = require('langchain/embeddings/openai'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools, + // Basic Tools + CodeBrew, + AzureAISearch, GoogleSearchAPI, WolframAlphaAPI, - StructuredWolfram, OpenAICreateImage, StableDiffusionAPI, + // Structured Tools DALLE3, - StructuredSD, - AzureAISearch, - StructuredACS, E2BTools, CodeSherpa, + StructuredSD, + StructuredACS, CodeSherpaTools, - CodeBrew, + TraversaalSearch, + StructuredWolfram, + TavilySearchResults, } = require('../'); const { loadToolSuite } = require('./loadToolSuite'); const { loadSpecs } = require('./loadSpecs'); @@ -30,6 +34,14 @@ const getOpenAIKey = async (options, user) => { return openAIApiKey || (await getUserPluginAuthValue(user, 'OPENAI_API_KEY')); }; +/** + * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. + * Tools without required authentication or with valid authentication are considered valid. + * + * @param {Object} user The user object for whom to validate tool access. + * @param {Array<string>} tools An array of tool identifiers to validate. Defaults to an empty array. + * @returns {Promise<Array<string>>} A promise that resolves to an array of valid tool identifiers. + */ const validateTools = async (user, tools = []) => { try { const validToolsSet = new Set(tools); @@ -37,16 +49,34 @@ const validateTools = async (user, tools = []) => { validToolsSet.has(tool.pluginKey), ); + /** + * Validates the credentials for a given auth field or set of alternate auth fields for a tool. + * If valid admin or user authentication is found, the function returns early. Otherwise, it removes the tool from the set of valid tools. + * + * @param {string} authField The authentication field or fields (separated by "||" for alternates) to validate. + * @param {string} toolName The identifier of the tool being validated. + */ const validateCredentials = async (authField, toolName) => { - const adminAuth = process.env[authField]; - if (adminAuth && adminAuth.length > 0) { - return; + const fields = authField.split('||'); + for (const field of fields) { + const adminAuth = process.env[field]; + if (adminAuth && adminAuth.length > 0) { + return; + } + + let userAuth = null; + try { + userAuth = await getUserPluginAuthValue(user, field); + } catch (err) { + if (field === fields[fields.length - 1] && !userAuth) { + throw err; + } + } + if (userAuth && userAuth.length > 0) { + return; + } } - const userAuth = await getUserPluginAuthValue(user, authField); - if (userAuth && userAuth.length > 0) { - return; - } validToolsSet.delete(toolName); }; @@ -63,20 +93,55 @@ const validateTools = async (user, tools = []) => { return Array.from(validToolsSet.values()); } catch (err) { logger.error('[validateTools] There was a problem validating tools', err); - throw new Error(err); + throw new Error('There was a problem validating tools'); } }; -const loadToolWithAuth = async (userId, authFields, ToolConstructor, options = {}) => { +/** + * Initializes a tool with authentication values for the given user, supporting alternate authentication fields. + * Authentication fields can have alternates separated by "||", and the first defined variable will be used. + * + * @param {string} userId The user ID for which the tool is being loaded. + * @param {Array<string>} authFields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". + * @param {typeof import('langchain/tools').Tool} ToolConstructor The constructor function for the tool to be initialized. + * @param {Object} options Optional parameters to be passed to the tool constructor alongside authentication values. + * @returns {Function} An Async function that, when called, asynchronously initializes and returns an instance of the tool with authentication. + */ +const loadToolWithAuth = (userId, authFields, ToolConstructor, options = {}) => { return async function () { let authValues = {}; - for (const authField of authFields) { - let authValue = process.env[authField]; - if (!authValue) { - authValue = await getUserPluginAuthValue(userId, authField); + /** + * Finds the first non-empty value for the given authentication field, supporting alternate fields. + * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". + * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. + */ + const findAuthValue = async (fields) => { + for (const field of fields) { + let value = process.env[field]; + if (value) { + return { authField: field, authValue: value }; + } + try { + value = await getUserPluginAuthValue(userId, field); + } catch (err) { + if (field === fields[fields.length - 1] && !value) { + throw err; + } + } + if (value) { + return { authField: field, authValue: value }; + } + } + return null; + }; + + for (let authField of authFields) { + const fields = authField.split('||'); + const result = await findAuthValue(fields); + if (result) { + authValues[result.authField] = result.authValue; } - authValues[authField] = authValue; } return new ToolConstructor({ ...options, ...authValues, userId }); @@ -90,8 +155,10 @@ const loadTools = async ({ returnMap = false, tools = [], options = {}, + skipSpecs = false, }) => { const toolConstructors = { + tavily_search_results_json: TavilySearchResults, calculator: Calculator, google: GoogleSearchAPI, wolfram: functions ? StructuredWolfram : WolframAlphaAPI, @@ -99,6 +166,7 @@ const loadTools = async ({ 'stable-diffusion': functions ? StructuredSD : StableDiffusionAPI, 'azure-ai-search': functions ? StructuredACS : AzureAISearch, CodeBrew: CodeBrew, + traversaal_search: TraversaalSearch, }; const openAIApiKey = await getOpenAIKey(options, user); @@ -168,10 +236,19 @@ const loadTools = async ({ toolConstructors.codesherpa = CodeSherpa; } + const imageGenOptions = { + req: options.req, + fileStrategy: options.fileStrategy, + processFileURL: options.processFileURL, + returnMetadata: options.returnMetadata, + uploadImageBuffer: options.uploadImageBuffer, + }; + const toolOptions = { serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, - dalle: { fileStrategy: options.fileStrategy }, - 'dall-e': { fileStrategy: options.fileStrategy }, + dalle: imageGenOptions, + 'dall-e': imageGenOptions, + 'stable-diffusion': imageGenOptions, }; const toolAuthFields = {}; @@ -194,7 +271,7 @@ const loadTools = async ({ if (toolConstructors[tool]) { const options = toolOptions[tool] || {}; - const toolInstance = await loadToolWithAuth( + const toolInstance = loadToolWithAuth( user, toolAuthFields[tool], toolConstructors[tool], @@ -210,7 +287,7 @@ const loadTools = async ({ } let specs = null; - if (functions && remainingTools.length > 0) { + if (functions && remainingTools.length > 0 && skipSpecs !== true) { specs = await loadSpecs({ llm: model, user, @@ -237,6 +314,9 @@ const loadTools = async ({ let result = []; for (const tool of tools) { const validTool = requestedTools[tool]; + if (!validTool) { + continue; + } const plugin = await validTool(); if (Array.isArray(plugin)) { @@ -250,6 +330,7 @@ const loadTools = async ({ }; module.exports = { + loadToolWithAuth, validateTools, loadTools, }; diff --git a/api/app/clients/tools/util/handleTools.test.js b/api/app/clients/tools/util/handleTools.test.js index 40d8bc6129e..2c977714275 100644 --- a/api/app/clients/tools/util/handleTools.test.js +++ b/api/app/clients/tools/util/handleTools.test.js @@ -4,26 +4,33 @@ const mockUser = { findByIdAndDelete: jest.fn(), }; -var mockPluginService = { +const mockPluginService = { updateUserPluginAuth: jest.fn(), deleteUserPluginAuth: jest.fn(), getUserPluginAuthValue: jest.fn(), }; -jest.mock('../../../../models/User', () => { +jest.mock('~/models/User', () => { return function () { return mockUser; }; }); -jest.mock('../../../../server/services/PluginService', () => mockPluginService); +jest.mock('~/server/services/PluginService', () => mockPluginService); -const User = require('../../../../models/User'); -const { validateTools, loadTools } = require('./'); -const PluginService = require('../../../../server/services/PluginService'); -const { BaseChatModel } = require('langchain/chat_models/openai'); const { Calculator } = require('langchain/tools/calculator'); -const { availableTools, OpenAICreateImage, GoogleSearchAPI, StructuredSD } = require('../'); +const { BaseChatModel } = require('langchain/chat_models/openai'); + +const User = require('~/models/User'); +const PluginService = require('~/server/services/PluginService'); +const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools'); +const { + availableTools, + OpenAICreateImage, + GoogleSearchAPI, + StructuredSD, + WolframAlphaAPI, +} = require('../'); describe('Tool Handlers', () => { let fakeUser; @@ -44,7 +51,10 @@ describe('Tool Handlers', () => { }); mockPluginService.updateUserPluginAuth.mockImplementation( (userId, authField, _pluginKey, credential) => { - userAuthValues[`${userId}-${authField}`] = credential; + const fields = authField.split('||'); + fields.forEach((field) => { + userAuthValues[`${userId}-${field}`] = credential; + }); }, ); @@ -53,6 +63,7 @@ describe('Tool Handlers', () => { username: 'fakeuser', email: 'fakeuser@example.com', emailVerified: false, + // file deepcode ignore NoHardcodedPasswords/test: fake value password: 'fakepassword123', avatar: '', provider: 'local', @@ -133,6 +144,18 @@ describe('Tool Handlers', () => { loadTool2 = toolFunctions[sampleTools[1]]; loadTool3 = toolFunctions[sampleTools[2]]; }); + + let originalEnv; + + beforeEach(() => { + originalEnv = process.env; + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + it('returns the expected load functions for requested tools', async () => { expect(loadTool1).toBeDefined(); expect(loadTool2).toBeDefined(); @@ -149,6 +172,86 @@ describe('Tool Handlers', () => { expect(authTool).toBeInstanceOf(ToolClass); expect(tool).toBeInstanceOf(ToolClass2); }); + + it('should initialize an authenticated tool with primary auth field', async () => { + process.env.DALLE2_API_KEY = 'mocked_api_key'; + const initToolFunction = loadToolWithAuth( + 'userId', + ['DALLE2_API_KEY||DALLE_API_KEY'], + ToolClass, + ); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(ToolClass); + expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalled(); + }); + + it('should initialize an authenticated tool with alternate auth field when primary is missing', async () => { + delete process.env.DALLE2_API_KEY; // Ensure the primary key is not set + process.env.DALLE_API_KEY = 'mocked_alternate_api_key'; + const initToolFunction = loadToolWithAuth( + 'userId', + ['DALLE2_API_KEY||DALLE_API_KEY'], + ToolClass, + ); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(ToolClass); + expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(1); + expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledWith( + 'userId', + 'DALLE2_API_KEY', + ); + }); + + it('should fallback to getUserPluginAuthValue when env vars are missing', async () => { + mockPluginService.updateUserPluginAuth('userId', 'DALLE_API_KEY', 'dalle', 'mocked_api_key'); + const initToolFunction = loadToolWithAuth( + 'userId', + ['DALLE2_API_KEY||DALLE_API_KEY'], + ToolClass, + ); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(ToolClass); + expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(2); + }); + + it('should initialize an authenticated tool with singular auth field', async () => { + process.env.WOLFRAM_APP_ID = 'mocked_app_id'; + const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(WolframAlphaAPI); + expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalled(); + }); + + it('should initialize an authenticated tool when env var is set', async () => { + process.env.WOLFRAM_APP_ID = 'mocked_app_id'; + const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(WolframAlphaAPI); + expect(mockPluginService.getUserPluginAuthValue).not.toHaveBeenCalledWith( + 'userId', + 'WOLFRAM_APP_ID', + ); + }); + + it('should fallback to getUserPluginAuthValue when singular env var is missing', async () => { + delete process.env.WOLFRAM_APP_ID; // Ensure the environment variable is not set + mockPluginService.getUserPluginAuthValue.mockResolvedValue('mocked_user_auth_value'); + const initToolFunction = loadToolWithAuth('userId', ['WOLFRAM_APP_ID'], WolframAlphaAPI); + const authTool = await initToolFunction(); + + expect(authTool).toBeInstanceOf(WolframAlphaAPI); + expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledTimes(1); + expect(mockPluginService.getUserPluginAuthValue).toHaveBeenCalledWith( + 'userId', + 'WOLFRAM_APP_ID', + ); + }); + it('should throw an error for an unauthenticated tool', async () => { try { await loadTool2(); diff --git a/api/app/clients/tools/util/loadToolSuite.js b/api/app/clients/tools/util/loadToolSuite.js index 2b4500a4f77..4392d61b9a6 100644 --- a/api/app/clients/tools/util/loadToolSuite.js +++ b/api/app/clients/tools/util/loadToolSuite.js @@ -1,17 +1,49 @@ -const { getUserPluginAuthValue } = require('../../../../server/services/PluginService'); +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools } = require('../'); +const { logger } = require('~/config'); -const loadToolSuite = async ({ pluginKey, tools, user, options }) => { +/** + * Loads a suite of tools with authentication values for a given user, supporting alternate authentication fields. + * Authentication fields can have alternates separated by "||", and the first defined variable will be used. + * + * @param {Object} params Parameters for loading the tool suite. + * @param {string} params.pluginKey Key identifying the plugin whose tools are to be loaded. + * @param {Array<Function>} params.tools Array of tool constructor functions. + * @param {Object} params.user User object for whom the tools are being loaded. + * @param {Object} [params.options={}] Optional parameters to be passed to each tool constructor. + * @returns {Promise<Array>} A promise that resolves to an array of instantiated tools. + */ +const loadToolSuite = async ({ pluginKey, tools, user, options = {} }) => { const authConfig = availableTools.find((tool) => tool.pluginKey === pluginKey).authConfig; const suite = []; const authValues = {}; + const findAuthValue = async (authField) => { + const fields = authField.split('||'); + for (const field of fields) { + let value = process.env[field]; + if (value) { + return value; + } + try { + value = await getUserPluginAuthValue(user, field); + if (value) { + return value; + } + } catch (err) { + logger.error(`Error fetching plugin auth value for ${field}: ${err.message}`); + } + } + return null; + }; + for (const auth of authConfig) { - let authValue = process.env[auth.authField]; - if (!authValue) { - authValue = await getUserPluginAuthValue(user, auth.authField); + const authValue = await findAuthValue(auth.authField); + if (authValue !== null) { + authValues[auth.authField] = authValue; + } else { + logger.warn(`[loadToolSuite] No auth value found for ${auth.authField}`); } - authValues[auth.authField] = authValue; } for (const tool of tools) { diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 016c7700009..786bb1f1f74 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,5 +1,5 @@ const Keyv = require('keyv'); -const { CacheKeys } = require('librechat-data-provider'); +const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); const { math, isEnabled } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); @@ -23,6 +23,22 @@ const config = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.CONFIG_STORE }); +const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes + ? new Keyv({ store: keyvRedis, ttl: 1800000 }) + : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 }); + +const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes + ? new Keyv({ store: keyvRedis, ttl: 120000 }) + : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 }); + +const modelQueries = isEnabled(process.env.USE_REDIS) + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.MODEL_QUERIES }); + +const abortKeys = isEnabled(USE_REDIS) + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 }); + const namespaces = { [CacheKeys.CONFIG_STORE]: config, pending_req, @@ -31,9 +47,17 @@ const namespaces = { concurrent: createViolationInstance('concurrent'), non_browser: createViolationInstance('non_browser'), message_limit: createViolationInstance('message_limit'), - token_balance: createViolationInstance('token_balance'), + token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE), registrations: createViolationInstance('registrations'), + [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT), + [ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance( + ViolationTypes.ILLEGAL_MODEL_REQUEST, + ), logins: createViolationInstance('logins'), + [CacheKeys.ABORT_KEYS]: abortKeys, + [CacheKeys.TOKEN_CONFIG]: tokenConfig, + [CacheKeys.GEN_TITLE]: genTitle, + [CacheKeys.MODEL_QUERIES]: modelQueries, }; /** diff --git a/api/config/parsers.js b/api/config/parsers.js index 59685eab0bf..16c85cba4f4 100644 --- a/api/config/parsers.js +++ b/api/config/parsers.js @@ -33,6 +33,10 @@ function getMatchingSensitivePatterns(valueStr) { * @returns {string} - The redacted console message. */ function redactMessage(str) { + if (!str) { + return ''; + } + const patterns = getMatchingSensitivePatterns(str); if (patterns.length === 0) { diff --git a/api/config/paths.js b/api/config/paths.js index 41e3ac5054f..165e9e6cd4f 100644 --- a/api/config/paths.js +++ b/api/config/paths.js @@ -1,7 +1,14 @@ const path = require('path'); module.exports = { + root: path.resolve(__dirname, '..', '..'), + uploads: path.resolve(__dirname, '..', '..', 'uploads'), + clientPath: path.resolve(__dirname, '..', '..', 'client'), dist: path.resolve(__dirname, '..', '..', 'client', 'dist'), publicPath: path.resolve(__dirname, '..', '..', 'client', 'public'), + fonts: path.resolve(__dirname, '..', '..', 'client', 'public', 'fonts'), + assets: path.resolve(__dirname, '..', '..', 'client', 'public', 'assets'), imageOutput: path.resolve(__dirname, '..', '..', 'client', 'public', 'images'), + structuredTools: path.resolve(__dirname, '..', 'app', 'clients', 'tools', 'structured'), + pluginManifest: path.resolve(__dirname, '..', 'app', 'clients', 'tools', 'manifest.json'), }; diff --git a/api/config/winston.js b/api/config/winston.js index 6cba153f163..81e972fbbc3 100644 --- a/api/config/winston.js +++ b/api/config/winston.js @@ -5,7 +5,15 @@ const { redactFormat, redactMessage, debugTraverse } = require('./parsers'); const logDir = path.join(__dirname, '..', 'logs'); -const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false } = process.env; +const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false, CONSOLE_JSON = false } = process.env; + +const useConsoleJson = + (typeof CONSOLE_JSON === 'string' && CONSOLE_JSON?.toLowerCase() === 'true') || + CONSOLE_JSON === true; + +const useDebugConsole = + (typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') || + DEBUG_CONSOLE === true; const levels = { error: 0, @@ -33,7 +41,7 @@ const level = () => { const fileFormat = winston.format.combine( redactFormat(), - winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }), + winston.format.timestamp({ format: () => new Date().toISOString() }), winston.format.errors({ stack: true }), winston.format.splat(), // redactErrors(), @@ -99,14 +107,20 @@ const consoleFormat = winston.format.combine( }), ); -if ( - (typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') || - DEBUG_CONSOLE === true -) { +if (useDebugConsole) { transports.push( new winston.transports.Console({ level: 'debug', - format: winston.format.combine(consoleFormat, debugTraverse), + format: useConsoleJson + ? winston.format.combine(fileFormat, debugTraverse, winston.format.json()) + : winston.format.combine(fileFormat, debugTraverse), + }), + ); +} else if (useConsoleJson) { + transports.push( + new winston.transports.Console({ + level: 'info', + format: winston.format.combine(fileFormat, winston.format.json()), }), ); } else { diff --git a/api/models/Action.js b/api/models/Action.js new file mode 100644 index 00000000000..5141569c103 --- /dev/null +++ b/api/models/Action.js @@ -0,0 +1,68 @@ +const mongoose = require('mongoose'); +const actionSchema = require('./schema/action'); + +const Action = mongoose.model('action', actionSchema); + +/** + * Update an action with new data without overwriting existing properties, + * or create a new action if it doesn't exist. + * + * @param {Object} searchParams - The search parameters to find the action to update. + * @param {string} searchParams.action_id - The ID of the action to update. + * @param {string} searchParams.user - The user ID of the action's author. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise<Object>} The updated or newly created action document as a plain object. + */ +const updateAction = async (searchParams, updateData) => { + return await Action.findOneAndUpdate(searchParams, updateData, { + new: true, + upsert: true, + }).lean(); +}; + +/** + * Retrieves all actions that match the given search parameters. + * + * @param {Object} searchParams - The search parameters to find matching actions. + * @param {boolean} includeSensitive - Flag to include sensitive data in the metadata. + * @returns {Promise<Array<Object>>} A promise that resolves to an array of action documents as plain objects. + */ +const getActions = async (searchParams, includeSensitive = false) => { + const actions = await Action.find(searchParams).lean(); + + if (!includeSensitive) { + for (let i = 0; i < actions.length; i++) { + const metadata = actions[i].metadata; + if (!metadata) { + continue; + } + + const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; + for (let field of sensitiveFields) { + if (metadata[field]) { + delete metadata[field]; + } + } + } + } + + return actions; +}; + +/** + * Deletes an action by its ID. + * + * @param {Object} searchParams - The search parameters to find the action to update. + * @param {string} searchParams.action_id - The ID of the action to update. + * @param {string} searchParams.user - The user ID of the action's author. + * @returns {Promise<Object>} A promise that resolves to the deleted action document as a plain object, or null if no document was found. + */ +const deleteAction = async (searchParams) => { + return await Action.findOneAndDelete(searchParams).lean(); +}; + +module.exports = { + updateAction, + getActions, + deleteAction, +}; diff --git a/api/models/Assistant.js b/api/models/Assistant.js new file mode 100644 index 00000000000..fa6192eee93 --- /dev/null +++ b/api/models/Assistant.js @@ -0,0 +1,47 @@ +const mongoose = require('mongoose'); +const assistantSchema = require('./schema/assistant'); + +const Assistant = mongoose.model('assistant', assistantSchema); + +/** + * Update an assistant with new data without overwriting existing properties, + * or create a new assistant if it doesn't exist. + * + * @param {Object} searchParams - The search parameters to find the assistant to update. + * @param {string} searchParams.assistant_id - The ID of the assistant to update. + * @param {string} searchParams.user - The user ID of the assistant's author. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise<Object>} The updated or newly created assistant document as a plain object. + */ +const updateAssistant = async (searchParams, updateData) => { + return await Assistant.findOneAndUpdate(searchParams, updateData, { + new: true, + upsert: true, + }).lean(); +}; + +/** + * Retrieves an assistant document based on the provided ID. + * + * @param {Object} searchParams - The search parameters to find the assistant to update. + * @param {string} searchParams.assistant_id - The ID of the assistant to update. + * @param {string} searchParams.user - The user ID of the assistant's author. + * @returns {Promise<Object|null>} The assistant document as a plain object, or null if not found. + */ +const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean(); + +/** + * Retrieves all assistants that match the given search parameters. + * + * @param {Object} searchParams - The search parameters to find matching assistants. + * @returns {Promise<Array<Object>>} A promise that resolves to an array of action documents as plain objects. + */ +const getAssistants = async (searchParams) => { + return await Assistant.find(searchParams).lean(); +}; + +module.exports = { + updateAssistant, + getAssistants, + getAssistant, +}; diff --git a/api/models/Balance.js b/api/models/Balance.js index 45dec696304..24d9087b77f 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -10,8 +10,9 @@ balanceSchema.statics.check = async function ({ valueKey, tokenType, amount, + endpointTokenConfig, }) { - const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint }); + const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); const tokenCost = amount * multiplier; const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {}; @@ -24,6 +25,7 @@ balanceSchema.statics.check = async function ({ amount, balance, multiplier, + endpointTokenConfig: !!endpointTokenConfig, }); if (!balance) { diff --git a/api/models/Conversation.js b/api/models/Conversation.js index f1aa7bfe718..1ef47241cac 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -30,12 +30,12 @@ module.exports = { return { message: 'Error saving conversation' }; } }, - getConvosByPage: async (user, pageNumber = 1, pageSize = 14) => { + getConvosByPage: async (user, pageNumber = 1, pageSize = 25) => { try { const totalConvos = (await Conversation.countDocuments({ user })) || 1; const totalPages = Math.ceil(totalConvos / pageSize); const convos = await Conversation.find({ user }) - .sort({ createdAt: -1 }) + .sort({ updatedAt: -1 }) .skip((pageNumber - 1) * pageSize) .limit(pageSize) .lean(); @@ -45,7 +45,7 @@ module.exports = { return { message: 'Error getting conversations' }; } }, - getConvosQueried: async (user, convoIds, pageNumber = 1, pageSize = 14) => { + getConvosQueried: async (user, convoIds, pageNumber = 1, pageSize = 25) => { try { if (!convoIds || convoIds.length === 0) { return { conversations: [], pages: 1, pageNumber, pageSize }; diff --git a/api/models/File.js b/api/models/File.js index 4c353fd70b0..16e9ab6a0e8 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -14,24 +14,32 @@ const findFileById = async (file_id, options = {}) => { }; /** - * Retrieves files matching a given filter. + * Retrieves files matching a given filter, sorted by the most recently updated. * @param {Object} filter - The filter criteria to apply. + * @param {Object} [_sortOptions] - Optional sort parameters. * @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents. */ -const getFiles = async (filter) => { - return await File.find(filter).lean(); +const getFiles = async (filter, _sortOptions) => { + const sortOptions = { updatedAt: -1, ..._sortOptions }; + return await File.find(filter).sort(sortOptions).lean(); }; /** * Creates a new file with a TTL of 1 hour. * @param {MongoFile} data - The file data to be created, must contain file_id. + * @param {boolean} disableTTL - Whether to disable the TTL. * @returns {Promise<MongoFile>} A promise that resolves to the created file document. */ -const createFile = async (data) => { +const createFile = async (data, disableTTL) => { const fileData = { ...data, expiresAt: new Date(Date.now() + 3600 * 1000), }; + + if (disableTTL) { + delete fileData.expiresAt; + } + return await File.findOneAndUpdate({ file_id: data.file_id }, fileData, { new: true, upsert: true, @@ -61,7 +69,7 @@ const updateFileUsage = async (data) => { const { file_id, inc = 1 } = data; const updateOperation = { $inc: { usage: inc }, - $unset: { expiresAt: '' }, + $unset: { expiresAt: '', temp_file_id: '' }, }; return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean(); }; @@ -75,6 +83,15 @@ const deleteFile = async (file_id) => { return await File.findOneAndDelete({ file_id }).lean(); }; +/** + * Deletes a file identified by a filter. + * @param {object} filter - The filter criteria to apply. + * @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null. + */ +const deleteFileByFilter = async (filter) => { + return await File.findOneAndDelete(filter).lean(); +}; + /** * Deletes multiple files identified by an array of file_ids. * @param {Array<string>} file_ids - The unique identifiers of the files to delete. @@ -93,4 +110,5 @@ module.exports = { updateFileUsage, deleteFile, deleteFiles, + deleteFileByFilter, }; diff --git a/api/models/Message.js b/api/models/Message.js index fe615f3283f..a8e1acdf149 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -72,11 +72,49 @@ module.exports = { throw new Error('Failed to save message.'); } }, + /** + * Records a message in the database. + * + * @async + * @function recordMessage + * @param {Object} params - The message data object. + * @param {string} params.user - The identifier of the user. + * @param {string} params.endpoint - The endpoint where the message originated. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.conversationId - The identifier of the conversation. + * @param {string} [params.parentMessageId] - The identifier of the parent message, if any. + * @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed. + * @returns {Promise<Object>} The updated or newly inserted message document. + * @throws {Error} If there is an error in saving the message. + */ + async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) { + try { + // No parsing of convoId as may use threadId + const message = { + user, + endpoint, + messageId, + conversationId, + parentMessageId, + ...rest, + }; + + return await Message.findOneAndUpdate({ user, messageId }, message, { + upsert: true, + new: true, + }); + } catch (err) { + logger.error('Error saving message:', err); + throw new Error('Failed to save message.'); + } + }, async updateMessage(message) { try { const { messageId, ...update } = message; update.isEdited = true; - const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, { new: true }); + const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, { + new: true, + }); if (!updatedMessage) { throw new Error('Message not found.'); diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 0bc26fc37ee..0d11ab5374c 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -2,6 +2,7 @@ const mongoose = require('mongoose'); const { isEnabled } = require('../server/utils/handleText'); const transactionSchema = require('./schema/transaction'); const { getMultiplier } = require('./tx'); +const { logger } = require('~/config'); const Balance = require('./Balance'); const cancelRate = 1.15; @@ -10,8 +11,8 @@ transactionSchema.methods.calculateTokenValue = function () { if (!this.valueKey || !this.tokenType) { this.tokenValue = this.rawAmount; } - const { valueKey, tokenType, model } = this; - const multiplier = getMultiplier({ valueKey, tokenType, model }); + const { valueKey, tokenType, model, endpointTokenConfig } = this; + const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig })); this.rate = multiplier; this.tokenValue = this.rawAmount * multiplier; if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { @@ -25,6 +26,7 @@ transactionSchema.statics.create = async function (transactionData) { const Transaction = this; const transaction = new Transaction(transactionData); + transaction.endpointTokenConfig = transactionData.endpointTokenConfig; transaction.calculateTokenValue(); // Save the transaction @@ -34,12 +36,44 @@ transactionSchema.statics.create = async function (transactionData) { return; } - // Adjust the user's balance - return await Balance.findOneAndUpdate( + let balance = await Balance.findOne({ user: transaction.user }).lean(); + let incrementValue = transaction.tokenValue; + + if (balance && balance?.tokenCredits + incrementValue < 0) { + incrementValue = -balance.tokenCredits; + } + + balance = await Balance.findOneAndUpdate( { user: transaction.user }, - { $inc: { tokenCredits: transaction.tokenValue } }, + { $inc: { tokenCredits: incrementValue } }, { upsert: true, new: true }, ).lean(); + + return { + rate: transaction.rate, + user: transaction.user.toString(), + balance: balance.tokenCredits, + [transaction.tokenType]: incrementValue, + }; }; -module.exports = mongoose.model('Transaction', transactionSchema); +const Transaction = mongoose.model('Transaction', transactionSchema); + +/** + * Queries and retrieves transactions based on a given filter. + * @async + * @function getTransactions + * @param {Object} filter - MongoDB filter object to apply when querying transactions. + * @returns {Promise<Array>} A promise that resolves to an array of matched transactions. + * @throws {Error} Throws an error if querying the database fails. + */ +async function getTransactions(filter) { + try { + return await Transaction.find(filter).lean(); + } catch (error) { + logger.error('Error querying transactions:', error); + throw error; + } +} + +module.exports = { Transaction, getTransactions }; diff --git a/api/models/checkBalance.js b/api/models/checkBalance.js index c0bbd060bfb..5af77bbb192 100644 --- a/api/models/checkBalance.js +++ b/api/models/checkBalance.js @@ -1,5 +1,6 @@ +const { ViolationTypes } = require('librechat-data-provider'); +const { logViolation } = require('~/cache'); const Balance = require('./Balance'); -const { logViolation } = require('../cache'); /** * Checks the balance for a user and determines if they can spend a certain amount. * If the user cannot spend the amount, it logs a violation and denies the request. @@ -14,6 +15,7 @@ const { logViolation } = require('../cache'); * @param {('prompt' | 'completion')} params.txData.tokenType - The type of token. * @param {number} params.txData.amount - The amount of tokens. * @param {string} params.txData.model - The model name or identifier. + * @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint. * @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request. * @throws {Error} Throws an error if there's an issue with the balance check. */ @@ -24,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => { return true; } - const type = 'token_balance'; + const type = ViolationTypes.TOKEN_BALANCE; const errorMessage = { type, balance, diff --git a/api/models/index.js b/api/models/index.js index 1fa7513540c..bf88193823e 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -1,12 +1,14 @@ const { getMessages, saveMessage, + recordMessage, updateMessage, deleteMessagesSince, deleteMessages, } = require('./Message'); const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); +const { hashPassword, getUser, updateUser } = require('./userMethods'); const { findFileById, createFile, @@ -20,17 +22,20 @@ const Key = require('./Key'); const User = require('./User'); const Session = require('./Session'); const Balance = require('./Balance'); -const Transaction = require('./Transaction'); module.exports = { User, Key, Session, Balance, - Transaction, + + hashPassword, + updateUser, + getUser, getMessages, saveMessage, + recordMessage, updateMessage, deleteMessagesSince, deleteMessages, diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js index abba8486148..79dd30b11cc 100644 --- a/api/models/plugins/mongoMeili.js +++ b/api/models/plugins/mongoMeili.js @@ -183,6 +183,15 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { if (object.conversationId && object.conversationId.includes('|')) { object.conversationId = object.conversationId.replace(/\|/g, '--'); } + + if (object.content && Array.isArray(object.content)) { + object.text = object.content + .filter((item) => item.type === 'text' && item.text && item.text.value) + .map((item) => item.text.value) + .join(' '); + delete object.content; + } + return object; } diff --git a/api/models/schema/action.js b/api/models/schema/action.js new file mode 100644 index 00000000000..9e9109adf78 --- /dev/null +++ b/api/models/schema/action.js @@ -0,0 +1,59 @@ +const mongoose = require('mongoose'); + +const { Schema } = mongoose; + +const AuthSchema = new Schema( + { + authorization_type: String, + custom_auth_header: String, + type: { + type: String, + enum: ['service_http', 'oauth', 'none'], + }, + authorization_content_type: String, + authorization_url: String, + client_url: String, + scope: String, + token_exchange_method: { + type: String, + enum: ['default_post', 'basic_auth_header', null], + }, + }, + { _id: false }, +); + +const actionSchema = new Schema({ + user: { + type: mongoose.Schema.Types.ObjectId, + ref: 'User', + index: true, + required: true, + }, + action_id: { + type: String, + index: true, + required: true, + }, + type: { + type: String, + default: 'action_prototype', + }, + settings: Schema.Types.Mixed, + assistant_id: String, + metadata: { + api_key: String, // private, encrypted + auth: AuthSchema, + domain: { + type: String, + required: true, + }, + // json_schema: Schema.Types.Mixed, + privacy_policy_url: String, + raw_spec: String, + oauth_client_id: String, // private, encrypted + oauth_client_secret: String, // private, encrypted + }, +}); +// }, { minimize: false }); // Prevent removal of empty objects + +module.exports = actionSchema; diff --git a/api/models/schema/assistant.js b/api/models/schema/assistant.js new file mode 100644 index 00000000000..67eb8e8e720 --- /dev/null +++ b/api/models/schema/assistant.js @@ -0,0 +1,33 @@ +const mongoose = require('mongoose'); + +const assistantSchema = mongoose.Schema( + { + user: { + type: mongoose.Schema.Types.ObjectId, + ref: 'User', + required: true, + }, + assistant_id: { + type: String, + index: true, + required: true, + }, + avatar: { + type: { + filepath: String, + source: String, + }, + default: undefined, + }, + access_level: { + type: Number, + }, + file_ids: { type: [String], default: undefined }, + actions: { type: [String], default: undefined }, + }, + { + timestamps: true, + }, +); + +module.exports = assistantSchema; diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js index a282287eccb..4810f68321a 100644 --- a/api/models/schema/convoSchema.js +++ b/api/models/schema/convoSchema.js @@ -55,7 +55,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { }); } -convoSchema.index({ createdAt: 1 }); +convoSchema.index({ createdAt: 1, updatedAt: 1 }); const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema); diff --git a/api/models/schema/defaults.js b/api/models/schema/defaults.js index 39a6430f46b..b2ea3a12c70 100644 --- a/api/models/schema/defaults.js +++ b/api/models/schema/defaults.js @@ -11,152 +11,137 @@ const conversationPreset = { // for azureOpenAI, openAI, chatGPTBrowser only model: { type: String, - // default: null, required: false, }, // for azureOpenAI, openAI only chatGptLabel: { type: String, - // default: null, required: false, }, // for google only modelLabel: { type: String, - // default: null, required: false, }, promptPrefix: { type: String, - // default: null, required: false, }, temperature: { type: Number, - // default: 1, required: false, }, top_p: { type: Number, - // default: 1, required: false, }, // for google only topP: { type: Number, - // default: 0.95, required: false, }, topK: { type: Number, - // default: 40, required: false, }, maxOutputTokens: { type: Number, - // default: 1024, required: false, }, presence_penalty: { type: Number, - // default: 0, required: false, }, frequency_penalty: { type: Number, - // default: 0, required: false, }, // for bingai only jailbreak: { type: Boolean, - // default: false, }, context: { type: String, - // default: null, }, systemMessage: { type: String, - // default: null, }, toneStyle: { type: String, - // default: null, }, + file_ids: { type: [{ type: String }], default: undefined }, + // deprecated resendImages: { type: Boolean, }, + // files + resendFiles: { + type: Boolean, + }, imageDetail: { type: String, }, + /* assistants */ + assistant_id: { + type: String, + }, + instructions: { + type: String, + }, }; const agentOptions = { model: { type: String, - // default: null, required: false, }, // for azureOpenAI, openAI only chatGptLabel: { type: String, - // default: null, required: false, }, modelLabel: { type: String, - // default: null, required: false, }, promptPrefix: { type: String, - // default: null, required: false, }, temperature: { type: Number, - // default: 1, required: false, }, top_p: { type: Number, - // default: 1, required: false, }, // for google only topP: { type: Number, - // default: 0.95, required: false, }, topK: { type: Number, - // default: 40, required: false, }, maxOutputTokens: { type: Number, - // default: 1024, required: false, }, presence_penalty: { type: Number, - // default: 0, required: false, }, frequency_penalty: { type: Number, - // default: 0, required: false, }, context: { type: String, - // default: null, }, systemMessage: { type: String, - // default: null, }, }; diff --git a/api/models/schema/fileSchema.js b/api/models/schema/fileSchema.js index 471b7bfd70a..93a8815e53b 100644 --- a/api/models/schema/fileSchema.js +++ b/api/models/schema/fileSchema.js @@ -3,6 +3,8 @@ const mongoose = require('mongoose'); /** * @typedef {Object} MongoFile + * @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID + * @property {number} [__v] - MongoDB Version Key * @property {mongoose.Schema.Types.ObjectId} user - User ID * @property {string} [conversationId] - Optional conversation ID * @property {string} file_id - File identifier @@ -13,10 +15,15 @@ const mongoose = require('mongoose'); * @property {'file'} object - Type of object, always 'file' * @property {string} type - Type of file * @property {number} usage - Number of uses of the file + * @property {string} [context] - Context of the file origin + * @property {boolean} [embedded] - Whether or not the file is embedded in vector db + * @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting) * @property {string} [source] - The source of the file * @property {number} [width] - Optional width of the file * @property {number} [height] - Optional height of the file * @property {Date} [expiresAt] - Optional height of the file + * @property {Date} [createdAt] - Date when the file was created + * @property {Date} [updatedAt] - Date when the file was updated */ const fileSchema = mongoose.Schema( { @@ -57,10 +64,17 @@ const fileSchema = mongoose.Schema( required: true, default: 'file', }, + embedded: { + type: Boolean, + }, type: { type: String, required: true, }, + context: { + type: String, + // required: true, + }, usage: { type: Number, required: true, @@ -70,6 +84,9 @@ const fileSchema = mongoose.Schema( type: String, default: FileSources.local, }, + model: { + type: String, + }, width: Number, height: Number, expiresAt: { diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index 06da19e476d..fc745499fe5 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -17,6 +17,7 @@ const messageSchema = mongoose.Schema( user: { type: String, index: true, + required: true, default: null, }, model: { @@ -46,12 +47,10 @@ const messageSchema = mongoose.Schema( }, sender: { type: String, - required: true, meiliIndex: true, }, text: { type: String, - required: true, meiliIndex: true, }, summary: { @@ -103,6 +102,14 @@ const messageSchema = mongoose.Schema( default: undefined, }, plugins: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, + content: { + type: [{ type: mongoose.Schema.Types.Mixed }], + default: undefined, + meiliIndex: true, + }, + thread_id: { + type: String, + }, }, { timestamps: true }, ); diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index fe3a2be87ae..e37aa41d0cc 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -1,4 +1,4 @@ -const Transaction = require('./Transaction'); +const { Transaction } = require('./Transaction'); const { logger } = require('~/config'); /** @@ -11,6 +11,7 @@ const { logger } = require('~/config'); * @param {String} txData.conversationId - The ID of the conversation. * @param {String} txData.model - The model name. * @param {String} txData.context - The context in which the transaction is made. + * @param {String} [txData.endpointTokenConfig] - The current endpoint token config. * @param {String} [txData.valueKey] - The value key (optional). * @param {Object} tokenUsage - The number of tokens used. * @param {Number} tokenUsage.promptTokens - The number of prompt tokens used. @@ -20,6 +21,15 @@ const { logger } = require('~/config'); */ const spendTokens = async (txData, tokenUsage) => { const { promptTokens, completionTokens } = tokenUsage; + logger.debug( + `[spendTokens] conversationId: ${txData.conversationId}${ + txData?.context ? ` | Context: ${txData?.context}` : '' + } | Token usage: `, + { + promptTokens, + completionTokens, + }, + ); let prompt, completion; try { if (promptTokens >= 0) { @@ -41,7 +51,16 @@ const spendTokens = async (txData, tokenUsage) => { rawAmount: -completionTokens, }); - logger.debug('[spendTokens] post-transaction', { prompt, completion }); + prompt && + completion && + logger.debug('[spendTokens] Transaction data record against balance:', { + user: prompt.user, + prompt: prompt.prompt, + promptRate: prompt.rate, + completion: completion.completion, + completionRate: completion.rate, + balance: completion.balance, + }); } catch (err) { logger.error('[spendTokens]', err); } diff --git a/api/models/tx.js b/api/models/tx.js index f6f3b7f5522..bc993290f44 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -3,6 +3,7 @@ const defaultRate = 6; /** * Mapping of model token sizes to their respective multipliers for prompt and completion. + * The rates are 1 USD per 1M tokens. * @type {Object.<string, {prompt: number, completion: number}>} */ const tokenValues = { @@ -12,6 +13,18 @@ const tokenValues = { '16k': { prompt: 3, completion: 4 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, 'gpt-4-1106': { prompt: 10, completion: 30 }, + 'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 }, + 'claude-3-opus': { prompt: 15, completion: 75 }, + 'claude-3-sonnet': { prompt: 3, completion: 15 }, + 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, + 'claude-2.1': { prompt: 8, completion: 24 }, + 'claude-2': { prompt: 8, completion: 24 }, + 'claude-': { prompt: 0.8, completion: 2.4 }, + 'command-r-plus': { prompt: 3, completion: 15 }, + 'command-r': { prompt: 0.5, completion: 1.5 }, + /* cohere doesn't have rates for the older command models, + so this was from https://artificialanalysis.ai/models/command-light/providers */ + command: { prompt: 0.38, completion: 0.38 }, }; /** @@ -29,16 +42,24 @@ const getValueKey = (model, endpoint) => { if (modelName.includes('gpt-3.5-turbo-16k')) { return '16k'; + } else if (modelName.includes('gpt-3.5-turbo-0125')) { + return 'gpt-3.5-turbo-0125'; } else if (modelName.includes('gpt-3.5-turbo-1106')) { return 'gpt-3.5-turbo-1106'; } else if (modelName.includes('gpt-3.5')) { return '4k'; } else if (modelName.includes('gpt-4-1106')) { return 'gpt-4-1106'; + } else if (modelName.includes('gpt-4-0125')) { + return 'gpt-4-1106'; + } else if (modelName.includes('gpt-4-turbo')) { + return 'gpt-4-1106'; } else if (modelName.includes('gpt-4-32k')) { return '32k'; } else if (modelName.includes('gpt-4')) { return '8k'; + } else if (tokenValues[modelName]) { + return modelName; } return undefined; @@ -53,9 +74,14 @@ const getValueKey = (model, endpoint) => { * @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion'). * @param {string} [params.model] - The model name to derive the value key from if not provided. * @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided. + * @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint. * @returns {number} The multiplier for the given parameters, or a default value if not found. */ -const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => { +const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConfig }) => { + if (endpointTokenConfig) { + return endpointTokenConfig?.[model]?.[tokenType] ?? defaultRate; + } + if (valueKey && tokenType) { return tokenValues[valueKey][tokenType] ?? defaultRate; } diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index 135298bf2b2..36533a11dd4 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -84,6 +84,15 @@ describe('getMultiplier', () => { expect(getMultiplier({ tokenType: 'completion', model: 'gpt-4-1106-vision-preview' })).toBe( tokenValues['gpt-4-1106'].completion, ); + expect(getMultiplier({ tokenType: 'completion', model: 'gpt-4-0125-preview' })).toBe( + tokenValues['gpt-4-1106'].completion, + ); + expect(getMultiplier({ tokenType: 'completion', model: 'gpt-4-turbo-vision-preview' })).toBe( + tokenValues['gpt-4-1106'].completion, + ); + expect(getMultiplier({ tokenType: 'completion', model: 'gpt-3.5-turbo-0125' })).toBe( + tokenValues['gpt-3.5-turbo-0125'].completion, + ); }); it('should return defaultRate if derived valueKey does not match any known patterns', () => { diff --git a/api/models/userMethods.js b/api/models/userMethods.js new file mode 100644 index 00000000000..c1ccce5b523 --- /dev/null +++ b/api/models/userMethods.js @@ -0,0 +1,46 @@ +const bcrypt = require('bcryptjs'); +const User = require('./User'); + +const hashPassword = async (password) => { + const hashedPassword = await new Promise((resolve, reject) => { + bcrypt.hash(password, 10, function (err, hash) { + if (err) { + reject(err); + } else { + resolve(hash); + } + }); + }); + + return hashedPassword; +}; + +/** + * Retrieve a user by ID and convert the found user document to a plain object. + * + * @param {string} userId - The ID of the user to find and return as a plain object. + * @returns {Promise<Object>} A plain object representing the user document, or `null` if no user is found. + */ +const getUser = async function (userId) { + return await User.findById(userId).lean(); +}; + +/** + * Update a user with new data without overwriting existing properties. + * + * @param {string} userId - The ID of the user to update. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise<Object>} The updated user document as a plain object, or `null` if no user is found. + */ +const updateUser = async function (userId, updateData) { + return await User.findByIdAndUpdate(userId, updateData, { + new: true, + runValidators: true, + }).lean(); +}; + +module.exports = { + hashPassword, + updateUser, + getUser, +}; diff --git a/api/package.json b/api/package.json index 292a3f5a1c3..31df31f7c5f 100644 --- a/api/package.json +++ b/api/package.json @@ -1,13 +1,19 @@ { "name": "@librechat/backend", - "version": "0.6.6", + "version": "0.7.0", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", "server-dev": "echo 'please run this from the root directory'", "test": "cross-env NODE_ENV=test jest", "b:test": "NODE_ENV=test bun jest", - "test:ci": "jest --ci" + "test:ci": "jest --ci", + "add-balance": "node ./add-balance.js", + "list-balances": "node ./list-balances.js", + "user-stats": "node ./user-stats.js", + "create-user": "node ./create-user.js", + "ban-user": "node ./ban-user.js", + "delete-user": "node ./delete-user.js" }, "repository": { "type": "git", @@ -25,17 +31,18 @@ "bugs": { "url": "https://github.com/danny-avila/LibreChat/issues" }, - "homepage": "https://github.com/danny-avila/LibreChat#readme", + "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.5.4", + "@anthropic-ai/sdk": "^0.16.1", "@azure/search-documents": "^12.0.0", "@keyv/mongo": "^2.1.8", "@keyv/redis": "^2.8.1", - "@langchain/google-genai": "^0.0.7", + "@langchain/community": "^0.0.17", + "@langchain/google-genai": "^0.0.8", "axios": "^1.3.4", "bcryptjs": "^2.4.3", "cheerio": "^1.0.0-rc.12", - "cohere-ai": "^6.0.0", + "cohere-ai": "^7.9.1", "connect-redis": "^7.1.0", "cookie": "^0.5.0", "cors": "^2.8.5", @@ -44,6 +51,7 @@ "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^6.9.0", "express-session": "^1.17.3", + "file-type": "^18.7.0", "firebase": "^10.6.0", "googleapis": "^126.0.1", "handlebars": "^4.7.7", @@ -57,13 +65,14 @@ "langchain": "^0.0.214", "librechat-data-provider": "*", "lodash": "^4.17.21", - "meilisearch": "^0.33.0", + "meilisearch": "^0.38.0", + "mime": "^3.0.0", "module-alias": "^2.2.3", "mongoose": "^7.1.1", "multer": "^1.4.5-lts.1", "nodejs-gpt": "^1.37.4", "nodemailer": "^6.9.4", - "openai": "^4.20.1", + "openai": "^4.29.0", "openai-chat-tokens": "^0.2.8", "openid-client": "^5.4.2", "passport": "^0.6.0", diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 67d7c67e9f7..171e7aaae74 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,7 +1,8 @@ -const { getResponseSender } = require('librechat-data-provider'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const throttle = require('lodash/throttle'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); +const { sendMessage, createOnProgress } = require('~/server/utils'); +const { saveMessage, getConvo } = require('~/models'); const { logger } = require('~/config'); const AskController = async (req, res, next, initializeClient, addTitle) => { @@ -16,13 +17,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { logger.debug('[AskController]', { text, conversationId, ...endpointOption }); - let metadata; let userMessage; let promptTokens; let userMessageId; let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model, @@ -31,8 +29,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { const newConvo = !conversationId; const user = req.user.id; - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { for (let key in data) { if (key === 'userMessage') { @@ -54,11 +50,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { const { client } = await initializeClient({ req, res, endpointOption }); const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; + onProgress: throttle( + ({ text: partialText }) => { saveMessage({ messageId: responseMessageId, sender, @@ -70,12 +63,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { error: false, user, }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, + }, + 3000, + { trailing: false }, + ), }); getText = getPartialText; @@ -92,6 +83,20 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { const { abortController, onStart } = createAbortController(req, res, getAbortData); + res.on('close', () => { + logger.debug('[AskController] Request closed'); + if (!abortController) { + return; + } else if (abortController.signal.aborted) { + return; + } else if (abortController.requestCompleted) { + return; + } + + abortController.abort(); + logger.debug('[AskController] Request aborted on close'); + }); + const messageOptions = { user, parentMessageId, @@ -99,7 +104,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { overrideParentMessageId, getReqData, onStart, - addMetadata, abortController, onProgress: progressCallback.call(null, { res, @@ -114,22 +118,23 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { response.parentMessageId = overrideParentMessageId; } - if (metadata) { - response = { ...response, ...metadata }; - } - response.endpoint = endpointOption.endpoint; + const conversation = await getConvo(user, conversationId); + conversation.title = + conversation && !conversation.title ? null : conversation?.title || 'New Chat'; + if (client.options.attachments) { userMessage.files = client.options.attachments; + conversation.model = endpointOption.modelOptions.model; delete userMessage.image_urls; } if (!abortController.signal.aborted) { sendMessage(res, { - title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(user, conversationId), + conversation, + title: conversation.title, requestMessage: userMessage, responseMessage: response, }); @@ -140,7 +145,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { await saveMessage(userMessage); - if (addTitle && parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { + if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { addTitle(req, { text, response, diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 921ba3d8388..ee1751442c1 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -76,14 +76,14 @@ const refreshController = async (req, res) => { } try { - let payload; - payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); - const userId = payload.id; - const user = await User.findOne({ _id: userId }); + const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); + const user = await User.findOne({ _id: payload.id }); if (!user) { return res.status(401).redirect('/login'); } + const userId = payload.id; + if (process.env.NODE_ENV === 'CI') { const token = await setAuthTokens(userId, res); const userObj = user.toJSON(); @@ -118,6 +118,6 @@ module.exports = { getUserController, refreshController, registrationController, - resetPasswordRequestController, resetPasswordController, + resetPasswordRequestController, }; diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 43b82e7193f..28a35185ffb 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,7 +1,8 @@ +const throttle = require('lodash/throttle'); const { getResponseSender } = require('librechat-data-provider'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); const { createAbortController, handleAbortError } = require('~/server/middleware'); +const { sendMessage, createOnProgress } = require('~/server/utils'); +const { saveMessage, getConvo } = require('~/models'); const { logger } = require('~/config'); const EditController = async (req, res, next, initializeClient) => { @@ -25,11 +26,8 @@ const EditController = async (req, res, next, initializeClient) => { ...endpointOption, }); - let metadata; let userMessage; let promptTokens; - let lastSavedTimestamp = 0; - let saveDelay = 100; const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model, @@ -38,7 +36,6 @@ const EditController = async (req, res, next, initializeClient) => { const userMessageId = parentMessageId; const user = req.user.id; - const addMetadata = (data) => (metadata = data); const getReqData = (data = {}) => { for (let key in data) { if (key === 'userMessage') { @@ -53,11 +50,8 @@ const EditController = async (req, res, next, initializeClient) => { const { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; + onProgress: throttle( + ({ text: partialText }) => { saveMessage({ messageId: responseMessageId, sender, @@ -70,12 +64,10 @@ const EditController = async (req, res, next, initializeClient) => { error: false, user, }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, + }, + 3000, + { trailing: false }, + ), }); const getAbortData = () => ({ @@ -90,6 +82,20 @@ const EditController = async (req, res, next, initializeClient) => { const { abortController, onStart } = createAbortController(req, res, getAbortData); + res.on('close', () => { + logger.debug('[EditController] Request closed'); + if (!abortController) { + return; + } else if (abortController.signal.aborted) { + return; + } else if (abortController.requestCompleted) { + return; + } + + abortController.abort(); + logger.debug('[EditController] Request aborted on close'); + }); + try { const { client } = await initializeClient({ req, res, endpointOption }); @@ -104,7 +110,6 @@ const EditController = async (req, res, next, initializeClient) => { overrideParentMessageId, getReqData, onStart, - addMetadata, abortController, onProgress: progressCallback.call(null, { res, @@ -113,15 +118,19 @@ const EditController = async (req, res, next, initializeClient) => { }), }); - if (metadata) { - response = { ...response, ...metadata }; + const conversation = await getConvo(user, conversationId); + conversation.title = + conversation && !conversation.title ? null : conversation?.title || 'New Chat'; + + if (client.options.attachments) { + conversation.model = endpointOption.modelOptions.model; } if (!abortController.signal.aborted) { sendMessage(res, { - title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(user, conversationId), + conversation, + title: conversation.title, requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/controllers/EndpointController.js b/api/server/controllers/EndpointController.js index 5069bb33e0b..b99dd5eda9c 100644 --- a/api/server/controllers/EndpointController.js +++ b/api/server/controllers/EndpointController.js @@ -1,4 +1,4 @@ -const { CacheKeys } = require('librechat-data-provider'); +const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider'); const { loadDefaultEndpointsConfig, loadConfigEndpoints } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -10,10 +10,23 @@ async function endpointController(req, res) { return; } - const defaultEndpointsConfig = await loadDefaultEndpointsConfig(); - const customConfigEndpoints = await loadConfigEndpoints(); + const defaultEndpointsConfig = await loadDefaultEndpointsConfig(req); + const customConfigEndpoints = await loadConfigEndpoints(req); - const endpointsConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints }; + /** @type {TEndpointsConfig} */ + const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints }; + if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) { + const { disableBuilder, retrievalModels, capabilities, ..._rest } = + req.app.locals[EModelEndpoint.assistants]; + mergedConfig[EModelEndpoint.assistants] = { + ...mergedConfig[EModelEndpoint.assistants], + retrievalModels, + disableBuilder, + capabilities, + }; + } + + const endpointsConfig = orderEndpointsConfig(mergedConfig); await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); res.send(JSON.stringify(endpointsConfig)); diff --git a/api/server/controllers/ErrorController.js b/api/server/controllers/ErrorController.js index 1308527b8cd..234cb90fb37 100644 --- a/api/server/controllers/ErrorController.js +++ b/api/server/controllers/ErrorController.js @@ -3,23 +3,24 @@ const { logger } = require('~/config'); //handle duplicates const handleDuplicateKeyError = (err, res) => { logger.error('Duplicate key error:', err.keyValue); - const field = Object.keys(err.keyValue); + const field = `${JSON.stringify(Object.keys(err.keyValue))}`; const code = 409; - const error = `An document with that ${field} already exists.`; - res.status(code).send({ messages: error, fields: field }); + res + .status(code) + .send({ messages: `An document with that ${field} already exists.`, fields: field }); }; //handle validation errors const handleValidationError = (err, res) => { logger.error('Validation error:', err.errors); let errors = Object.values(err.errors).map((el) => el.message); - let fields = Object.values(err.errors).map((el) => el.path); + let fields = `${JSON.stringify(Object.values(err.errors).map((el) => el.path))}`; let code = 400; if (errors.length > 1) { - const formattedErrors = errors.join(' '); - res.status(code).send({ messages: formattedErrors, fields: fields }); + errors = errors.join(' '); + res.status(code).send({ messages: `${JSON.stringify(errors)}`, fields: fields }); } else { - res.status(code).send({ messages: errors, fields: fields }); + res.status(code).send({ messages: `${JSON.stringify(errors)}`, fields: fields }); } }; diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 2d23961e154..022ece4c103 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -2,20 +2,39 @@ const { CacheKeys } = require('librechat-data-provider'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); -async function modelController(req, res) { +const getModelsConfig = async (req) => { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + if (!modelsConfig) { + modelsConfig = await loadModels(req); + } + + return modelsConfig; +}; + +/** + * Loads the models from the config. + * @param {Express.Request} req - The Express request object. + * @returns {Promise<TModelsConfig>} The models config. + */ +async function loadModels(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); if (cachedModelsConfig) { - res.send(cachedModelsConfig); - return; + return cachedModelsConfig; } - const defaultModelsConfig = await loadDefaultModels(); - const customModelsConfig = await loadConfigModels(); + const defaultModelsConfig = await loadDefaultModels(req); + const customModelsConfig = await loadConfigModels(req); const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); + return modelConfig; +} + +async function modelController(req, res) { + const modelConfig = await loadModels(req); res.send(modelConfig); } -module.exports = modelController; +module.exports = { modelController, loadModels, getModelsConfig }; diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index c37b36974e0..803d89923ba 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,9 +1,14 @@ -const path = require('path'); const { promises: fs } = require('fs'); const { CacheKeys } = require('librechat-data-provider'); const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs'); const { getLogStores } = require('~/cache'); +/** + * Filters out duplicate plugins from the list of plugins. + * + * @param {TPlugin[]} plugins The list of plugins to filter. + * @returns {TPlugin[]} The list of plugins with duplicates removed. + */ const filterUniquePlugins = (plugins) => { const seen = new Set(); return plugins.filter((plugin) => { @@ -13,17 +18,31 @@ const filterUniquePlugins = (plugins) => { }); }; +/** + * Determines if a plugin is authenticated by checking if all required authentication fields have non-empty values. + * Supports alternate authentication fields, allowing validation against multiple possible environment variables. + * + * @param {TPlugin} plugin The plugin object containing the authentication configuration. + * @returns {boolean} True if the plugin is authenticated for all required fields, false otherwise. + */ const isPluginAuthenticated = (plugin) => { if (!plugin.authConfig || plugin.authConfig.length === 0) { return false; } return plugin.authConfig.every((authFieldObj) => { - const envValue = process.env[authFieldObj.authField]; - if (envValue === 'user_provided') { - return false; + const authFieldOptions = authFieldObj.authField.split('||'); + let isFieldAuthenticated = false; + + for (const fieldOption of authFieldOptions) { + const envValue = process.env[fieldOption]; + if (envValue && envValue.trim() !== '' && envValue !== 'user_provided') { + isFieldAuthenticated = true; + break; + } } - return envValue && envValue.trim() !== ''; + + return isFieldAuthenticated; }); }; @@ -36,12 +55,10 @@ const getAvailablePluginsController = async (req, res) => { return; } - const manifestFile = await fs.readFile( - path.join(__dirname, '..', '..', 'app', 'clients', 'tools', 'manifest.json'), - 'utf8', - ); + const pluginManifest = await fs.readFile(req.app.locals.paths.pluginManifest, 'utf8'); - const jsonData = JSON.parse(manifestFile); + const jsonData = JSON.parse(pluginManifest); + /** @type {TPlugin[]} */ const uniquePlugins = filterUniquePlugins(jsonData); const authenticatedPlugins = uniquePlugins.map((plugin) => { if (isPluginAuthenticated(plugin)) { @@ -58,6 +75,53 @@ const getAvailablePluginsController = async (req, res) => { } }; +/** + * Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file. + * + * This function first attempts to retrieve the list of tools from a cache. If the tools are not found in the cache, + * it reads a plugin manifest file, filters for unique plugins, and determines if each plugin is authenticated. + * Only plugins that are marked as available in the application's local state are included in the final list. + * The resulting list of tools is then cached and sent to the client. + * + * @param {object} req - The request object, containing information about the HTTP request. + * @param {object} res - The response object, used to send back the desired HTTP response. + * @returns {Promise<void>} A promise that resolves when the function has completed. + */ +const getAvailableTools = async (req, res) => { + try { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedTools = await cache.get(CacheKeys.TOOLS); + if (cachedTools) { + res.status(200).json(cachedTools); + return; + } + + const pluginManifest = await fs.readFile(req.app.locals.paths.pluginManifest, 'utf8'); + + const jsonData = JSON.parse(pluginManifest); + /** @type {TPlugin[]} */ + const uniquePlugins = filterUniquePlugins(jsonData); + + const authenticatedPlugins = uniquePlugins.map((plugin) => { + if (isPluginAuthenticated(plugin)) { + return { ...plugin, authenticated: true }; + } else { + return plugin; + } + }); + + const tools = authenticatedPlugins.filter( + (plugin) => req.app.locals.availableTools[plugin.pluginKey] !== undefined, + ); + + await cache.set(CacheKeys.TOOLS, tools); + res.status(200).json(tools); + } catch (error) { + res.status(500).json({ message: error.message }); + } +}; + module.exports = { + getAvailableTools, getAvailablePluginsController, }; diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index fa08cd54529..ac20ca627a1 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -8,16 +8,19 @@ const getUserController = async (req, res) => { const updateUserPluginsController = async (req, res) => { const { user } = req; - const { pluginKey, action, auth } = req.body; + const { pluginKey, action, auth, isAssistantTool } = req.body; let authService; try { - const userPluginsService = await updateUserPluginsService(user, pluginKey, action); + if (!isAssistantTool) { + const userPluginsService = await updateUserPluginsService(user, pluginKey, action); - if (userPluginsService instanceof Error) { - logger.error('[userPluginsService]', userPluginsService); - const { status, message } = userPluginsService; - res.status(status).send({ message }); + if (userPluginsService instanceof Error) { + logger.error('[userPluginsService]', userPluginsService); + const { status, message } = userPluginsService; + res.status(status).send({ message }); + } } + if (auth) { const keys = Object.keys(auth); const values = Object.values(auth); diff --git a/api/server/index.js b/api/server/index.js index 86806b59146..4e85c508010 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -2,15 +2,18 @@ require('dotenv').config(); const path = require('path'); require('module-alias')({ base: path.resolve(__dirname, '..') }); const cors = require('cors'); +const axios = require('axios'); const express = require('express'); const passport = require('passport'); const mongoSanitize = require('express-mongo-sanitize'); +const validateImageRequest = require('./middleware/validateImageRequest'); const errorController = require('./controllers/ErrorController'); const { jwtLogin, passportLogin } = require('~/strategies'); const configureSocialLogins = require('./socialLogins'); const { connectDb, indexSync } = require('~/lib/db'); const AppService = require('./services/AppService'); const noIndex = require('./middleware/noIndex'); +const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); const routes = require('./routes'); @@ -21,13 +24,19 @@ const port = Number(PORT) || 3080; const host = HOST || 'localhost'; const startServer = async () => { + if (typeof Bun !== 'undefined') { + axios.defaults.headers.common['Accept-Encoding'] = 'gzip'; + } await connectDb(); logger.info('Connected to MongoDB'); await indexSync(); const app = express(); + app.disable('x-powered-by'); await AppService(app); + app.get('/health', (_req, res) => res.status(200).send('OK')); + // Middleware app.use(noIndex); app.use(errorController); @@ -35,7 +44,8 @@ const startServer = async () => { app.use(mongoSanitize()); app.use(express.urlencoded({ extended: true, limit: '3mb' })); app.use(express.static(app.locals.paths.dist)); - app.use(express.static(app.locals.paths.publicPath)); + app.use(express.static(app.locals.paths.fonts)); + app.use(express.static(app.locals.paths.assets)); app.set('trust proxy', 1); // trust first proxy app.use(cors()); @@ -50,7 +60,7 @@ const startServer = async () => { passport.use(await jwtLogin()); passport.use(passportLogin()); - if (ALLOW_SOCIAL_LOGIN?.toLowerCase() === 'true') { + if (isEnabled(ALLOW_SOCIAL_LOGIN)) { configureSocialLogins(app); } @@ -73,7 +83,8 @@ const startServer = async () => { app.use('/api/plugins', routes.plugins); app.use('/api/config', routes.config); app.use('/api/assistants', routes.assistants); - app.use('/api/files', routes.files); + app.use('/api/files', await routes.files.initialize()); + app.use('/images/', validateImageRequest, routes.staticRoute); app.use((req, res) => { res.status(404).sendFile(path.join(app.locals.paths.dist, 'index.html')); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index cc9b9fc0513..a2be50ee82d 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,18 +1,24 @@ +const { EModelEndpoint } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { saveMessage, getConvo, getConvoTitle } = require('~/models'); const clearPendingReq = require('~/cache/clearPendingReq'); const abortControllers = require('./abortControllers'); const { redactMessage } = require('~/config/parsers'); const spendTokens = require('~/models/spendTokens'); +const { abortRun } = require('./abortRun'); const { logger } = require('~/config'); async function abortMessage(req, res) { - let { abortKey, conversationId } = req.body; + let { abortKey, conversationId, endpoint } = req.body; if (!abortKey && conversationId) { abortKey = conversationId; } + if (endpoint === EModelEndpoint.assistants) { + return await abortRun(req, res); + } + if (!abortControllers.has(abortKey) && !res.headersSent) { return res.status(204).send({ message: 'Request not found' }); } @@ -104,7 +110,7 @@ const handleAbortError = async (res, req, error, data) => { } const respondWithError = async (partialText) => { - const options = { + let options = { sender, messageId, conversationId, @@ -115,7 +121,8 @@ const handleAbortError = async (res, req, error, data) => { }; if (partialText) { - options.overrideProps = { + options = { + ...options, error: false, unfinished: true, text: partialText, diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js new file mode 100644 index 00000000000..6db6329d44c --- /dev/null +++ b/api/server/middleware/abortRun.js @@ -0,0 +1,92 @@ +const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { checkMessageGaps, recordUsage } = require('~/server/services/Threads'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { sendMessage } = require('~/server/utils'); +const { logger } = require('~/config'); + +const three_minutes = 1000 * 60 * 3; + +async function abortRun(req, res) { + res.setHeader('Content-Type', 'application/json'); + const { abortKey } = req.body; + const [conversationId, latestMessageId] = abortKey.split(':'); + const conversation = await getConvo(req.user.id, conversationId); + + if (conversation?.model) { + req.body.model = conversation.model; + } + + if (!isUUID.safeParse(conversationId).success) { + logger.error('[abortRun] Invalid conversationId', { conversationId }); + return res.status(400).send({ message: 'Invalid conversationId' }); + } + + const cacheKey = `${req.user.id}:${conversationId}`; + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const runValues = await cache.get(cacheKey); + const [thread_id, run_id] = runValues.split(':'); + + if (!run_id) { + logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id }); + return res.status(204).send({ message: 'Run not found' }); + } else if (run_id === 'cancelled') { + logger.warn('[abortRun] Run already cancelled', { thread_id }); + return res.status(204).send({ message: 'Run already cancelled' }); + } + + let runMessages = []; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + try { + await cache.set(cacheKey, 'cancelled', three_minutes); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug('[abortRun] Cancelled run:', cancelledRun); + } catch (error) { + logger.error('[abortRun] Error cancelling run', error); + if ( + error?.message?.includes(RunStatus.CANCELLED) || + error?.message?.includes(RunStatus.CANCELLING) + ) { + return res.end(); + } + } + + try { + const run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error('[abortRun] Error fetching or processing run', error); + } + + runMessages = await checkMessageGaps({ + openai, + latestMessageId, + thread_id, + run_id, + conversationId, + }); + + const finalEvent = { + final: true, + conversation, + runMessages, + }; + + if (res.headersSent && finalEvent) { + return sendMessage(res, finalEvent); + } + + res.json(finalEvent); +} + +module.exports = { + abortRun, +}; diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 91d0caceaaf..e0ae6c8534d 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,6 +1,8 @@ const { parseConvo, EModelEndpoint } = require('librechat-data-provider'); -const { processFiles } = require('~/server/services/Files/process'); +const { getModelsConfig } = require('~/server/controllers/ModelController'); +const assistants = require('~/server/services/Endpoints/assistants'); const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); +const { processFiles } = require('~/server/services/Files/process'); const anthropic = require('~/server/services/Endpoints/anthropic'); const openAI = require('~/server/services/Endpoints/openAI'); const custom = require('~/server/services/Endpoints/custom'); @@ -13,9 +15,10 @@ const buildFunction = { [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, + [EModelEndpoint.assistants]: assistants.buildOptions, }; -function buildEndpointOption(req, res, next) { +async function buildEndpointOption(req, res, next) { const { endpoint, endpointType } = req.body; const parsedBody = parseConvo({ endpoint, endpointType, conversation: req.body }); req.body.endpointOption = buildFunction[endpointType ?? endpoint]( @@ -23,6 +26,10 @@ function buildEndpointOption(req, res, next) { parsedBody, endpointType, ); + + const modelsConfig = await getModelsConfig(req); + req.body.endpointOption.modelsConfig = modelsConfig; + if (req.body.files) { // hold the promise req.body.endpointOption.attachments = processFiles(req.body.files); diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 8000aa2b107..37952176bfa 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -1,7 +1,7 @@ const crypto = require('crypto'); -const { saveMessage } = require('~/models'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { sendMessage, sendError } = require('~/server/utils'); -const { getResponseSender } = require('librechat-data-provider'); +const { saveMessage } = require('~/models'); /** * Denies a request by sending an error message and optionally saves the user's message. @@ -38,8 +38,7 @@ const denyRequest = async (req, res, errorMessage) => { }; sendMessage(res, { message: userMessage, created: true }); - const shouldSaveMessage = - _convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000'; + const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT; if (shouldSaveMessage) { await saveMessage({ ...userMessage, user: req.user.id }); diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 77afd971650..b9960a237af 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -3,7 +3,9 @@ const checkBan = require('./checkBan'); const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); +const validateModel = require('./validateModel'); const requireJwtAuth = require('./requireJwtAuth'); +const uploadLimiters = require('./uploadLimiters'); const registerLimiter = require('./registerLimiter'); const messageLimiters = require('./messageLimiters'); const requireLocalAuth = require('./requireLocalAuth'); @@ -12,10 +14,12 @@ const concurrentLimiter = require('./concurrentLimiter'); const validateMessageReq = require('./validateMessageReq'); const buildEndpointOption = require('./buildEndpointOption'); const validateRegistration = require('./validateRegistration'); +const validateImageRequest = require('./validateImageRequest'); const moderateText = require('./moderateText'); const noIndex = require('./noIndex'); module.exports = { + ...uploadLimiters, ...abortMiddleware, ...messageLimiters, checkBan, @@ -30,6 +34,8 @@ module.exports = { validateMessageReq, buildEndpointOption, validateRegistration, + validateImageRequest, + validateModel, moderateText, noIndex, }; diff --git a/api/server/middleware/moderateText.js b/api/server/middleware/moderateText.js index c4bfd8a13ae..40bc5e9430b 100644 --- a/api/server/middleware/moderateText.js +++ b/api/server/middleware/moderateText.js @@ -1,5 +1,6 @@ const axios = require('axios'); const denyRequest = require('./denyRequest'); +const { logger } = require('~/config'); async function moderateText(req, res, next) { if (process.env.OPENAI_MODERATION === 'true') { @@ -28,7 +29,7 @@ async function moderateText(req, res, next) { return await denyRequest(req, res, errorMessage); } } catch (error) { - console.error('Error in moderateText:', error); + logger.error('Error in moderateText:', error); const errorMessage = 'error in moderation check'; return await denyRequest(req, res, errorMessage); } diff --git a/api/server/middleware/uploadLimiters.js b/api/server/middleware/uploadLimiters.js new file mode 100644 index 00000000000..71af164fde4 --- /dev/null +++ b/api/server/middleware/uploadLimiters.js @@ -0,0 +1,75 @@ +const rateLimit = require('express-rate-limit'); +const { ViolationTypes } = require('librechat-data-provider'); +const logViolation = require('~/cache/logViolation'); + +const getEnvironmentVariables = () => { + const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; + const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15; + const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50; + const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15; + + const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000; + const fileUploadIpMax = FILE_UPLOAD_IP_MAX; + const fileUploadIpWindowInMinutes = fileUploadIpWindowMs / 60000; + + const fileUploadUserWindowMs = FILE_UPLOAD_USER_WINDOW * 60 * 1000; + const fileUploadUserMax = FILE_UPLOAD_USER_MAX; + const fileUploadUserWindowInMinutes = fileUploadUserWindowMs / 60000; + + return { + fileUploadIpWindowMs, + fileUploadIpMax, + fileUploadIpWindowInMinutes, + fileUploadUserWindowMs, + fileUploadUserMax, + fileUploadUserWindowInMinutes, + }; +}; + +const createFileUploadHandler = (ip = true) => { + const { + fileUploadIpMax, + fileUploadIpWindowInMinutes, + fileUploadUserMax, + fileUploadUserWindowInMinutes, + } = getEnvironmentVariables(); + + return async (req, res) => { + const type = ViolationTypes.FILE_UPLOAD_LIMIT; + const errorMessage = { + type, + max: ip ? fileUploadIpMax : fileUploadUserMax, + limiter: ip ? 'ip' : 'user', + windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes, + }; + + await logViolation(req, res, type, errorMessage); + res.status(429).json({ message: 'Too many file upload requests. Try again later' }); + }; +}; + +const createFileLimiters = () => { + const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } = + getEnvironmentVariables(); + + const fileUploadIpLimiter = rateLimit({ + windowMs: fileUploadIpWindowMs, + max: fileUploadIpMax, + handler: createFileUploadHandler(), + }); + + const fileUploadUserLimiter = rateLimit({ + windowMs: fileUploadUserWindowMs, + max: fileUploadUserMax, + handler: createFileUploadHandler(false), + keyGenerator: function (req) { + return req.user?.id; // Use the user ID or NULL if not available + }, + }); + + return { fileUploadIpLimiter, fileUploadUserLimiter }; +}; + +module.exports = { + createFileLimiters, +}; diff --git a/api/server/middleware/validateImageRequest.js b/api/server/middleware/validateImageRequest.js new file mode 100644 index 00000000000..03482b4b118 --- /dev/null +++ b/api/server/middleware/validateImageRequest.js @@ -0,0 +1,37 @@ +const cookies = require('cookie'); +const jwt = require('jsonwebtoken'); +const { logger } = require('~/config'); + +/** + * Middleware to validate image request + */ +function validateImageRequest(req, res, next) { + const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null; + if (!refreshToken) { + logger.warn('[validateImageRequest] Refresh token not provided'); + return res.status(401).send('Unauthorized'); + } + + let payload; + try { + payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); + } catch (err) { + logger.warn('[validateImageRequest]', err); + return res.status(403).send('Access Denied'); + } + + const currentTimeInSeconds = Math.floor(Date.now() / 1000); + if (payload.exp < currentTimeInSeconds) { + logger.warn('[validateImageRequest] Refresh token expired'); + return res.status(403).send('Access Denied'); + } + + if (req.path.includes(payload.id)) { + logger.debug('[validateImageRequest] Image request validated'); + next(); + } else { + res.status(403).send('Access Denied'); + } +} + +module.exports = validateImageRequest; diff --git a/api/server/middleware/validateModel.js b/api/server/middleware/validateModel.js new file mode 100644 index 00000000000..dacbb826297 --- /dev/null +++ b/api/server/middleware/validateModel.js @@ -0,0 +1,47 @@ +const { ViolationTypes } = require('librechat-data-provider'); +const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { handleError } = require('~/server/utils'); +const { logViolation } = require('~/cache'); +/** + * Validates the model of the request. + * + * @async + * @param {Express.Request} req - The Express request object. + * @param {Express.Response} res - The Express response object. + * @param {Function} next - The Express next function. + */ +const validateModel = async (req, res, next) => { + const { model, endpoint } = req.body; + if (!model) { + return handleError(res, { text: 'Model not provided' }); + } + + const modelsConfig = await getModelsConfig(req); + + if (!modelsConfig) { + return handleError(res, { text: 'Models not loaded' }); + } + + const availableModels = modelsConfig[endpoint]; + if (!availableModels) { + return handleError(res, { text: 'Endpoint models not loaded' }); + } + + let validModel = !!availableModels.find((availableModel) => availableModel === model); + + if (validModel) { + return next(); + } + + const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {}; + + const type = ViolationTypes.ILLEGAL_MODEL_REQUEST; + const errorMessage = { + type, + }; + + await logViolation(req, res, type, errorMessage, score); + return handleError(res, { text: 'Illegal model request' }); +}; + +module.exports = validateModel; diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 4833b83d105..bc3742dfffc 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -1,7 +1,9 @@ const request = require('supertest'); const express = require('express'); const routes = require('../'); +// file deepcode ignore UseCsurfForExpress/test: test const app = express(); +app.disable('x-powered-by'); app.use('/api/config', routes.config); afterEach(() => { @@ -54,13 +56,14 @@ describe.skip('GET /', () => { expect(response.statusCode).toBe(200); expect(response.body).toEqual({ appTitle: 'Test Title', - googleLoginEnabled: true, + socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'], + discordLoginEnabled: true, facebookLoginEnabled: true, + githubLoginEnabled: true, + googleLoginEnabled: true, openidLoginEnabled: true, openidLabel: 'Test OpenID', openidImageUrl: 'http://test-server.com', - githubLoginEnabled: true, - discordLoginEnabled: true, serverDomain: 'http://test-server.com', emailLoginEnabled: 'true', registrationEnabled: 'true', diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index e0ea0f9857f..a08d1d25705 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -1,9 +1,10 @@ const express = require('express'); const AskController = require('~/server/controllers/AskController'); -const { initializeClient } = require('~/server/services/Endpoints/anthropic'); +const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index 34f1096a871..4ce1770b8ed 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -1,5 +1,6 @@ const crypto = require('crypto'); const express = require('express'); +const { Constants } = require('librechat-data-provider'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('~/models'); const { handleError, sendMessage, createOnProgress, handleText } = require('~/server/utils'); const { setHeaders } = require('~/server/middleware'); @@ -27,7 +28,7 @@ router.post('/', setHeaders, async (req, res) => { const conversationId = oldConversationId || crypto.randomUUID(); const isNewConversation = !oldConversationId; const userMessageId = crypto.randomUUID(); - const userParentMessageId = parentMessageId || '00000000-0000-0000-0000-000000000000'; + const userParentMessageId = parentMessageId || Constants.NO_PARENT; const userMessage = { messageId: userMessageId, sender: 'User', @@ -209,7 +210,7 @@ const ask = async ({ }); res.end(); - if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { + if (userParentMessageId == Constants.NO_PARENT) { // const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await response.details.title; await saveConvo(user, { diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index 1281b56ae35..916cda4b10f 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -1,5 +1,6 @@ -const express = require('express'); const crypto = require('crypto'); +const express = require('express'); +const { Constants } = require('librechat-data-provider'); const { handleError, sendMessage, createOnProgress, handleText } = require('~/server/utils'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('~/models'); const { setHeaders } = require('~/server/middleware'); @@ -28,7 +29,7 @@ router.post('/', setHeaders, async (req, res) => { const conversationId = oldConversationId || crypto.randomUUID(); const isNewConversation = !oldConversationId; const userMessageId = messageId; - const userParentMessageId = parentMessageId || '00000000-0000-0000-0000-000000000000'; + const userParentMessageId = parentMessageId || Constants.NO_PARENT; let userMessage = { messageId: userMessageId, sender: 'User', @@ -238,7 +239,7 @@ const ask = async ({ }); res.end(); - if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { + if (userParentMessageId == Constants.NO_PARENT) { const title = await titleConvoBing({ text, response: responseMessage, diff --git a/api/server/routes/ask/custom.js b/api/server/routes/ask/custom.js index ef979bf0000..668a9902cb9 100644 --- a/api/server/routes/ask/custom.js +++ b/api/server/routes/ask/custom.js @@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -13,8 +14,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index 78c648495ff..b5425d67649 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 85616cd1b31..f93a5a953bb 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,81 +1,88 @@ const express = require('express'); -const router = express.Router(); -const { getResponseSender } = require('librechat-data-provider'); -const { validateTools } = require('~/app'); -const { addTitle } = require('~/server/services/Endpoints/openAI'); +const throttle = require('lodash/throttle'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { saveMessage, getConvoTitle, getConvo } = require('~/models'); const { sendMessage, createOnProgress } = require('~/server/utils'); +const { addTitle } = require('~/server/services/Endpoints/openAI'); const { handleAbort, createAbortController, handleAbortError, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, } = require('~/server/middleware'); +const { validateTools } = require('~/app'); const { logger } = require('~/config'); +const router = express.Router(); + router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); - let metadata; - let userMessage; - let promptTokens; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const newConvo = !conversationId; - const user = req.user.id; - - const plugins = []; - - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } - }; - - let streaming = null; - let timer = null; +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); - if (timer) { - clearTimeout(timer); + let userMessage; + let promptTokens; + let userMessageId; + let responseMessageId; + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.modelOptions.model, + }); + const newConvo = !conversationId; + const user = req.user.id; + + const plugins = []; + + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + userMessageId = data[key].messageId; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } else if (!conversationId && key === 'conversationId') { + conversationId = data[key]; + } } + }; + + const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); + let streaming = null; + let timer = null; - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + onProgress: ({ text: partialText }) => { + if (timer) { + clearTimeout(timer); + } + + throttledSaveMessage({ messageId: responseMessageId, sender, conversationId, @@ -87,140 +94,131 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, plugins, user, }); + + streaming = new Promise((resolve) => { + timer = setTimeout(() => { + resolve(); + }, 250); + }); + }, + }); + + const pluginMap = new Map(); + const onAgentAction = async (action, runId) => { + pluginMap.set(runId, action.tool); + sendIntermediateMessage(res, { plugins }); + }; + + const onToolStart = async (tool, input, runId, parentRunId) => { + const pluginName = pluginMap.get(parentRunId); + const latestPlugin = { + runId, + loading: true, + inputs: [input], + latest: pluginName, + outputs: null, + }; + + if (streaming) { + await streaming; } + const extraTokens = ':::plugin:::\n'; + plugins.push(latestPlugin); + sendIntermediateMessage(res, { plugins }, extraTokens); + }; - if (saveDelay < 500) { - saveDelay = 500; + const onToolEnd = async (output, runId) => { + if (streaming) { + await streaming; } - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { plugins }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; + const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage(res, { plugins }, extraTokens); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; - } + if (pluginIndex !== -1) { + plugins[pluginIndex].loading = false; + plugins[pluginIndex].outputs = output; + } + }; - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); + const onChainEnd = () => { + saveMessage({ ...userMessage, user }); + sendIntermediateMessage(res, { plugins }); + }; - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const onChainEnd = () => { - saveMessage({ ...userMessage, user }); - sendIntermediateMessage(res, { plugins }); - }; - - const getAbortData = () => ({ - sender, - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - user, + const getAbortData = () => ({ + sender, conversationId, - parentMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - addMetadata, - getPartialText, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - plugins, - }), - abortController, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugins: plugins.map((p) => ({ ...p, loading: false })), + userMessage, + promptTokens, }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + user, + conversationId, + parentMessageId, + overrideParentMessageId, + getReqData, + onAgentAction, + onChainEnd, + onToolStart, + onToolEnd, + onStart, + getPartialText, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + plugins, + }), + abortController, + }); - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - logger.debug('[/ask/gptPlugins]', response); + logger.debug('[/ask/gptPlugins]', response); - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage({ ...response, user }); + response.plugins = plugins.map((p) => ({ ...p, loading: false })); + await saveMessage({ ...response, user }); - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); - if (parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { - addTitle(req, { - text, - response, - client, + if (parentMessageId === Constants.NO_PARENT && newConvo) { + addTitle(req, { + text, + response, + client, + }); + } + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } -}); + }, +); module.exports = router; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 31b3111077f..5083a08b104 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -4,6 +4,7 @@ const { addTitle, initializeClient } = require('~/server/services/Endpoints/open const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -13,8 +14,15 @@ const router = express.Router(); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js new file mode 100644 index 00000000000..33db6ce803a --- /dev/null +++ b/api/server/routes/assistants/actions.js @@ -0,0 +1,202 @@ +const { v4 } = require('uuid'); +const express = require('express'); +const { actionDelimiter } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { updateAction, getActions, deleteAction } = require('~/models/Action'); +const { updateAssistant, getAssistant } = require('~/models/Assistant'); +const { logger } = require('~/config'); + +const router = express.Router(); + +/** + * Retrieves all user's actions + * @route GET /actions/ + * @param {string} req.params.id - Assistant identifier. + * @returns {Action[]} 200 - success response - application/json + */ +router.get('/', async (req, res) => { + try { + res.json(await getActions()); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * Adds or updates actions for a specific assistant. + * @route POST /actions/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {FunctionTool[]} req.body.functions - The functions to be added or updated. + * @param {string} [req.body.action_id] - Optional ID for the action. + * @param {ActionMetadata} req.body.metadata - Metadata for the action. + * @returns {Object} 200 - success response - application/json + */ +router.post('/:assistant_id', async (req, res) => { + try { + const { assistant_id } = req.params; + + /** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */ + const { functions, action_id: _action_id, metadata: _metadata } = req.body; + if (!functions.length) { + return res.status(400).json({ message: 'No functions provided' }); + } + + let metadata = encryptMetadata(_metadata); + + let { domain } = metadata; + /* Azure doesn't support periods in function names */ + domain = domainParser(req, domain, true); + + if (!domain) { + return res.status(400).json({ message: 'No domain provided' }); + } + + const action_id = _action_id ?? v4(); + const initialPromises = []; + + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + initialPromises.push(getAssistant({ assistant_id })); + initialPromises.push(openai.beta.assistants.retrieve(assistant_id)); + !!_action_id && initialPromises.push(getActions({ action_id }, true)); + + /** @type {[AssistantDocument, Assistant, [Action|undefined]]} */ + const [assistant_data, assistant, actions_result] = await Promise.all(initialPromises); + + if (actions_result && actions_result.length) { + const action = actions_result[0]; + metadata = { ...action.metadata, ...metadata }; + } + + if (!assistant) { + return res.status(404).json({ message: 'Assistant not found' }); + } + + const { actions: _actions = [] } = assistant_data ?? {}; + const actions = []; + for (const action of _actions) { + const [_action_domain, current_action_id] = action.split(actionDelimiter); + if (current_action_id === action_id) { + continue; + } + + actions.push(action); + } + + actions.push(`${domain}${actionDelimiter}${action_id}`); + + /** @type {{ tools: FunctionTool[] | { type: 'code_interpreter'|'retrieval'}[]}} */ + const { tools: _tools = [] } = assistant; + + const tools = _tools + .filter( + (tool) => + !( + tool.function && + (tool.function.name.includes(domain) || tool.function.name.includes(action_id)) + ), + ) + .concat( + functions.map((tool) => ({ + ...tool, + function: { + ...tool.function, + name: `${tool.function.name}${actionDelimiter}${domain}`, + }, + })), + ); + + const promises = []; + promises.push( + updateAssistant( + { assistant_id }, + { + actions, + user: req.user.id, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { tools })); + promises.push(updateAction({ action_id }, { metadata, assistant_id, user: req.user.id })); + + /** @type {[AssistantDocument, Assistant, Action]} */ + const resolved = await Promise.all(promises); + const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; + for (let field of sensitiveFields) { + if (resolved[2].metadata[field]) { + delete resolved[2].metadata[field]; + } + } + res.json(resolved); + } catch (error) { + const message = 'Trouble updating the Assistant Action'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + +/** + * Deletes an action for a specific assistant. + * @route DELETE /actions/:assistant_id/:action_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {string} req.params.action_id - The ID of the action to delete. + * @returns {Object} 200 - success response - application/json + */ +router.delete('/:assistant_id/:action_id/:model', async (req, res) => { + try { + const { assistant_id, action_id, model } = req.params; + req.body.model = model; + + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const initialPromises = []; + initialPromises.push(getAssistant({ assistant_id })); + initialPromises.push(openai.beta.assistants.retrieve(assistant_id)); + + /** @type {[AssistantDocument, Assistant]} */ + const [assistant_data, assistant] = await Promise.all(initialPromises); + + const { actions = [] } = assistant_data ?? {}; + const { tools = [] } = assistant ?? {}; + + let domain = ''; + const updatedActions = actions.filter((action) => { + if (action.includes(action_id)) { + [domain] = action.split(actionDelimiter); + return false; + } + return true; + }); + + domain = domainParser(req, domain, true); + + const updatedTools = tools.filter( + (tool) => !(tool.function && tool.function.name.includes(domain)), + ); + + const promises = []; + promises.push( + updateAssistant( + { assistant_id }, + { + actions: updatedActions, + user: req.user.id, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { tools: updatedTools })); + promises.push(deleteAction({ action_id })); + + await Promise.all(promises); + res.status(200).json({ message: 'Action deleted successfully' }); + } catch (error) { + const message = 'Trouble deleting the Assistant Action'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/assistants/assistants.js b/api/server/routes/assistants/assistants.js index b911c685aa9..70c685a97a0 100644 --- a/api/server/routes/assistants/assistants.js +++ b/api/server/routes/assistants/assistants.js @@ -1,9 +1,35 @@ -const OpenAI = require('openai'); +const multer = require('multer'); const express = require('express'); +const { FileContext, EModelEndpoint } = require('librechat-data-provider'); +const { + initializeClient, + listAssistantsForAzure, + listAssistants, +} = require('~/server/services/Endpoints/assistants'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { uploadImageBuffer } = require('~/server/services/Files/process'); +const { updateAssistant, getAssistants } = require('~/models/Assistant'); +const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); +const actions = require('./actions'); +const tools = require('./tools'); +const upload = multer(); const router = express.Router(); +/** + * Assistant actions route. + * @route GET|POST /assistants/actions + */ +router.use('/actions', actions); + +/** + * Create an assistant. + * @route GET /assistants/tools + * @returns {TPlugin[]} 200 - application/json + */ +router.use('/tools', tools); + /** * Create an assistant. * @route POST /assistants @@ -12,12 +38,29 @@ const router = express.Router(); */ router.post('/', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); - const assistantData = req.body; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const { tools = [], ...assistantData } = req.body; + assistantData.tools = tools + .map((tool) => { + if (typeof tool !== 'string') { + return tool; + } + + return req.app.locals.availableTools[tool]; + }) + .filter((tool) => tool); + + if (openai.locals?.azureOptions) { + assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName; + } + const assistant = await openai.beta.assistants.create(assistantData); logger.debug('/assistants/', assistant); res.status(201).json(assistant); } catch (error) { + logger.error('[/assistants] Error creating assistant', error); res.status(500).json({ error: error.message }); } }); @@ -30,11 +73,14 @@ router.post('/', async (req, res) => { */ router.get('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const assistant = await openai.beta.assistants.retrieve(assistant_id); res.json(assistant); } catch (error) { + logger.error('[/assistants/:id] Error retrieving assistant', error); res.status(500).json({ error: error.message }); } }); @@ -48,12 +94,29 @@ router.get('/:id', async (req, res) => { */ router.patch('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const updateData = req.body; + updateData.tools = (updateData.tools ?? []) + .map((tool) => { + if (typeof tool !== 'string') { + return tool; + } + + return req.app.locals.availableTools[tool]; + }) + .filter((tool) => tool); + + if (openai.locals?.azureOptions && updateData.model) { + updateData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName; + } + const updatedAssistant = await openai.beta.assistants.update(assistant_id, updateData); res.json(updatedAssistant); } catch (error) { + logger.error('[/assistants/:id] Error updating assistant', error); res.status(500).json({ error: error.message }); } }); @@ -66,12 +129,15 @@ router.patch('/:id', async (req, res) => { */ router.delete('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const deletionStatus = await openai.beta.assistants.del(assistant_id); res.json(deletionStatus); } catch (error) { - res.status(500).json({ error: error.message }); + logger.error('[/assistants/:id] Error deleting assistant', error); + res.status(500).json({ error: 'Error deleting assistant' }); } }); @@ -79,22 +145,121 @@ router.delete('/:id', async (req, res) => { * Returns a list of assistants. * @route GET /assistants * @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting. - * @returns {Array<Assistant>} 200 - success response - application/json + * @returns {AssistantListResponse} 200 - success response - application/json */ router.get('/', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); - const { limit, order, after, before } = req.query; - const assistants = await openai.beta.assistants.list({ - limit, - order, - after, - before, - }); - res.json(assistants); + const { limit = 100, order = 'desc', after, before } = req.query; + const query = { limit, order, after, before }; + + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + /** @type {AssistantListResponse} */ + let body; + + if (azureConfig?.assistants) { + body = await listAssistantsForAzure({ req, res, azureConfig, query }); + } else { + ({ body } = await listAssistants({ req, res, query })); + } + + if (req.app.locals?.[EModelEndpoint.assistants]) { + /** @type {Partial<TAssistantEndpoint>} */ + const assistantsConfig = req.app.locals[EModelEndpoint.assistants]; + const { supportedIds, excludedIds } = assistantsConfig; + if (supportedIds?.length) { + body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id)); + } else if (excludedIds?.length) { + body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id)); + } + } + + res.json(body); } catch (error) { + logger.error('[/assistants] Error listing assistants', error); + res.status(500).json({ message: 'Error listing assistants' }); + } +}); + +/** + * Returns a list of the user's assistant documents (metadata saved to database). + * @route GET /assistants/documents + * @returns {AssistantDocument[]} 200 - success response - application/json + */ +router.get('/documents', async (req, res) => { + try { + res.json(await getAssistants({ user: req.user.id })); + } catch (error) { + logger.error('[/assistants/documents] Error listing assistant documents', error); res.status(500).json({ error: error.message }); } }); +/** + * Uploads and updates an avatar for a specific assistant. + * @route POST /avatar/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {Express.Multer.File} req.file - The avatar image file. + * @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar. + * @returns {Object} 200 - success response - application/json + */ +router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => { + try { + const { assistant_id } = req.params; + if (!assistant_id) { + return res.status(400).json({ message: 'Assistant ID is required' }); + } + + let { metadata: _metadata = '{}' } = req.body; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const image = await uploadImageBuffer({ req, context: FileContext.avatar }); + + try { + _metadata = JSON.parse(_metadata); + } catch (error) { + logger.error('[/avatar/:assistant_id] Error parsing metadata', error); + _metadata = {}; + } + + if (_metadata.avatar && _metadata.avatar_source) { + const { deleteFile } = getStrategyFunctions(_metadata.avatar_source); + try { + await deleteFile(req, { filepath: _metadata.avatar }); + await deleteFileByFilter({ filepath: _metadata.avatar }); + } catch (error) { + logger.error('[/avatar/:assistant_id] Error deleting old avatar', error); + } + } + + const metadata = { + ..._metadata, + avatar: image.filepath, + avatar_source: req.app.locals.fileStrategy, + }; + + const promises = []; + promises.push( + updateAssistant( + { assistant_id }, + { + avatar: { + filepath: image.filepath, + source: req.app.locals.fileStrategy, + }, + user: req.user.id, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { metadata })); + + const resolved = await Promise.all(promises); + res.status(201).json(resolved[1]); + } catch (error) { + const message = 'An error occurred while updating the Assistant Avatar'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + module.exports = router; diff --git a/api/server/routes/assistants/chat.js b/api/server/routes/assistants/chat.js index e45bad191e9..69be8a7b3e4 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/routes/assistants/chat.js @@ -1,108 +1,658 @@ -const crypto = require('crypto'); -const OpenAI = require('openai'); -const { logger } = require('~/config'); -const { sendMessage } = require('../../utils'); -const { initThread, createRun, handleRun } = require('../../services/AssistantService'); +const { v4 } = require('uuid'); const express = require('express'); +const { + Constants, + RunStatus, + CacheKeys, + FileSources, + ContentTypes, + EModelEndpoint, + ViolationTypes, + ImageVisionTool, + AssistantStreamEvents, +} = require('librechat-data-provider'); +const { + initThread, + recordUsage, + saveUserMessage, + checkMessageGaps, + addThreadMetadata, + saveAssistantMessage, +} = require('~/server/services/Threads'); +const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); +const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants'); +const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); +const { createRun, StreamRunManager } = require('~/server/services/Runs'); +const { getTransactions } = require('~/models/Transaction'); +const checkBalance = require('~/models/checkBalance'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { getModelMaxTokens } = require('~/utils'); +const { logger } = require('~/config'); + const router = express.Router(); const { setHeaders, - // handleAbort, - // handleAbortError, + handleAbort, + validateModel, + handleAbortError, // validateEndpoint, - // buildEndpointOption, - // createAbortController, -} = require('../../middleware'); + buildEndpointOption, +} = require('~/server/middleware'); + +router.post('/abort', handleAbort()); -// const thread = { -// id: 'thread_LexzJUVugYFqfslS7c7iL3Zo', -// "thread_nZoiCbPauU60LqY1Q0ME1elg" -// }; +const ten_minutes = 1000 * 60 * 10; /** - * Chat with an assistant. + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} */ -router.post('/', setHeaders, async (req, res) => { +router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => { + logger.debug('[/assistants/chat/] req.body', req.body); + + const { + text, + model, + files = [], + promptPrefix, + assistant_id, + instructions, + thread_id: _thread_id, + messageId: _messageId, + conversationId: convoId, + parentMessageId: _parentId = Constants.NO_PARENT, + } = req.body; + + /** @type {Partial<TAssistantEndpoint>} */ + const assistantsConfig = req.app.locals?.[EModelEndpoint.assistants]; + + if (assistantsConfig) { + const { supportedIds, excludedIds } = assistantsConfig; + const error = { message: 'Assistant not supported' }; + if (supportedIds?.length && !supportedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + error, + }); + } else if (excludedIds?.length && excludedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + }); + } + } + + /** @type {OpenAIClient} */ + let openai; + /** @type {string|undefined} - the current thread id */ + let thread_id = _thread_id; + /** @type {string|undefined} - the current run id */ + let run_id; + /** @type {string|undefined} - the parent messageId */ + let parentMessageId = _parentId; + /** @type {TMessage[]} */ + let previousMessages = []; + /** @type {import('librechat-data-provider').TConversation | null} */ + let conversation = null; + /** @type {string[]} */ + let file_ids = []; + /** @type {Set<string>} */ + let attachedFileIds = new Set(); + /** @type {TMessage | null} */ + let requestMessage = null; + /** @type {undefined | Promise<ChatCompletion>} */ + let visionPromise; + + const userMessageId = v4(); + const responseMessageId = v4(); + + /** @type {string} - The conversation UUID - created if undefined */ + const conversationId = convoId ?? v4(); + + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const cacheKey = `${req.user.id}:${conversationId}`; + + /** @type {Run | undefined} - The completed run, undefined if incomplete */ + let completedRun; + + const handleError = async (error) => { + const defaultErrorMessage = + 'The Assistant run failed to initialize. Try sending a message in a new conversation.'; + const messageData = { + thread_id, + assistant_id, + conversationId, + parentMessageId, + sender: 'System', + user: req.user.id, + shouldSaveMessage: false, + messageId: responseMessageId, + endpoint: EModelEndpoint.assistants, + }; + + if (error.message === 'Run cancelled') { + return res.end(); + } else if (error.message === 'Request closed' && completedRun) { + return; + } else if (error.message === 'Request closed') { + logger.debug('[/assistants/chat/] Request aborted on close'); + } else if (/Files.*are invalid/.test(error.message)) { + const errorMessage = `Files are invalid, or may not have uploaded yet.${ + req.app.locals?.[EModelEndpoint.azureOpenAI].assistants + ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + : '' + }`; + return sendResponse(res, messageData, errorMessage); + } else if (error?.message?.includes('string too long')) { + return sendResponse( + res, + messageData, + 'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.', + ); + } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { + return sendResponse(res, messageData, error.message); + } else { + logger.error('[/assistants/chat/]', error); + } + + if (!openai || !thread_id || !run_id) { + return sendResponse(res, messageData, defaultErrorMessage); + } + + await sleep(2000); + + try { + const status = await cache.get(cacheKey); + if (status === 'cancelled') { + logger.debug('[/assistants/chat/] Run already cancelled'); + return res.end(); + } + await cache.delete(cacheKey); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun); + } catch (error) { + logger.error('[/assistants/chat/] Error cancelling run', error); + } + + await sleep(2000); + + let run; + try { + run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error('[/assistants/chat/] Error fetching or processing run', error); + } + + let finalEvent; + try { + const runMessages = await checkMessageGaps({ + openai, + run_id, + thread_id, + conversationId, + latestMessageId: responseMessageId, + }); + + const errorContentPart = { + text: { + value: + error?.message ?? 'There was an error processing your request. Please try again later.', + }, + type: ContentTypes.ERROR, + }; + + if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) { + runMessages[runMessages.length - 1].content = [errorContentPart]; + } else { + const contentParts = runMessages[runMessages.length - 1].content; + for (let i = 0; i < contentParts.length; i++) { + const currentPart = contentParts[i]; + /** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */ + const toolCall = currentPart?.[ContentTypes.TOOL_CALL]; + if ( + toolCall && + toolCall?.function && + !(toolCall?.function?.output || toolCall?.function?.output?.length) + ) { + contentParts[i] = { + ...currentPart, + [ContentTypes.TOOL_CALL]: { + ...toolCall, + function: { + ...toolCall.function, + output: 'error processing tool', + }, + }, + }; + } + } + runMessages[runMessages.length - 1].content.push(errorContentPart); + } + + finalEvent = { + final: true, + conversation: await getConvo(req.user.id, conversationId), + runMessages, + }; + } catch (error) { + logger.error('[/assistants/chat/] Error finalizing error process', error); + return sendResponse(res, messageData, 'The Assistant run failed'); + } + + return sendResponse(res, finalEvent); + }; + try { - logger.debug('[/assistants/chat/] req.body', req.body); - // test message: - // How many polls of 500 ms intervals are there in 18 seconds? + res.on('close', async () => { + if (!completedRun) { + await handleError(new Error('Request closed')); + } + }); - const { assistant_id, messages, text: userMessage, messageId } = req.body; - const conversationId = req.body.conversationId || crypto.randomUUID(); - // let thread_id = req.body.thread_id ?? 'thread_nZoiCbPauU60LqY1Q0ME1elg'; // for testing - let thread_id = req.body.thread_id; + if (convoId && !_thread_id) { + completedRun = true; + throw new Error('Missing thread_id for existing conversation'); + } if (!assistant_id) { + completedRun = true; throw new Error('Missing assistant_id'); } - const openai = new OpenAI(process.env.OPENAI_API_KEY); - console.log(messages); - - const initThreadBody = { - messages: [ - { - role: 'user', - content: userMessage, - metadata: { - messageId, - }, + const checkBalanceBeforeRun = async () => { + if (!isEnabled(process.env.CHECK_BALANCE)) { + return; + } + const transactions = + (await getTransactions({ + user: req.user.id, + context: 'message', + conversationId, + })) ?? []; + + const totalPreviousTokens = Math.abs( + transactions.reduce((acc, curr) => acc + curr.rawAmount, 0), + ); + + // TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions + const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0; + // 5 is added for labels + let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5; + promptTokens += totalPreviousTokens + promptBuffer; + // Count tokens up to the current context window + promptTokens = Math.min(promptTokens, getModelMaxTokens(model)); + + await checkBalance({ + req, + res, + txData: { + model, + user: req.user.id, + tokenType: 'prompt', + amount: promptTokens, }, - ], + }); + }; + + /** @type {{ openai: OpenAIClient }} */ + const { openai: _openai, client } = await initializeClient({ + req, + res, + endpointOption: req.body.endpointOption, + initAppClient: true, + }); + + openai = _openai; + + if (previousMessages.length) { + parentMessageId = previousMessages[previousMessages.length - 1].messageId; + } + + let userMessage = { + role: 'user', + content: text, metadata: { - conversationId, + messageId: userMessageId, }, }; - const result = await initThread({ openai, body: initThreadBody, thread_id }); - // const { messages: _messages } = result; - thread_id = result.thread_id; + /** @type {CreateRunBody | undefined} */ + const body = { + assistant_id, + model, + }; - /* NOTE: - * By default, a Run will use the model and tools configuration specified in Assistant object, - * but you can override most of these when creating the Run for added flexibility: - */ - const run = await createRun({ - openai, - thread_id, - body: { assistant_id, model: 'gpt-3.5-turbo-1106' }, + if (promptPrefix) { + body.additional_instructions = promptPrefix; + } + + if (instructions) { + body.instructions = instructions; + } + + const getRequestFileIds = async () => { + let thread_file_ids = []; + if (convoId) { + const convo = await getConvo(req.user.id, convoId); + if (convo && convo.file_ids) { + thread_file_ids = convo.file_ids; + } + } + + file_ids = files.map(({ file_id }) => file_id); + if (file_ids.length || thread_file_ids.length) { + userMessage.file_ids = file_ids; + attachedFileIds = new Set([...file_ids, ...thread_file_ids]); + } + }; + + const addVisionPrompt = async () => { + if (!req.body.endpointOption.attachments) { + return; + } + + /** @type {MongoFile[]} */ + const attachments = await req.body.endpointOption.attachments; + if ( + attachments && + attachments.every((attachment) => attachment.source === FileSources.openai) + ) { + return; + } + + const assistant = await openai.beta.assistants.retrieve(assistant_id); + const visionToolIndex = assistant.tools.findIndex( + (tool) => tool?.function && tool?.function?.name === ImageVisionTool.function.name, + ); + + if (visionToolIndex === -1) { + return; + } + + let visionMessage = { + role: 'user', + content: '', + }; + const files = await client.addImageURLs(visionMessage, attachments); + if (!visionMessage.image_urls?.length) { + return; + } + + const imageCount = visionMessage.image_urls.length; + const plural = imageCount > 1; + visionMessage.content = createVisionPrompt(plural); + visionMessage = formatMessage({ message: visionMessage, endpoint: EModelEndpoint.openAI }); + + visionPromise = openai.chat.completions.create({ + model: 'gpt-4-vision-preview', + messages: [visionMessage], + max_tokens: 4000, + }); + + const pluralized = plural ? 's' : ''; + body.additional_instructions = `${ + body.additional_instructions ? `${body.additional_instructions}\n` : '' + }The user has uploaded ${imageCount} image${pluralized}. + Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${ + plural ? '' : 'a ' +}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`; + + return files; + }; + + const initializeThread = async () => { + /** @type {[ undefined | MongoFile[]]}*/ + const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]); + // TODO: may allow multiple messages to be created beforehand in a future update + const initThreadBody = { + messages: [userMessage], + metadata: { + user: req.user.id, + conversationId, + }, + }; + + if (processedFiles) { + for (const file of processedFiles) { + if (file.source !== FileSources.openai) { + attachedFileIds.delete(file.file_id); + const index = file_ids.indexOf(file.file_id); + if (index > -1) { + file_ids.splice(index, 1); + } + } + } + + userMessage.file_ids = file_ids; + } + + const result = await initThread({ openai, body: initThreadBody, thread_id }); + thread_id = result.thread_id; + + createOnTextProgress({ + openai, + conversationId, + userMessageId, + messageId: responseMessageId, + thread_id, + }); + + requestMessage = { + user: req.user.id, + text, + messageId: userMessageId, + parentMessageId, + // TODO: make sure client sends correct format for `files`, use zod + files, + file_ids, + conversationId, + isCreatedByUser: true, + assistant_id, + thread_id, + model: assistant_id, + }; + + previousMessages.push(requestMessage); + + /* asynchronous */ + saveUserMessage({ ...requestMessage, model }); + + conversation = { + conversationId, + endpoint: EModelEndpoint.assistants, + promptPrefix: promptPrefix, + instructions: instructions, + assistant_id, + // model, + }; + + if (file_ids.length) { + conversation.file_ids = file_ids; + } + }; + + const promises = [initializeThread(), checkBalanceBeforeRun()]; + await Promise.all(promises); + + const sendInitialResponse = () => { + sendMessage(res, { + sync: true, + conversationId, + // messages: previousMessages, + requestMessage, + responseMessage: { + user: req.user.id, + messageId: openai.responseMessage.messageId, + parentMessageId: userMessageId, + conversationId, + assistant_id, + thread_id, + model: assistant_id, + }, + }); + }; + + /** @type {RunResponse | typeof StreamRunManager | undefined} */ + let response; + + const processRun = async (retry = false) => { + if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + openai.attachedFileIds = attachedFileIds; + openai.visionPromise = visionPromise; + if (retry) { + response = await runAssistant({ + openai, + thread_id, + run_id, + in_progress: openai.in_progress, + }); + return; + } + + /* NOTE: + * By default, a Run will use the model and tools configuration specified in Assistant object, + * but you can override most of these when creating the Run for added flexibility: + */ + const run = await createRun({ + openai, + thread_id, + body, + }); + + run_id = run.id; + await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes); + sendInitialResponse(); + + // todo: retry logic + response = await runAssistant({ openai, thread_id, run_id }); + return; + } + + /** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */ + const handlers = { + [AssistantStreamEvents.ThreadRunCreated]: async (event) => { + await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes); + run_id = event.data.id; + sendInitialResponse(); + }, + }; + + const streamRunManager = new StreamRunManager({ + req, + res, + openai, + handlers, + thread_id, + visionPromise, + attachedFileIds, + responseMessage: openai.responseMessage, + // streamOptions: { + + // }, + }); + + await streamRunManager.runAssistant({ + thread_id, + body, + }); + + response = streamRunManager; + }; + + await processRun(); + logger.debug('[/assistants/chat/] response', { + run: response.run, + steps: response.steps, }); - const response = await handleRun({ openai, thread_id, run_id: run.id }); - // TODO: parse responses, save to db, send to user + if (response.run.status === RunStatus.CANCELLED) { + logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`'); + return res.end(); + } + + if (response.run.status === RunStatus.IN_PROGRESS) { + processRun(true); + } + + completedRun = response.run; + + /** @type {ResponseMessage} */ + const responseMessage = { + ...(response.responseMessage ?? response.finalMessage), + parentMessageId: userMessageId, + conversationId, + user: req.user.id, + assistant_id, + thread_id, + model: assistant_id, + }; sendMessage(res, { - title: 'New Chat', final: true, - conversation: { - conversationId: 'fake-convo-id', - title: 'New Chat', - }, + conversation, requestMessage: { - messageId: 'fake-user-message-id', - parentMessageId: '00000000-0000-0000-0000-000000000000', - conversationId: 'fake-convo-id', - sender: 'User', - text: req.body.text, - isCreatedByUser: true, - }, - responseMessage: { - messageId: 'fake-response-id', - conversationId: 'fake-convo-id', - parentMessageId: 'fake-user-message-id', - isCreatedByUser: false, - isEdited: false, - model: 'gpt-3.5-turbo-1106', - sender: 'Assistant', - text: response.choices[0].text, + parentMessageId, + thread_id, }, }); res.end(); + + await saveAssistantMessage({ ...responseMessage, model }); + + if (parentMessageId === Constants.NO_PARENT && !_thread_id) { + addTitle(req, { + text, + responseText: response.text, + conversationId, + client, + }); + } + + await addThreadMetadata({ + openai, + thread_id, + messageId: responseMessage.messageId, + messages: response.messages, + }); + + if (!response.run.usage) { + await sleep(3000); + completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id); + if (completedRun.usage) { + await recordUsage({ + ...completedRun.usage, + user: req.user.id, + model: completedRun.model ?? model, + conversationId, + }); + } + } else { + await recordUsage({ + ...response.run.usage, + user: req.user.id, + model: response.run.model ?? model, + conversationId, + }); + } } catch (error) { - // res.status(500).json({ error: error.message }); - logger.error('[/assistants/chat/]', error); - res.end(); + await handleError(error); } }); diff --git a/api/server/routes/assistants/tools.js b/api/server/routes/assistants/tools.js new file mode 100644 index 00000000000..324b6209589 --- /dev/null +++ b/api/server/routes/assistants/tools.js @@ -0,0 +1,8 @@ +const express = require('express'); +const { getAvailableTools } = require('~/server/controllers/PluginController'); + +const router = express.Router(); + +router.get('/', getAvailableTools); + +module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 85889f4b818..a9f6772deea 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,4 +1,5 @@ const express = require('express'); +const { defaultSocialLogins } = require('librechat-data-provider'); const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); @@ -7,21 +8,27 @@ const emailLoginEnabled = process.env.ALLOW_EMAIL_LOGIN === undefined || isEnabled(process.env.ALLOW_EMAIL_LOGIN); router.get('/', async function (req, res) { + const isBirthday = () => { + const today = new Date(); + return today.getMonth() === 1 && today.getDate() === 11; + }; + try { const payload = { appTitle: process.env.APP_TITLE || 'LibreChat', - googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, + socialLogins: req.app.locals.socialLogins ?? defaultSocialLogins, + discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET, facebookLoginEnabled: !!process.env.FACEBOOK_CLIENT_ID && !!process.env.FACEBOOK_CLIENT_SECRET, + githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, + googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, openidLoginEnabled: !!process.env.OPENID_CLIENT_ID && !!process.env.OPENID_CLIENT_SECRET && !!process.env.OPENID_ISSUER && !!process.env.OPENID_SESSION_SECRET, - openidLabel: process.env.OPENID_BUTTON_LABEL || 'Login with OpenID', + openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', openidImageUrl: process.env.OPENID_IMAGE_URL, - githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, - discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET, serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', emailLoginEnabled, registrationEnabled: isEnabled(process.env.ALLOW_REGISTRATION), @@ -32,6 +39,12 @@ router.get('/', async function (req, res) { !!process.env.EMAIL_PASSWORD && !!process.env.EMAIL_FROM, checkBalance: isEnabled(process.env.CHECK_BALANCE), + showBirthdayIcon: + isBirthday() || + isEnabled(process.env.SHOW_BIRTHDAY_ICON) || + process.env.SHOW_BIRTHDAY_ICON === '', + helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai', + interface: req.app.locals.interface, }; if (typeof process.env.CUSTOM_FOOTER === 'string') { diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 4395df0fee1..0fa45223805 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,14 +1,23 @@ const express = require('express'); -const router = express.Router(); -const { getConvosByPage, deleteConvos } = require('~/models/Conversation'); +const { CacheKeys } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { getConvo, saveConvo } = require('~/models'); +const getLogStores = require('~/cache/getLogStores'); +const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); +const router = express.Router(); router.use(requireJwtAuth); router.get('/', async (req, res) => { - const pageNumber = req.query.pageNumber || 1; + let pageNumber = req.query.pageNumber || 1; + pageNumber = parseInt(pageNumber, 10); + + if (isNaN(pageNumber) || pageNumber < 1) { + return res.status(400).json({ error: 'Invalid page number' }); + } + res.status(200).send(await getConvosByPage(req.user.id, pageNumber)); }); @@ -17,32 +26,64 @@ router.get('/:conversationId', async (req, res) => { const convo = await getConvo(req.user.id, conversationId); if (convo) { - res.status(200).send(convo); + res.status(200).json(convo); } else { res.status(404).end(); } }); +router.post('/gen_title', async (req, res) => { + const { conversationId } = req.body; + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${conversationId}`; + let title = await titleCache.get(key); + + if (!title) { + await sleep(2500); + title = await titleCache.get(key); + } + + if (title) { + await titleCache.delete(key); + res.status(200).json({ title }); + } else { + res.status(404).json({ + message: 'Title not found or method not implemented for the conversation\'s endpoint', + }); + } +}); + router.post('/clear', async (req, res) => { let filter = {}; - const { conversationId, source } = req.body.arg; + const { conversationId, source, thread_id } = req.body.arg; if (conversationId) { filter = { conversationId }; } - // for debugging deletion source - // logger.debug('source:', source); - if (source === 'button' && !conversationId) { return res.status(200).send('No conversationId provided'); } + if (thread_id) { + /** @type {{ openai: OpenAI}} */ + const { openai } = await initializeClient({ req, res }); + try { + const response = await openai.beta.threads.del(thread_id); + logger.debug('Deleted OpenAI thread:', response); + } catch (error) { + logger.error('Error deleting OpenAI thread:', error); + } + } + + // for debugging deletion source + // logger.debug('source:', source); + try { const dbResponse = await deleteConvos(req.user.id, filter); - res.status(201).send(dbResponse); + res.status(201).json(dbResponse); } catch (error) { logger.error('Error clearing conversations', error); - res.status(500).send(error); + res.status(500).send('Error clearing conversations'); } }); @@ -51,10 +92,10 @@ router.post('/update', async (req, res) => { try { const dbResponse = await saveConvo(req.user.id, update); - res.status(201).send(dbResponse); + res.status(201).json(dbResponse); } catch (error) { logger.error('Error updating conversation', error); - res.status(500).send(error); + res.status(500).send('Error updating conversation'); } }); diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 34dd9d6dfac..c7bf128d7cb 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/edit/custom.js b/api/server/routes/edit/custom.js index dd63c96c8f9..0bf97ba1800 100644 --- a/api/server/routes/edit/custom.js +++ b/api/server/routes/edit/custom.js @@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -13,8 +14,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/edit/google.js b/api/server/routes/edit/google.js index e4dfbcd1412..7482f11b4c0 100644 --- a/api/server/routes/edit/google.js +++ b/api/server/routes/edit/google.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 8ddf92c2507..61d76178f4f 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -1,88 +1,94 @@ const express = require('express'); -const router = express.Router(); -const { validateTools } = require('~/app'); +const throttle = require('lodash/throttle'); const { getResponseSender } = require('librechat-data-provider'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); -const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { handleAbort, createAbortController, handleAbortError, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, } = require('~/server/middleware'); +const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); +const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); +const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { validateTools } = require('~/app'); const { logger } = require('~/config'); +const router = express.Router(); + router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/edit/gptPlugins]', { - text, - generation, - isContinued, - conversationId, - ...endpointOption, - }); - let metadata; - let userMessage; - let promptTokens; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const userMessageId = parentMessageId; - const user = req.user.id; - - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; - - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } - }; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (plugin.loading === true) { - plugin.loading = false; +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + + logger.debug('[/edit/gptPlugins]', { + text, + generation, + isContinued, + conversationId, + ...endpointOption, + }); + + let userMessage; + let promptTokens; + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.modelOptions.model, + }); + const userMessageId = parentMessageId; + const user = req.user.id; + + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; + + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } } + }; + + const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + if (plugin.loading === true) { + plugin.loading = false; + } - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ + throttledSaveMessage({ messageId: responseMessageId, sender, conversationId, @@ -94,104 +100,95 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, error: false, user, }); - } + }, + }); - if (saveDelay < 500) { - saveDelay = 500; + const onAgentAction = (action, start = false) => { + const formattedAction = formatAction(action); + plugin.inputs.push(formattedAction); + plugin.latest = formattedAction.plugin; + if (!start) { + saveMessage({ ...userMessage, user }); } - }, - }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start) { + sendIntermediateMessage(res, { plugin }); + // logger.debug('PLUGIN ACTION', formattedAction); + }; + + const onChainEnd = (data) => { + let { intermediateSteps: steps } = data; + plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; + plugin.loading = false; saveMessage({ ...userMessage, user }); - } - sendIntermediateMessage(res, { plugin }); - // logger.debug('PLUGIN ACTION', formattedAction); - }; - - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage({ ...userMessage, user }); - sendIntermediateMessage(res, { plugin }); - // logger.debug('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender, - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, + sendIntermediateMessage(res, { plugin }); + // logger.debug('CHAIN END', plugin.outputs); + }; + + const getAbortData = () => ({ + sender, conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onStart, - addMetadata, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - plugin, - parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, + promptTokens, }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + user, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getReqData, + onAgentAction, + onChainEnd, + onStart, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - if (metadata) { - response = { ...response, ...metadata }; + logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); + response.plugin = { ...plugin, loading: false }; + await saveMessage({ ...response, user }); + + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); } - - logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); - response.plugin = { ...plugin, loading: false }; - await saveMessage({ ...response, user }); - - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } -}); + }, +); module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index e54881148dc..ae26b235c79 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -13,8 +14,15 @@ const router = express.Router(); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/files/avatar.js b/api/server/routes/files/avatar.js index 5abba85f9e8..71ade965cde 100644 --- a/api/server/routes/files/avatar.js +++ b/api/server/routes/files/avatar.js @@ -1,38 +1,36 @@ -const express = require('express'); const multer = require('multer'); - -const uploadAvatar = require('~/server/services/Files/images/avatar'); -const { requireJwtAuth } = require('~/server/middleware/'); -const User = require('~/models/User'); +const express = require('express'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); +const { logger } = require('~/config'); const upload = multer(); const router = express.Router(); -router.post('/', requireJwtAuth, upload.single('input'), async (req, res) => { +router.post('/', upload.single('input'), async (req, res) => { try { const userId = req.user.id; const { manual } = req.body; const input = req.file.buffer; + if (!userId) { throw new Error('User ID is undefined'); } - // TODO: do not use Model directly, instead use a service method that uses the model - const user = await User.findById(userId).lean(); - - if (!user) { - throw new Error('User not found'); - } - const url = await uploadAvatar({ - input, + const fileStrategy = req.app.locals.fileStrategy; + const webPBuffer = await resizeAvatar({ userId, - manual, - fileStrategy: req.app.locals.fileStrategy, + input, }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + const url = await processAvatar({ buffer: webPBuffer, userId, manual }); + res.json({ url }); } catch (error) { - res.status(500).json({ message: 'An error occurred while uploading the profile picture' }); + const message = 'An error occurred while uploading the profile picture'; + logger.error(message, error); + res.status(500).json({ message }); } }); diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index 3fea2e5d07b..812d4bd33d7 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -1,14 +1,18 @@ -const { z } = require('zod'); +const fs = require('fs').promises; const express = require('express'); -const { FileSources } = require('librechat-data-provider'); +const { isUUID, FileSources } = require('librechat-data-provider'); +const { + filterFile, + processFileUpload, + processDeleteRequest, +} = require('~/server/services/Files/process'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { deleteFiles, getFiles } = require('~/models'); +const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); const router = express.Router(); -const isUUID = z.string().uuid(); - router.get('/', async (req, res) => { try { const files = await getFiles({ user: req.user.id }); @@ -19,6 +23,15 @@ router.get('/', async (req, res) => { } }); +router.get('/config', async (req, res) => { + try { + res.status(200).json(req.app.locals.fileConfig); + } catch (error) { + logger.error('[/files] Error getting fileConfig', error); + res.status(400).json({ message: 'Error in request', error: error.message }); + } +}); + router.delete('/', async (req, res) => { try { const { files: _files } = req.body; @@ -31,6 +44,11 @@ router.delete('/', async (req, res) => { if (!file.filepath) { return false; } + + if (/^(file|assistant)-/.test(file.file_id)) { + return true; + } + return isUUID.safeParse(file.file_id).success; }); @@ -39,33 +57,114 @@ router.delete('/', async (req, res) => { return; } - const file_ids = files.map((file) => file.file_id); - const deletionMethods = {}; - const promises = []; - promises.push(await deleteFiles(file_ids)); + await processDeleteRequest({ req, files }); + + res.status(200).json({ message: 'Files deleted successfully' }); + } catch (error) { + logger.error('[/files] Error deleting files:', error); + res.status(400).json({ message: 'Error in request', error: error.message }); + } +}); + +router.get('/download/:userId/:file_id', async (req, res) => { + try { + const { userId, file_id } = req.params; + logger.debug(`File download requested by user ${userId}: ${file_id}`); - for (const file of files) { - const source = file.source ?? FileSources.local; + if (userId !== req.user.id) { + logger.warn(`${errorPrefix} forbidden: ${file_id}`); + return res.status(403).send('Forbidden'); + } - if (deletionMethods[source]) { - promises.push(deletionMethods[source](req, file)); - continue; - } + const [file] = await getFiles({ file_id }); + const errorPrefix = `File download requested by user ${userId}`; - const { deleteFile } = getStrategyFunctions(source); - if (!deleteFile) { - throw new Error(`Delete function not implemented for ${source}`); - } + if (!file) { + logger.warn(`${errorPrefix} not found: ${file_id}`); + return res.status(404).send('File not found'); + } - deletionMethods[source] = deleteFile; - promises.push(deleteFile(req, file)); + if (!file.filepath.includes(userId)) { + logger.warn(`${errorPrefix} forbidden: ${file_id}`); + return res.status(403).send('Forbidden'); } - await Promise.all(promises); - res.status(200).json({ message: 'Files deleted successfully' }); + if (file.source === FileSources.openai && !file.model) { + logger.warn(`${errorPrefix} has no associated model: ${file_id}`); + return res.status(400).send('The model used when creating this file is not available'); + } + + const { getDownloadStream } = getStrategyFunctions(file.source); + if (!getDownloadStream) { + logger.warn(`${errorPrefix} has no stream method implemented: ${file.source}`); + return res.status(501).send('Not Implemented'); + } + + const setHeaders = () => { + res.setHeader('Content-Disposition', `attachment; filename="${file.filename}"`); + res.setHeader('Content-Type', 'application/octet-stream'); + res.setHeader('X-File-Metadata', JSON.stringify(file)); + }; + + /** @type {{ body: import('stream').PassThrough } | undefined} */ + let passThrough; + /** @type {ReadableStream | undefined} */ + let fileStream; + if (file.source === FileSources.openai) { + req.body = { model: file.model }; + const { openai } = await initializeClient({ req, res }); + logger.debug(`Downloading file ${file_id} from OpenAI`); + passThrough = await getDownloadStream(file_id, openai); + setHeaders(); + logger.debug(`File ${file_id} downloaded from OpenAI`); + passThrough.body.pipe(res); + } else { + fileStream = getDownloadStream(file_id); + setHeaders(); + fileStream.pipe(res); + } } catch (error) { - logger.error('[/files] Error deleting files:', error); - res.status(400).json({ message: 'Error in request', error: error.message }); + logger.error('Error downloading file:', error); + res.status(500).send('Error downloading file'); + } +}); + +router.post('/', async (req, res) => { + const file = req.file; + const metadata = req.body; + let cleanup = true; + + try { + filterFile({ req, file }); + + metadata.temp_file_id = metadata.file_id; + metadata.file_id = req.file_id; + + await processFileUpload({ req, res, file, metadata }); + } catch (error) { + let message = 'Error processing file'; + logger.error('[/files] Error processing file:', error); + cleanup = false; + + if (error.message?.includes('file_ids')) { + message += ': ' + error.message; + } + + // TODO: delete remote file if it exists + try { + await fs.unlink(file.path); + } catch (error) { + logger.error('[/files] Error deleting file:', error); + } + res.status(500).json({ message }); + } + + if (cleanup) { + try { + await fs.unlink(file.path); + } catch (error) { + logger.error('[/files/images] Error deleting file after file processing:', error); + } } }); diff --git a/api/server/routes/files/images.js b/api/server/routes/files/images.js index 30d3c3cac60..374711c4acd 100644 --- a/api/server/routes/files/images.js +++ b/api/server/routes/files/images.js @@ -1,60 +1,36 @@ -const { z } = require('zod'); +const path = require('path'); const fs = require('fs').promises; const express = require('express'); -const upload = require('./multer'); -const { processImageUpload } = require('~/server/services/Files/process'); +const { filterFile, processImageFile } = require('~/server/services/Files/process'); const { logger } = require('~/config'); const router = express.Router(); -router.post('/', upload.single('file'), async (req, res) => { - const file = req.file; +router.post('/', async (req, res) => { const metadata = req.body; - // TODO: add file size/type validation - - const uuidSchema = z.string().uuid(); try { - if (!file) { - throw new Error('No file provided'); - } + filterFile({ req, file: req.file, image: true }); - if (!metadata.file_id) { - throw new Error('No file_id provided'); - } - - if (!metadata.width) { - throw new Error('No width provided'); - } - - if (!metadata.height) { - throw new Error('No height provided'); - } - /* parse to validate api call */ - uuidSchema.parse(metadata.file_id); metadata.temp_file_id = metadata.file_id; metadata.file_id = req.file_id; - await processImageUpload({ req, res, file, metadata }); + await processImageFile({ req, res, file: req.file, metadata }); } catch (error) { + // TODO: delete remote file if it exists logger.error('[/files/images] Error processing file:', error); try { - await fs.unlink(file.path); + const filepath = path.join( + req.app.locals.paths.imageOutput, + req.user.id, + path.basename(req.file.filename), + ); + await fs.unlink(filepath); } catch (error) { logger.error('[/files/images] Error deleting file:', error); } res.status(500).json({ message: 'Error processing file' }); } - - // do this if strategy is not local - // finally { - // try { - // // await fs.unlink(file.path); - // } catch (error) { - // logger.error('[/files/images] Error deleting file:', error); - - // } - // } }); module.exports = router; diff --git a/api/server/routes/files/index.js b/api/server/routes/files/index.js index 9afb900bbe6..c9f5ce1679e 100644 --- a/api/server/routes/files/index.js +++ b/api/server/routes/files/index.js @@ -1,24 +1,27 @@ const express = require('express'); -const router = express.Router(); -const { - uaParser, - checkBan, - requireJwtAuth, - // concurrentLimiter, - // messageIpLimiter, - // messageUserLimiter, -} = require('../../middleware'); +const createMulterInstance = require('./multer'); +const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware'); const files = require('./files'); const images = require('./images'); const avatar = require('./avatar'); -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); +const initialize = async () => { + const router = express.Router(); + router.use(requireJwtAuth); + router.use(checkBan); + router.use(uaParser); -router.use('/', files); -router.use('/images', images); -router.use('/images/avatar', avatar); + const upload = await createMulterInstance(); + const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters(); + router.post('*', fileUploadIpLimiter, fileUploadUserLimiter); + router.post('/', upload.single('file')); + router.post('/images', upload.single('file')); -module.exports = router; + router.use('/', files); + router.use('/images', images); + router.use('/images/avatar', avatar); + return router; +}; + +module.exports = { initialize }; diff --git a/api/server/routes/files/multer.js b/api/server/routes/files/multer.js index d5aea05a373..2162a0d8075 100644 --- a/api/server/routes/files/multer.js +++ b/api/server/routes/files/multer.js @@ -2,13 +2,12 @@ const fs = require('fs'); const path = require('path'); const crypto = require('crypto'); const multer = require('multer'); - -const supportedTypes = ['image/jpeg', 'image/jpg', 'image/png', 'image/webp']; -const sizeLimit = 20 * 1024 * 1024; // 20 MB +const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const storage = multer.diskStorage({ destination: function (req, file, cb) { - const outputPath = path.join(req.app.locals.paths.imageOutput, 'temp'); + const outputPath = path.join(req.app.locals.paths.uploads, 'temp', req.user.id); if (!fs.existsSync(outputPath)) { fs.mkdirSync(outputPath, { recursive: true }); } @@ -16,22 +15,31 @@ const storage = multer.diskStorage({ }, filename: function (req, file, cb) { req.file_id = crypto.randomUUID(); - const fileExt = path.extname(file.originalname); - cb(null, `img-${req.file_id}${fileExt}`); + file.originalname = decodeURIComponent(file.originalname); + cb(null, `${file.originalname}`); }, }); const fileFilter = (req, file, cb) => { - if (!supportedTypes.includes(file.mimetype)) { - return cb( - new Error('Unsupported file type. Only JPEG, JPG, PNG, and WEBP files are allowed.'), - false, - ); + if (!file) { + return cb(new Error('No file provided'), false); + } + + if (!defaultFileConfig.checkType(file.mimetype)) { + return cb(new Error('Unsupported file type: ' + file.mimetype), false); } cb(null, true); }; -const upload = multer({ storage, fileFilter, limits: { fileSize: sizeLimit } }); +const createMulterInstance = async () => { + const customConfig = await getCustomConfig(); + const fileConfig = mergeFileConfig(customConfig?.fileConfig); + return multer({ + storage, + fileFilter, + limits: { fileSize: fileConfig.serverFileSizeLimit }, + }); +}; -module.exports = upload; +module.exports = createMulterInstance; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 05a4595b02d..8b1ffd8fe8c 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -17,6 +17,7 @@ const user = require('./user'); const config = require('./config'); const assistants = require('./assistants'); const files = require('./files'); +const staticRoute = require('./static'); module.exports = { search, @@ -38,4 +39,5 @@ module.exports = { config, assistants, files, + staticRoute, }; diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 1e2faafe7bf..d53dacae495 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -36,7 +36,7 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = const { messageId, model } = req.params; const { text } = req.body; const tokenCount = await countTokens(text, model); - res.status(201).send(await updateMessage({ messageId, text, tokenCount })); + res.status(201).json(await updateMessage({ messageId, text, tokenCount })); }); // DELETE diff --git a/api/server/routes/models.js b/api/server/routes/models.js index 383a63c1136..e3272087a76 100644 --- a/api/server/routes/models.js +++ b/api/server/routes/models.js @@ -1,8 +1,8 @@ const express = require('express'); -const router = express.Router(); -const controller = require('../controllers/ModelController'); -const { requireJwtAuth } = require('../middleware/'); +const { modelController } = require('~/server/controllers/ModelController'); +const { requireJwtAuth } = require('~/server/middleware/'); -router.get('/', requireJwtAuth, controller); +const router = express.Router(); +router.get('/', requireJwtAuth, modelController); module.exports = router; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 816fc7200f3..e85d83d8883 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -1,3 +1,5 @@ +// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware + const passport = require('passport'); const express = require('express'); const router = express.Router(); diff --git a/api/server/routes/presets.js b/api/server/routes/presets.js index 76aaed698cd..19214a3a7d1 100644 --- a/api/server/routes/presets.js +++ b/api/server/routes/presets.js @@ -5,27 +5,28 @@ const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const { logger } = require('~/config'); const router = express.Router(); +router.use(requireJwtAuth); -router.get('/', requireJwtAuth, async (req, res) => { +router.get('/', async (req, res) => { const presets = (await getPresets(req.user.id)).map((preset) => preset); - res.status(200).send(presets); + res.status(200).json(presets); }); -router.post('/', requireJwtAuth, async (req, res) => { +router.post('/', async (req, res) => { const update = req.body || {}; update.presetId = update?.presetId || crypto.randomUUID(); try { const preset = await savePreset(req.user.id, update); - res.status(201).send(preset); + res.status(201).json(preset); } catch (error) { logger.error('[/presets] error saving preset', error); - res.status(500).send(error); + res.status(500).send('There was an error when saving the preset'); } }); -router.post('/delete', requireJwtAuth, async (req, res) => { +router.post('/delete', async (req, res) => { let filter = {}; const { presetId } = req.body || {}; @@ -37,10 +38,10 @@ router.post('/delete', requireJwtAuth, async (req, res) => { try { const deleteCount = await deletePresets(req.user.id, filter); - res.status(201).send(deleteCount); + res.status(201).json(deleteCount); } catch (error) { logger.error('[/presets/delete] error deleting presets', error); - res.status(500).send(error); + res.status(500).send('There was an error deleting the presets'); } }); diff --git a/api/server/routes/static.js b/api/server/routes/static.js new file mode 100644 index 00000000000..116f7c8dd06 --- /dev/null +++ b/api/server/routes/static.js @@ -0,0 +1,7 @@ +const express = require('express'); +const paths = require('~/config/paths'); + +const router = express.Router(); +router.use(express.static(paths.imageOutput)); + +module.exports = router; diff --git a/api/server/routes/tokenizer.js b/api/server/routes/tokenizer.js index 581f82bf2ad..e12a86bde16 100644 --- a/api/server/routes/tokenizer.js +++ b/api/server/routes/tokenizer.js @@ -11,7 +11,7 @@ router.post('/', requireJwtAuth, async (req, res) => { res.send({ count }); } catch (e) { logger.error('[/tokenizer] Error counting tokens', e); - res.status(500).send(e.message); + res.status(500).json('Error counting tokens'); } }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js new file mode 100644 index 00000000000..22770f15500 --- /dev/null +++ b/api/server/services/ActionService.js @@ -0,0 +1,148 @@ +const { AuthTypeEnum, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider'); +const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); +const { getActions } = require('~/models/Action'); +const { logger } = require('~/config'); + +/** + * Parses the domain for an action. + * + * Azure OpenAI Assistants API doesn't support periods in function + * names due to `[a-zA-Z0-9_-]*` Regex Validation. + * + * @param {Express.Request} req - Express Request object + * @param {string} domain - The domain for the actoin + * @param {boolean} inverse - If true, replaces periods with `actionDomainSeparator` + * @returns {string} The parsed domain + */ +function domainParser(req, domain, inverse = false) { + if (!domain) { + return; + } + + if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + return domain; + } + + if (inverse) { + return domain.replace(/\./g, actionDomainSeparator); + } + + return domain.replace(actionDomainSeparator, '.'); +} + +/** + * Loads action sets based on the user and assistant ID. + * + * @param {Object} searchParams - The parameters for loading action sets. + * @param {string} searchParams.user - The user identifier. + * @param {string} searchParams.assistant_id - The assistant identifier. + * @returns {Promise<Action[] | null>} A promise that resolves to an array of actions or `null` if no match. + */ +async function loadActionSets(searchParams) { + return await getActions(searchParams, true); +} + +/** + * Creates a general tool for an entire action set. + * + * @param {Object} params - The parameters for loading action sets. + * @param {Action} params.action - The action set. Necessary for decrypting authentication values. + * @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call. + * @returns { { _call: (toolInput: Object) => unknown} } An object with `_call` method to execute the tool input. + */ +function createActionTool({ action, requestBuilder }) { + action.metadata = decryptMetadata(action.metadata); + const _call = async (toolInput) => { + try { + requestBuilder.setParams(toolInput); + if (action.metadata.auth && action.metadata.auth.type !== AuthTypeEnum.None) { + await requestBuilder.setAuth(action.metadata); + } + const res = await requestBuilder.execute(); + if (typeof res.data === 'object') { + return JSON.stringify(res.data); + } + return res.data; + } catch (error) { + logger.error(`API call to ${action.metadata.domain} failed`, error); + if (error.response) { + const { status, data } = error.response; + return `API call to ${ + action.metadata.domain + } failed with status ${status}: ${JSON.stringify(data)}`; + } + + return `API call to ${action.metadata.domain} failed.`; + } + }; + + return { + _call, + }; +} + +/** + * Encrypts sensitive metadata values for an action. + * + * @param {ActionMetadata} metadata - The action metadata to encrypt. + * @returns {ActionMetadata} The updated action metadata with encrypted values. + */ +function encryptMetadata(metadata) { + const encryptedMetadata = { ...metadata }; + + // ServiceHttp + if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { + if (metadata.api_key) { + encryptedMetadata.api_key = encryptV2(metadata.api_key); + } + } + + // OAuth + else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { + if (metadata.oauth_client_id) { + encryptedMetadata.oauth_client_id = encryptV2(metadata.oauth_client_id); + } + if (metadata.oauth_client_secret) { + encryptedMetadata.oauth_client_secret = encryptV2(metadata.oauth_client_secret); + } + } + + return encryptedMetadata; +} + +/** + * Decrypts sensitive metadata values for an action. + * + * @param {ActionMetadata} metadata - The action metadata to decrypt. + * @returns {ActionMetadata} The updated action metadata with decrypted values. + */ +function decryptMetadata(metadata) { + const decryptedMetadata = { ...metadata }; + + // ServiceHttp + if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { + if (metadata.api_key) { + decryptedMetadata.api_key = decryptV2(metadata.api_key); + } + } + + // OAuth + else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { + if (metadata.oauth_client_id) { + decryptedMetadata.oauth_client_id = decryptV2(metadata.oauth_client_id); + } + if (metadata.oauth_client_secret) { + decryptedMetadata.oauth_client_secret = decryptV2(metadata.oauth_client_secret); + } + } + + return decryptedMetadata; +} + +module.exports = { + loadActionSets, + createActionTool, + encryptMetadata, + decryptMetadata, + domainParser, +}; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index b1f7cf57d8b..e4cb416b416 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,7 +1,28 @@ -const { FileSources } = require('librechat-data-provider'); +const { + Constants, + FileSources, + Capabilities, + EModelEndpoint, + defaultSocialLogins, + validateAzureGroups, + mapModelToAzureConfig, + assistantEndpointSchema, + deprecatedAzureVariables, + conflictingAzureVariables, +} = require('librechat-data-provider'); const { initializeFirebase } = require('./Files/Firebase/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); +const handleRateLimits = require('./Config/handleRateLimits'); +const { loadAndFormatTools } = require('./ToolService'); const paths = require('~/config/paths'); +const { logger } = require('~/config'); + +const secretDefaults = { + CREDS_KEY: 'f34be427ebb29de8d88c107a71546019685ed8b241d8f2ed00c3df97ad2566f0', + CREDS_IV: 'e2341419ec3dd3d19b13a1a87fafcbfb', + JWT_SECRET: '16f8c0ef4a5d391b26034086c628469d3f9f497f08163ab9b40137092f2909ef', + JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418', +}; /** * @@ -10,7 +31,9 @@ const paths = require('~/config/paths'); * @param {Express.Application} app - The Express application object. */ const AppService = async (app) => { + /** @type {TCustomConfig}*/ const config = (await loadCustomConfig()) ?? {}; + const fileStrategy = config.fileStrategy ?? FileSources.local; process.env.CDN_PROVIDER = fileStrategy; @@ -18,10 +41,168 @@ const AppService = async (app) => { initializeFirebase(); } + /** @type {Record<string, FunctionTool} */ + const availableTools = loadAndFormatTools({ + directory: paths.structuredTools, + filter: new Set([ + 'ChatTool.js', + 'CodeSherpa.js', + 'CodeSherpaTools.js', + 'E2BTools.js', + 'extractionChain.js', + ]), + }); + + const socialLogins = config?.registration?.socialLogins ?? defaultSocialLogins; + + if (!Object.keys(config).length) { + app.locals = { + availableTools, + fileStrategy, + socialLogins, + paths, + }; + + return; + } + + if (config.version !== Constants.CONFIG_VERSION) { + logger.info( + `\nOutdated Config version: ${config.version}. Current version: ${Constants.CONFIG_VERSION}\n\nCheck out the latest config file guide for new options and features.\nhttps://docs.librechat.ai/install/configuration/custom_config.html\n\n`, + ); + } + + handleRateLimits(config?.rateLimits); + + const endpointLocals = {}; + + if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) { + const { groups, ...azureConfiguration } = config.endpoints[EModelEndpoint.azureOpenAI]; + const { isValid, modelNames, modelGroupMap, groupMap, errors } = validateAzureGroups(groups); + + if (!isValid) { + const errorString = errors.join('\n'); + const errorMessage = 'Invalid Azure OpenAI configuration:\n' + errorString; + logger.error(errorMessage); + throw new Error(errorMessage); + } + + const assistantModels = []; + const assistantGroups = new Set(); + for (const modelName of modelNames) { + mapModelToAzureConfig({ modelName, modelGroupMap, groupMap }); + const groupName = modelGroupMap?.[modelName]?.group; + const modelGroup = groupMap?.[groupName]; + let supportsAssistants = modelGroup?.assistants || modelGroup?.[modelName]?.assistants; + if (supportsAssistants) { + assistantModels.push(modelName); + !assistantGroups.has(groupName) && assistantGroups.add(groupName); + } + } + + if (azureConfiguration.assistants && assistantModels.length === 0) { + throw new Error( + 'No Azure models are configured to support assistants. Please remove the `assistants` field or configure at least one model to support assistants.', + ); + } + + endpointLocals[EModelEndpoint.azureOpenAI] = { + modelNames, + modelGroupMap, + groupMap, + assistantModels, + assistantGroups: Array.from(assistantGroups), + ...azureConfiguration, + }; + + deprecatedAzureVariables.forEach(({ key, description }) => { + if (process.env[key]) { + logger.warn( + `The \`${key}\` environment variable (related to ${description}) should not be used in combination with the \`azureOpenAI\` endpoint configuration, as you will experience conflicts and errors.`, + ); + } + }); + + conflictingAzureVariables.forEach(({ key }) => { + if (process.env[key]) { + logger.warn( + `The \`${key}\` environment variable should not be used in combination with the \`azureOpenAI\` endpoint configuration, as you may experience with the defined placeholders for mapping to the current model grouping using the same name.`, + ); + } + }); + + if (azureConfiguration.assistants) { + endpointLocals[EModelEndpoint.assistants] = { + // Note: may need to add retrieval models here in the future + capabilities: [Capabilities.tools, Capabilities.actions, Capabilities.code_interpreter], + }; + } + } + + if (config?.endpoints?.[EModelEndpoint.assistants]) { + const assistantsConfig = config.endpoints[EModelEndpoint.assistants]; + const parsedConfig = assistantEndpointSchema.parse(assistantsConfig); + if (assistantsConfig.supportedIds?.length && assistantsConfig.excludedIds?.length) { + logger.warn( + `Both \`supportedIds\` and \`excludedIds\` are defined for the ${EModelEndpoint.assistants} endpoint; \`excludedIds\` field will be ignored.`, + ); + } + + const prevConfig = endpointLocals[EModelEndpoint.assistants] ?? {}; + + /** @type {Partial<TAssistantEndpoint>} */ + endpointLocals[EModelEndpoint.assistants] = { + ...prevConfig, + retrievalModels: parsedConfig.retrievalModels, + disableBuilder: parsedConfig.disableBuilder, + pollIntervalMs: parsedConfig.pollIntervalMs, + supportedIds: parsedConfig.supportedIds, + capabilities: parsedConfig.capabilities, + excludedIds: parsedConfig.excludedIds, + timeoutMs: parsedConfig.timeoutMs, + }; + } + + try { + const response = await fetch(`${process.env.RAG_API_URL}/health`); + if (response?.ok && response?.status === 200) { + logger.info(`RAG API is running and reachable at ${process.env.RAG_API_URL}.`); + } + } catch (error) { + logger.warn( + `RAG API is either not running or not reachable at ${process.env.RAG_API_URL}, you may experience errors with file uploads.`, + ); + } + app.locals = { + socialLogins, + availableTools, fileStrategy, + fileConfig: config?.fileConfig, + interface: config?.interface, paths, + ...endpointLocals, }; + + let hasDefaultSecrets = false; + for (const [key, value] of Object.entries(secretDefaults)) { + if (process.env[key] === value) { + logger.warn(`Default value for ${key} is being used.`); + !hasDefaultSecrets && (hasDefaultSecrets = true); + } + } + + if (hasDefaultSecrets) { + logger.info( + `Please replace any default secret values. + + For your conveninence, fork & run this replit to generate your own secret values: + + https://replit.com/@daavila/crypto#index.js + + `, + ); + } }; module.exports = AppService; diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js new file mode 100644 index 00000000000..3a40a49b3e7 --- /dev/null +++ b/api/server/services/AppService.spec.js @@ -0,0 +1,436 @@ +const { + FileSources, + EModelEndpoint, + defaultSocialLogins, + validateAzureGroups, + deprecatedAzureVariables, + conflictingAzureVariables, +} = require('librechat-data-provider'); + +const AppService = require('./AppService'); + +jest.mock('./Config/loadCustomConfig', () => { + return jest.fn(() => + Promise.resolve({ + registration: { socialLogins: ['testLogin'] }, + fileStrategy: 'testStrategy', + }), + ); +}); +jest.mock('./Files/Firebase/initialize', () => ({ + initializeFirebase: jest.fn(), +})); +jest.mock('./ToolService', () => ({ + loadAndFormatTools: jest.fn().mockReturnValue({ + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }), +})); + +const azureGroups = [ + { + group: 'librechat-westus', + apiKey: '${WESTUS_API_KEY}', + instanceName: 'librechat-westus', + version: '2023-12-01-preview', + models: { + 'gpt-4-vision-preview': { + deploymentName: 'gpt-4-vision-preview', + version: '2024-02-15-preview', + }, + 'gpt-3.5-turbo': { + deploymentName: 'gpt-35-turbo', + }, + 'gpt-3.5-turbo-1106': { + deploymentName: 'gpt-35-turbo-1106', + }, + 'gpt-4': { + deploymentName: 'gpt-4', + }, + 'gpt-4-1106-preview': { + deploymentName: 'gpt-4-1106-preview', + }, + }, + }, + { + group: 'librechat-eastus', + apiKey: '${EASTUS_API_KEY}', + instanceName: 'librechat-eastus', + deploymentName: 'gpt-4-turbo', + version: '2024-02-15-preview', + models: { + 'gpt-4-turbo': true, + }, + }, +]; + +describe('AppService', () => { + let app; + + beforeEach(() => { + app = { locals: {} }; + process.env.CDN_PROVIDER = undefined; + }); + + it('should correctly assign process.env and app.locals based on custom config', async () => { + await AppService(app); + + expect(process.env.CDN_PROVIDER).toEqual('testStrategy'); + + expect(app.locals).toEqual({ + socialLogins: ['testLogin'], + fileStrategy: 'testStrategy', + availableTools: { + ExampleTool: { + type: 'function', + function: expect.objectContaining({ + description: 'Example tool function', + name: 'exampleFunction', + parameters: expect.objectContaining({ + type: 'object', + properties: expect.any(Object), + required: expect.arrayContaining(['param1']), + }), + }), + }, + }, + paths: expect.anything(), + }); + }); + + it('should log a warning if the config version is outdated', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + version: '0.9.0', // An outdated version for this test + registration: { socialLogins: ['testLogin'] }, + fileStrategy: 'testStrategy', + }), + ); + + await AppService(app); + + const { logger } = require('~/config'); + expect(logger.info).toHaveBeenCalledWith(expect.stringContaining('Outdated Config version')); + }); + + it('should initialize Firebase when fileStrategy is firebase', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + fileStrategy: FileSources.firebase, + }), + ); + + await AppService(app); + + const { initializeFirebase } = require('./Files/Firebase/initialize'); + expect(initializeFirebase).toHaveBeenCalled(); + + expect(process.env.CDN_PROVIDER).toEqual(FileSources.firebase); + }); + + it('should load and format tools accurately with defined structure', async () => { + const { loadAndFormatTools } = require('./ToolService'); + await AppService(app); + + expect(loadAndFormatTools).toHaveBeenCalledWith({ + directory: expect.anything(), + filter: expect.anything(), + }); + + expect(app.locals.availableTools.ExampleTool).toBeDefined(); + expect(app.locals.availableTools.ExampleTool).toEqual({ + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }); + }); + + it('should correctly configure Assistants endpoint based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.assistants]: { + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: ['id1', 'id2'], + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.assistants); + expect(app.locals[EModelEndpoint.assistants]).toEqual( + expect.objectContaining({ + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: expect.arrayContaining(['id1', 'id2']), + }), + ); + }); + + it('should correctly configure Azure OpenAI endpoint based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.azureOpenAI]: { + groups: azureGroups, + }, + }, + }), + ); + + process.env.WESTUS_API_KEY = 'westus-key'; + process.env.EASTUS_API_KEY = 'eastus-key'; + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.azureOpenAI); + const azureConfig = app.locals[EModelEndpoint.azureOpenAI]; + expect(azureConfig).toHaveProperty('modelNames'); + expect(azureConfig).toHaveProperty('modelGroupMap'); + expect(azureConfig).toHaveProperty('groupMap'); + + const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(azureGroups); + expect(azureConfig.modelNames).toEqual(modelNames); + expect(azureConfig.modelGroupMap).toEqual(modelGroupMap); + expect(azureConfig.groupMap).toEqual(groupMap); + }); + + it('should not modify FILE_UPLOAD environment variables without rate limits', async () => { + // Setup initial environment variables + process.env.FILE_UPLOAD_IP_MAX = '10'; + process.env.FILE_UPLOAD_IP_WINDOW = '15'; + process.env.FILE_UPLOAD_USER_MAX = '5'; + process.env.FILE_UPLOAD_USER_WINDOW = '20'; + + const initialEnv = { ...process.env }; + + await AppService(app); + + // Expect environment variables to remain unchanged + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual(initialEnv.FILE_UPLOAD_IP_MAX); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual(initialEnv.FILE_UPLOAD_IP_WINDOW); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual(initialEnv.FILE_UPLOAD_USER_MAX); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual(initialEnv.FILE_UPLOAD_USER_WINDOW); + }); + + it('should correctly set FILE_UPLOAD environment variables based on rate limits', async () => { + // Define and mock a custom configuration with rate limits + const rateLimitsConfig = { + rateLimits: { + fileUploads: { + ipMax: '100', + ipWindowInMinutes: '60', + userMax: '50', + userWindowInMinutes: '30', + }, + }, + }; + + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve(rateLimitsConfig), + ); + + await AppService(app); + + // Verify that process.env has been updated according to the rate limits config + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual('100'); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual('60'); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual('50'); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual('30'); + }); + + it('should fallback to default FILE_UPLOAD environment variables when rate limits are unspecified', async () => { + // Setup initial environment variables to non-default values + process.env.FILE_UPLOAD_IP_MAX = 'initialMax'; + process.env.FILE_UPLOAD_IP_WINDOW = 'initialWindow'; + process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax'; + process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow'; + + // Mock a custom configuration without specific rate limits + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); + + await AppService(app); + + // Verify that process.env falls back to the initial values + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual('initialMax'); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual('initialWindow'); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual('initialUserMax'); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual('initialUserWindow'); + }); +}); + +describe('AppService updating app.locals and issuing warnings', () => { + let app; + let initialEnv; + + beforeEach(() => { + // Store initial environment variables to restore them after each test + initialEnv = { ...process.env }; + + app = { locals: {} }; + process.env.CDN_PROVIDER = undefined; + }); + + afterEach(() => { + // Restore initial environment variables + process.env = { ...initialEnv }; + }); + + it('should update app.locals with default values if loadCustomConfig returns undefined', async () => { + // Mock loadCustomConfig to return undefined + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(undefined)); + + await AppService(app); + + expect(app.locals).toBeDefined(); + expect(app.locals.paths).toBeDefined(); + expect(app.locals.availableTools).toBeDefined(); + expect(app.locals.fileStrategy).toEqual(FileSources.local); + expect(app.locals.socialLogins).toEqual(defaultSocialLogins); + }); + + it('should update app.locals with values from loadCustomConfig', async () => { + // Mock loadCustomConfig to return a specific config object + const customConfig = { + fileStrategy: 'firebase', + registration: { socialLogins: ['testLogin'] }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve(customConfig), + ); + + await AppService(app); + + expect(app.locals).toBeDefined(); + expect(app.locals.paths).toBeDefined(); + expect(app.locals.availableTools).toBeDefined(); + expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); + expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); + }); + + it('should apply the assistants endpoint configuration correctly to app.locals', async () => { + const mockConfig = { + endpoints: { + assistants: { + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: ['id1', 'id2'], + }, + }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + const app = { locals: {} }; + await AppService(app); + + expect(app.locals).toHaveProperty('assistants'); + const { assistants } = app.locals; + expect(assistants.disableBuilder).toBe(true); + expect(assistants.pollIntervalMs).toBe(5000); + expect(assistants.timeoutMs).toBe(30000); + expect(assistants.supportedIds).toEqual(['id1', 'id2']); + expect(assistants.excludedIds).toBeUndefined(); + }); + + it('should log a warning when both supportedIds and excludedIds are provided', async () => { + const mockConfig = { + endpoints: { + assistants: { + disableBuilder: false, + pollIntervalMs: 3000, + timeoutMs: 20000, + supportedIds: ['id1', 'id2'], + excludedIds: ['id3'], + }, + }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + const app = { locals: {} }; + await require('./AppService')(app); + + const { logger } = require('~/config'); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Both `supportedIds` and `excludedIds` are defined'), + ); + }); + + it('should issue expected warnings when loading Azure Groups with deprecated Environment Variables', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.azureOpenAI]: { + groups: azureGroups, + }, + }, + }), + ); + + deprecatedAzureVariables.forEach((varInfo) => { + process.env[varInfo.key] = 'test'; + }); + + const app = { locals: {} }; + await require('./AppService')(app); + + const { logger } = require('~/config'); + deprecatedAzureVariables.forEach(({ key, description }) => { + expect(logger.warn).toHaveBeenCalledWith( + `The \`${key}\` environment variable (related to ${description}) should not be used in combination with the \`azureOpenAI\` endpoint configuration, as you will experience conflicts and errors.`, + ); + }); + }); + + it('should issue expected warnings when loading conflicting Azure Envrionment Variables', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.azureOpenAI]: { + groups: azureGroups, + }, + }, + }), + ); + + conflictingAzureVariables.forEach((varInfo) => { + process.env[varInfo.key] = 'test'; + }); + + const app = { locals: {} }; + await require('./AppService')(app); + + const { logger } = require('~/config'); + conflictingAzureVariables.forEach(({ key }) => { + expect(logger.warn).toHaveBeenCalledWith( + `The \`${key}\` environment variable should not be used in combination with the \`azureOpenAI\` endpoint configuration, as you may experience with the defined placeholders for mapping to the current model grouping using the same name.`, + ); + }); + }); +}); diff --git a/api/server/services/AssistantService.js b/api/server/services/AssistantService.js index 4b929193481..41e88dc8bdb 100644 --- a/api/server/services/AssistantService.js +++ b/api/server/services/AssistantService.js @@ -1,256 +1,91 @@ -const RunManager = require('./Runs/RunMananger'); +const { klona } = require('klona'); +const { + StepTypes, + RunStatus, + StepStatus, + ContentTypes, + ToolCallTypes, + imageGenTools, + EModelEndpoint, + defaultOrderQuery, +} = require('librechat-data-provider'); +const { retrieveAndProcessFile } = require('~/server/services/Files/process'); +const { processRequiredActions } = require('~/server/services/ToolService'); +const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); +const { RunManager, waitForRun } = require('~/server/services/Runs'); +const { processMessages } = require('~/server/services/Threads'); +const { TextStream } = require('~/app/clients'); +const { logger } = require('~/config'); /** - * @typedef {Object} Message - * @property {string} id - The identifier of the message. - * @property {string} object - The object type, always 'thread.message'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the message was created. - * @property {string} thread_id - The thread ID that this message belongs to. - * @property {string} role - The entity that produced the message. One of 'user' or 'assistant'. - * @property {Object[]} content - The content of the message in an array of text and/or images. - * @property {string} content[].type - The type of content, either 'text' or 'image_file'. - * @property {Object} [content[].text] - The text content, present if type is 'text'. - * @property {string} content[].text.value - The data that makes up the text. - * @property {Object[]} [content[].text.annotations] - Annotations for the text content. - * @property {Object} [content[].image_file] - The image file content, present if type is 'image_file'. - * @property {string} content[].image_file.file_id - The File ID of the image in the message content. - * @property {string[]} [file_ids] - Optional list of File IDs for the message. - * @property {string|null} [assistant_id] - If applicable, the ID of the assistant that authored this message. - * @property {string|null} [run_id] - If applicable, the ID of the run associated with the authoring of this message. - * @property {Object} [metadata] - Optional metadata for the message, a map of key-value pairs. - */ - -/** - * @typedef {Object} FunctionTool - * @property {string} type - The type of tool, 'function'. - * @property {Object} function - The function definition. - * @property {string} function.description - A description of what the function does. - * @property {string} function.name - The name of the function to be called. - * @property {Object} function.parameters - The parameters the function accepts, described as a JSON Schema object. - */ - -/** - * @typedef {Object} Tool - * @property {string} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. - * @property {FunctionTool} [function] - The function tool, present if type is 'function'. - */ - -/** - * @typedef {Object} Run - * @property {string} id - The identifier of the run. - * @property {string} object - The object type, always 'thread.run'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run was created. - * @property {string} thread_id - The ID of the thread that was executed on as a part of this run. - * @property {string} assistant_id - The ID of the assistant used for execution of this run. - * @property {string} status - The status of the run (e.g., 'queued', 'completed'). - * @property {Object} [required_action] - Details on the action required to continue the run. - * @property {string} required_action.type - The type of required action, always 'submit_tool_outputs'. - * @property {Object} required_action.submit_tool_outputs - Details on the tool outputs needed for the run to continue. - * @property {Object[]} required_action.submit_tool_outputs.tool_calls - A list of the relevant tool calls. - * @property {string} required_action.submit_tool_outputs.tool_calls[].id - The ID of the tool call. - * @property {string} required_action.submit_tool_outputs.tool_calls[].type - The type of tool call the output is required for, always 'function'. - * @property {Object} required_action.submit_tool_outputs.tool_calls[].function - The function definition. - * @property {string} required_action.submit_tool_outputs.tool_calls[].function.name - The name of the function. - * @property {string} required_action.submit_tool_outputs.tool_calls[].function.arguments - The arguments that the model expects you to pass to the function. - * @property {Object} [last_error] - The last error associated with this run. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expires_at] - The Unix timestamp (in seconds) for when the run will expire. - * @property {number} [started_at] - The Unix timestamp (in seconds) for when the run was started. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run was completed. - * @property {string} [model] - The model that the assistant used for this run. - * @property {string} [instructions] - The instructions that the assistant used for this run. - * @property {Tool[]} [tools] - The list of tools used for this run. - * @property {string[]} [file_ids] - The list of File IDs used for this run. - * @property {Object} [metadata] - Metadata associated with this run. - */ - -/** - * @typedef {Object} RunStep - * @property {string} id - The identifier of the run step. - * @property {string} object - The object type, always 'thread.run.step'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run step was created. - * @property {string} assistant_id - The ID of the assistant associated with the run step. - * @property {string} thread_id - The ID of the thread that was run. - * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. - * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run step failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run step completed. - * @property {Object} [metadata] - Metadata associated with this run step, a map of up to 16 key-value pairs. - */ - -/** - * @typedef {Object} StepMessage - * @property {Message} message - The complete message object created by the step. - * @property {string} id - The identifier of the run step. - * @property {string} object - The object type, always 'thread.run.step'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run step was created. - * @property {string} assistant_id - The ID of the assistant associated with the run step. - * @property {string} thread_id - The ID of the thread that was run. - * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. - * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run step failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run step completed. - * @property {Object} [metadata] - Metadata associated with this run step, a map of up to 16 key-value pairs. - */ - -/** - * Initializes a new thread or adds messages to an existing thread. - * - * @param {Object} params - The parameters for initializing a thread. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {Object} params.body - The body of the request. - * @param {Message[]} params.body.messages - A list of messages to start the thread with. - * @param {Object} [params.body.metadata] - Optional metadata for the thread. - * @param {string} [params.thread_id] - Optional existing thread ID. If provided, a message will be added to this thread. - * @return {Promise<Thread>} A promise that resolves to the newly created thread object or the updated thread object. - */ -async function initThread({ openai, body, thread_id: _thread_id }) { - let thread = {}; - const messages = []; - if (_thread_id) { - const message = await openai.beta.threads.messages.create(_thread_id, body.messages[0]); - messages.push(message); - } else { - thread = await openai.beta.threads.create(body); - } - - const thread_id = _thread_id ?? thread.id; - return { messages, thread_id, ...thread }; -} - -/** - * Creates a run on a thread using the OpenAI API. + * Sorts, processes, and flattens messages to a single string. * - * @param {Object} params - The parameters for creating a run. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {string} params.thread_id - The ID of the thread to run. - * @param {Object} params.body - The body of the request to create a run. - * @param {string} params.body.assistant_id - The ID of the assistant to use for this run. - * @param {string} [params.body.model] - Optional. The ID of the model to be used for this run. - * @param {string} [params.body.instructions] - Optional. Override the default system message of the assistant. - * @param {Object[]} [params.body.tools] - Optional. Override the tools the assistant can use for this run. - * @param {string[]} [params.body.file_ids] - Optional. List of File IDs the assistant can use for this run. - * @param {Object} [params.body.metadata] - Optional. Metadata for the run. - * @return {Promise<Run>} A promise that resolves to the created run object. + * @param {Object} params - Params for creating the onTextProgress function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.userMessageId - The user message ID; response's `parentMessageId`. + * @param {string} params.messageId - The response message ID. + * @param {string} params.thread_id - The current thread ID. + * @returns {void} */ -async function createRun({ openai, thread_id, body }) { - const run = await openai.beta.threads.runs.create(thread_id, body); - return run; -} - -// /** -// * Retrieves all steps of a run. -// * -// * @param {Object} params - The parameters for the retrieveRunSteps function. -// * @param {OpenAI} params.openai - The OpenAI client instance. -// * @param {string} params.thread_id - The ID of the thread associated with the run. -// * @param {string} params.run_id - The ID of the run to retrieve steps for. -// * @return {Promise<RunStep[]>} A promise that resolves to an array of RunStep objects. -// */ -// async function retrieveRunSteps({ openai, thread_id, run_id }) { -// const runSteps = await openai.beta.threads.runs.steps.list(thread_id, run_id); -// return runSteps; -// } - -/** - * Delays the execution for a specified number of milliseconds. - * - * @param {number} ms - The number of milliseconds to delay. - * @return {Promise<void>} A promise that resolves after the specified delay. - */ -function sleep(ms) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - -/** - * Waits for a run to complete by repeatedly checking its status. It uses a RunManager instance to fetch and manage run steps based on the run status. - * - * @param {Object} params - The parameters for the waitForRun function. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {string} params.run_id - The ID of the run to wait for. - * @param {string} params.thread_id - The ID of the thread associated with the run. - * @param {RunManager} params.runManager - The RunManager instance to manage run steps. - * @param {number} params.pollIntervalMs - The interval for polling the run status, default is 500 milliseconds. - * @return {Promise<Run>} A promise that resolves to the last fetched run object. - */ -async function waitForRun({ openai, run_id, thread_id, runManager, pollIntervalMs = 500 }) { - const timeout = 18000; // 18 seconds - let timeElapsed = 0; - let run; - - // this runManager will be passed in from the caller - // const runManager = new RunManager({ - // 'in_progress': (step) => { /* ... */ }, - // 'queued': (step) => { /* ... */ }, - // }); - - while (timeElapsed < timeout) { - run = await openai.beta.threads.runs.retrieve(thread_id, run_id); - console.log(`Run status: ${run.status}`); - - if (!['in_progress', 'queued'].includes(run.status)) { - await runManager.fetchRunSteps({ - openai, - thread_id: thread_id, - run_id: run_id, - runStatus: run.status, - final: true, - }); - break; +async function createOnTextProgress({ + openai, + conversationId, + userMessageId, + messageId, + thread_id, +}) { + openai.responseMessage = { + conversationId, + parentMessageId: userMessageId, + role: 'assistant', + messageId, + content: [], + }; + + openai.responseText = ''; + + openai.addContentData = (data) => { + const { type, index } = data; + openai.responseMessage.content[index] = { type, [type]: data[type] }; + + if (type === ContentTypes.TEXT) { + openai.responseText += data[type].value; + return; } - // may use in future - // await runManager.fetchRunSteps({ - // openai, - // thread_id: thread_id, - // run_id: run_id, - // runStatus: run.status, - // }); - - await sleep(pollIntervalMs); - timeElapsed += pollIntervalMs; - } - - return run; + const contentData = { + index, + type, + [type]: data[type], + messageId, + thread_id, + conversationId, + }; + + logger.debug('Content data:', contentData); + sendMessage(openai.res, contentData); + }; } /** * Retrieves the response from an OpenAI run. * * @param {Object} params - The parameters for getting the response. - * @param {OpenAI} params.openai - The OpenAI client instance. + * @param {OpenAIClient} params.openai - The OpenAI client instance. * @param {string} params.run_id - The ID of the run to get the response for. * @param {string} params.thread_id - The ID of the thread associated with the run. - * @return {Promise<OpenAIAssistantFinish | OpenAIAssistantAction[] | Message[] | RequiredActionFunctionToolCall[]>} + * @return {Promise<OpenAIAssistantFinish | OpenAIAssistantAction[] | ThreadMessage[] | RequiredActionFunctionToolCall[]>} */ async function getResponse({ openai, run_id, thread_id }) { const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 500 }); - if (run.status === 'completed') { - const messages = await openai.beta.threads.messages.list(thread_id, { - order: 'asc', - }); + if (run.status === RunStatus.COMPLETED) { + const messages = await openai.beta.threads.messages.list(thread_id, defaultOrderQuery); const newMessages = messages.data.filter((msg) => msg.run_id === run_id); return newMessages; - } else if (run.status === 'requires_action') { + } else if (run.status === RunStatus.REQUIRES_ACTION) { const actions = []; run.required_action?.submit_tool_outputs.tool_calls.forEach((item) => { const functionCall = item.function; @@ -259,7 +94,6 @@ async function getResponse({ openai, run_id, thread_id }) { tool: functionCall.name, toolInput: args, toolCallId: item.id, - log: '', run_id, thread_id, }); @@ -272,90 +106,349 @@ async function getResponse({ openai, run_id, thread_id }) { throw new Error(`Unexpected run status ${run.status}.\nFull run info:\n\n${runInfo}`); } +/** + * Filters the steps to keep only the most recent instance of each unique step. + * @param {RunStep[]} steps - The array of RunSteps to filter. + * @return {RunStep[]} The filtered array of RunSteps. + */ +function filterSteps(steps = []) { + if (steps.length <= 1) { + return steps; + } + const stepMap = new Map(); + + steps.forEach((step) => { + if (!step) { + return; + } + + const effectiveTimestamp = Math.max( + step.created_at, + step.expired_at || 0, + step.cancelled_at || 0, + step.failed_at || 0, + step.completed_at || 0, + ); + + if (!stepMap.has(step.id) || effectiveTimestamp > stepMap.get(step.id).effectiveTimestamp) { + const latestStep = { ...step, effectiveTimestamp }; + if (latestStep.last_error) { + // testing to see if we ever step into this + } + stepMap.set(step.id, latestStep); + } + }); + + return Array.from(stepMap.values()).map((step) => { + delete step.effectiveTimestamp; + return step; + }); +} + +/** + * @callback InProgressFunction + * @param {Object} params - The parameters for the in progress step. + * @param {RunStep} params.step - The step object with details about the message creation. + * @returns {Promise<void>} - A promise that resolves when the step is processed. + */ + +function hasToolCallChanged(previousCall, currentCall) { + return JSON.stringify(previousCall) !== JSON.stringify(currentCall); +} + +/** + * Creates a handler function for steps in progress, specifically for + * processing messages and managing seen completed messages. + * + * @param {OpenAIClient} openai - The OpenAI client instance. + * @param {string} thread_id - The ID of the thread the run is in. + * @param {ThreadMessage[]} messages - The accumulated messages for the run. + * @return {InProgressFunction} a function to handle steps in progress. + */ +function createInProgressHandler(openai, thread_id, messages) { + openai.index = 0; + openai.mappedOrder = new Map(); + openai.seenToolCalls = new Map(); + openai.processedFileIds = new Set(); + openai.completeToolCallSteps = new Set(); + openai.seenCompletedMessages = new Set(); + + /** + * The in_progress function for handling message creation steps. + * + * @type {InProgressFunction} + */ + async function in_progress({ step }) { + if (step.type === StepTypes.TOOL_CALLS) { + const { tool_calls } = step.step_details; + + for (const _toolCall of tool_calls) { + /** @type {StepToolCall} */ + const toolCall = _toolCall; + const previousCall = openai.seenToolCalls.get(toolCall.id); + + // If the tool call isn't new and hasn't changed + if (previousCall && !hasToolCallChanged(previousCall, toolCall)) { + continue; + } + + let toolCallIndex = openai.mappedOrder.get(toolCall.id); + if (toolCallIndex === undefined) { + // New tool call + toolCallIndex = openai.index; + openai.mappedOrder.set(toolCall.id, openai.index); + openai.index++; + } + + if (step.status === StepStatus.IN_PROGRESS) { + toolCall.progress = + previousCall && previousCall.progress + ? Math.min(previousCall.progress + 0.2, 0.95) + : 0.01; + } else { + toolCall.progress = 1; + openai.completeToolCallSteps.add(step.id); + } + + if ( + toolCall.type === ToolCallTypes.CODE_INTERPRETER && + step.status === StepStatus.COMPLETED + ) { + const { outputs } = toolCall[toolCall.type]; + + for (const output of outputs) { + if (output.type !== 'image') { + continue; + } + + if (openai.processedFileIds.has(output.image?.file_id)) { + continue; + } + + const { file_id } = output.image; + const file = await retrieveAndProcessFile({ + openai, + client: openai, + file_id, + basename: `${file_id}.png`, + }); + + const prelimImage = file; + + // check if every key has a value before adding to content + const prelimImageKeys = Object.keys(prelimImage); + const validImageFile = prelimImageKeys.every((key) => prelimImage[key]); + + if (!validImageFile) { + continue; + } + + const image_file = { + [ContentTypes.IMAGE_FILE]: prelimImage, + type: ContentTypes.IMAGE_FILE, + index: openai.index, + }; + openai.addContentData(image_file); + openai.processedFileIds.add(file_id); + openai.index++; + } + } else if ( + toolCall.type === ToolCallTypes.FUNCTION && + step.status === StepStatus.COMPLETED && + imageGenTools.has(toolCall[toolCall.type].name) + ) { + /* If a change is detected, skip image generation tools as already processed */ + openai.seenToolCalls.set(toolCall.id, toolCall); + continue; + } + + openai.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + }); + + // Update the stored tool call + openai.seenToolCalls.set(toolCall.id, toolCall); + } + } else if (step.type === StepTypes.MESSAGE_CREATION && step.status === StepStatus.COMPLETED) { + const { message_id } = step.step_details.message_creation; + if (openai.seenCompletedMessages.has(message_id)) { + return; + } + + openai.seenCompletedMessages.add(message_id); + + const message = await openai.beta.threads.messages.retrieve(thread_id, message_id); + if (!message?.content?.length) { + return; + } + messages.push(message); + + let messageIndex = openai.mappedOrder.get(step.id); + if (messageIndex === undefined) { + // New message + messageIndex = openai.index; + openai.mappedOrder.set(step.id, openai.index); + openai.index++; + } + + const result = await processMessages({ openai, client: openai, messages: [message] }); + openai.addContentData({ + [ContentTypes.TEXT]: { value: result.text }, + type: ContentTypes.TEXT, + index: messageIndex, + }); + + // Create the Factory Function to stream the message + const { onProgress: progressCallback } = createOnProgress({ + // todo: add option to save partialText to db + // onProgress: () => {}, + }); + + // This creates a function that attaches all of the parameters + // specified here to each SSE message generated by the TextStream + const onProgress = progressCallback({ + res: openai.res, + index: messageIndex, + messageId: openai.responseMessage.messageId, + conversationId: openai.responseMessage.conversationId, + type: ContentTypes.TEXT, + thread_id, + }); + + // Create a small buffer before streaming begins + await sleep(500); + + const stream = new TextStream(result.text, { delay: 9 }); + await stream.processTextStream(onProgress); + } + } + + return in_progress; +} + /** * Initializes a RunManager with handlers, then invokes waitForRun to monitor and manage an OpenAI run. * * @param {Object} params - The parameters for managing and monitoring the run. - * @param {OpenAI} params.openai - The OpenAI client instance. + * @param {OpenAIClient} params.openai - The OpenAI client instance. * @param {string} params.run_id - The ID of the run to manage and monitor. * @param {string} params.thread_id - The ID of the thread associated with the run. - * @return {Promise<Object>} A promise that resolves to an object containing the run and managed steps. + * @param {RunStep[]} params.accumulatedSteps - The accumulated steps for the run. + * @param {ThreadMessage[]} params.accumulatedMessages - The accumulated messages for the run. + * @param {InProgressFunction} [params.in_progress] - The `in_progress` function from a previous run. + * @return {Promise<RunResponse>} A promise that resolves to an object containing the run and managed steps. */ -async function handleRun({ openai, run_id, thread_id }) { - let steps; - let messages; +async function runAssistant({ + openai, + run_id, + thread_id, + accumulatedSteps = [], + accumulatedMessages = [], + in_progress: inProgress, +}) { + let steps = accumulatedSteps; + let messages = accumulatedMessages; + const in_progress = inProgress ?? createInProgressHandler(openai, thread_id, messages); + openai.in_progress = in_progress; + const runManager = new RunManager({ - // 'in_progress': async ({ step, final, isLast }) => { - // // Define logic for handling steps with 'in_progress' status - // }, - // 'queued': async ({ step, final, isLast }) => { - // // Define logic for handling steps with 'queued' status - // }, + in_progress, final: async ({ step, runStatus, stepsByStatus }) => { - console.log(`Final step for ${run_id} with status ${runStatus}`); - console.dir(step, { depth: null }); + logger.debug(`[runAssistant] Final step for ${run_id} with status ${runStatus}`, step); const promises = []; - promises.push( - openai.beta.threads.messages.list(thread_id, { - order: 'asc', - }), - ); + // promises.push( + // openai.beta.threads.messages.list(thread_id, defaultOrderQuery), + // ); - const finalSteps = stepsByStatus[runStatus]; - - // loop across all statuses, may use in the future - // for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { - // promises.push(...stepsPromises); + // const finalSteps = stepsByStatus[runStatus]; + // for (const stepPromise of finalSteps) { + // promises.push(stepPromise); // } - for (const stepPromise of finalSteps) { - promises.push(stepPromise); + + // loop across all statuses + for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { + promises.push(...stepsPromises); } const resolved = await Promise.all(promises); - const res = resolved.shift(); - messages = res.data.filter((msg) => msg.run_id === run_id); + const finalSteps = filterSteps(steps.concat(resolved)); + + if (step.type === StepTypes.MESSAGE_CREATION) { + const incompleteToolCallSteps = finalSteps.filter( + (s) => s && s.type === StepTypes.TOOL_CALLS && !openai.completeToolCallSteps.has(s.id), + ); + for (const incompleteToolCallStep of incompleteToolCallSteps) { + await in_progress({ step: incompleteToolCallStep }); + } + } + await in_progress({ step }); + // const res = resolved.shift(); + // messages = messages.concat(res.data.filter((msg) => msg && msg.run_id === run_id)); resolved.push(step); - steps = resolved; + /* Note: no issues without deep cloning, but it's safer to do so */ + steps = klona(finalSteps); }, }); - const run = await waitForRun({ openai, run_id, thread_id, runManager, pollIntervalMs: 500 }); + /** @type {TCustomConfig.endpoints.assistants} */ + const assistantsEndpointConfig = openai.req.app.locals?.[EModelEndpoint.assistants] ?? {}; + const { pollIntervalMs, timeoutMs } = assistantsEndpointConfig; + + const run = await waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs, + timeout: timeoutMs, + }); - return { run, steps, messages }; -} + if (!run.required_action) { + // const { messages: sortedMessages, text } = await processMessages(openai, messages); + // return { run, steps, messages: sortedMessages, text }; + const sortedMessages = messages.sort((a, b) => a.created_at - b.created_at); + return { + run, + steps, + messages: sortedMessages, + finalMessage: openai.responseMessage, + text: openai.responseText, + }; + } -/** - * Maps messages to their corresponding steps. Steps with message creation will be paired with their messages, - * while steps without message creation will be returned as is. - * - * @param {RunStep[]} steps - An array of steps from the run. - * @param {Message[]} messages - An array of message objects. - * @returns {(StepMessage | RunStep)[]} An array where each element is either a step with its corresponding message (StepMessage) or a step without a message (RunStep). - */ -function mapMessagesToSteps(steps, messages) { - // Create a map of messages indexed by their IDs for efficient lookup - const messageMap = messages.reduce((acc, msg) => { - acc[msg.id] = msg; - return acc; - }, {}); - - // Map each step to its corresponding message, or return the step as is if no message ID is present - return steps.map((step) => { - const messageId = step.step_details?.message_creation?.message_id; - - if (messageId && messageMap[messageId]) { - return { step, message: messageMap[messageId] }; - } - return step; + const { submit_tool_outputs } = run.required_action; + const actions = submit_tool_outputs.tool_calls.map((item) => { + const functionCall = item.function; + const args = JSON.parse(functionCall.arguments); + return { + tool: functionCall.name, + toolInput: args, + toolCallId: item.id, + run_id, + thread_id, + }; + }); + + const outputs = await processRequiredActions(openai, actions); + + const toolRun = await openai.beta.threads.runs.submitToolOutputs(run.thread_id, run.id, outputs); + + // Recursive call with accumulated steps and messages + return await runAssistant({ + openai, + run_id: toolRun.id, + thread_id, + accumulatedSteps: steps, + accumulatedMessages: messages, + in_progress, }); } module.exports = { - initThread, - createRun, - waitForRun, getResponse, - handleRun, - mapMessagesToSteps, + runAssistant, + createOnTextProgress, }; diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index a60ae370efe..098110df0d4 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,6 +1,8 @@ const crypto = require('crypto'); const bcrypt = require('bcryptjs'); -const { registerSchema, errorsToString } = require('~/strategies/validators'); +const { errorsToString } = require('librechat-data-provider'); +const { registerSchema } = require('~/strategies/validators'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const Token = require('~/models/schema/tokenSchema'); const { sendEmail } = require('~/server/utils'); const Session = require('~/models/Session'); @@ -12,6 +14,27 @@ const domains = { server: process.env.DOMAIN_SERVER, }; +async function isDomainAllowed(email) { + if (!email) { + return false; + } + + const domain = email.split('@')[1]; + + if (!domain) { + return false; + } + + const customConfig = await getCustomConfig(); + if (!customConfig) { + return true; + } else if (!customConfig?.registration?.allowedDomains) { + return true; + } + + return customConfig.registration.allowedDomains.includes(domain); +} + const isProduction = process.env.NODE_ENV === 'production'; /** @@ -80,6 +103,12 @@ const registerUser = async (user) => { return { status: 500, message: 'Something went wrong' }; } + if (!(await isDomainAllowed(email))) { + const errorMessage = 'Registration from this domain is not allowed.'; + logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`); + return { status: 403, message: errorMessage }; + } + //determine if this is the first registered user (not counting anonymous_user) const isFirstRegisteredUser = (await User.countDocuments({})) === 0; @@ -143,8 +172,10 @@ const requestPasswordReset = async (email) => { user.email, 'Password Reset Request', { + appName: process.env.APP_TITLE || 'LibreChat', name: user.name, link: link, + year: new Date().getFullYear(), }, 'requestPasswordReset.handlebars', ); @@ -185,7 +216,9 @@ const resetPassword = async (userId, token, password) => { user.email, 'Password Reset Successfully', { + appName: process.env.APP_TITLE || 'LibreChat', name: user.name, + year: new Date().getFullYear(), }, 'passwordReset.handlebars', ); @@ -239,6 +272,7 @@ const setAuthTokens = async (userId, res, sessionId = null) => { module.exports = { registerUser, logoutUser, + isDomainAllowed, requestPasswordReset, resetPassword, setAuthTokens, diff --git a/api/server/services/AuthService.spec.js b/api/server/services/AuthService.spec.js new file mode 100644 index 00000000000..fb5d8e2533c --- /dev/null +++ b/api/server/services/AuthService.spec.js @@ -0,0 +1,39 @@ +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); +const { isDomainAllowed } = require('./AuthService'); + +jest.mock('~/server/services/Config/getCustomConfig', () => jest.fn()); + +describe('isDomainAllowed', () => { + it('should allow domain when customConfig is not available', async () => { + getCustomConfig.mockResolvedValue(null); + await expect(isDomainAllowed('test@domain1.com')).resolves.toBe(true); + }); + + it('should allow domain when allowedDomains is not defined in customConfig', async () => { + getCustomConfig.mockResolvedValue({}); + await expect(isDomainAllowed('test@domain1.com')).resolves.toBe(true); + }); + + it('should reject an email if it is falsy', async () => { + getCustomConfig.mockResolvedValue({}); + await expect(isDomainAllowed('')).resolves.toBe(false); + }); + + it('should allow a domain if it is included in the allowedDomains', async () => { + getCustomConfig.mockResolvedValue({ + registration: { + allowedDomains: ['domain1.com', 'domain2.com'], + }, + }); + await expect(isDomainAllowed('user@domain1.com')).resolves.toBe(true); + }); + + it('should reject a domain if it is not included in the allowedDomains', async () => { + getCustomConfig.mockResolvedValue({ + registration: { + allowedDomains: ['domain1.com', 'domain2.com'], + }, + }); + await expect(isDomainAllowed('user@domain3.com')).resolves.toBe(false); + }); +}); diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index 998e7a83d03..987fbb88517 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -1,24 +1,25 @@ const { EModelEndpoint } = require('librechat-data-provider'); +const { isUserProvided, generateConfig } = require('~/server/utils'); const { OPENAI_API_KEY: openAIApiKey, + ASSISTANTS_API_KEY: assistantsApiKey, AZURE_API_KEY: azureOpenAIApiKey, ANTHROPIC_API_KEY: anthropicApiKey, CHATGPT_TOKEN: chatGPTToken, BINGAI_TOKEN: bingToken, PLUGINS_USE_AZURE, GOOGLE_KEY: googleKey, + OPENAI_REVERSE_PROXY, + AZURE_OPENAI_BASEURL, + ASSISTANTS_BASE_URL, } = process.env ?? {}; const useAzurePlugins = !!PLUGINS_USE_AZURE; const userProvidedOpenAI = useAzurePlugins - ? azureOpenAIApiKey === 'user_provided' - : openAIApiKey === 'user_provided'; - -function isUserProvided(key) { - return key ? { userProvide: key === 'user_provided' } : false; -} + ? isUserProvided(azureOpenAIApiKey) + : isUserProvided(openAIApiKey); module.exports = { config: { @@ -27,11 +28,11 @@ module.exports = { useAzurePlugins, userProvidedOpenAI, googleKey, - [EModelEndpoint.openAI]: isUserProvided(openAIApiKey), - [EModelEndpoint.assistant]: isUserProvided(openAIApiKey), - [EModelEndpoint.azureOpenAI]: isUserProvided(azureOpenAIApiKey), - [EModelEndpoint.chatGPTBrowser]: isUserProvided(chatGPTToken), - [EModelEndpoint.anthropic]: isUserProvided(anthropicApiKey), - [EModelEndpoint.bingAI]: isUserProvided(bingToken), + [EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY), + [EModelEndpoint.assistants]: generateConfig(assistantsApiKey, ASSISTANTS_BASE_URL, true), + [EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL), + [EModelEndpoint.chatGPTBrowser]: generateConfig(chatGPTToken), + [EModelEndpoint.anthropic]: generateConfig(anthropicApiKey), + [EModelEndpoint.bingAI]: generateConfig(bingToken), }, }; diff --git a/api/cache/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js similarity index 69% rename from api/cache/getCustomConfig.js rename to api/server/services/Config/getCustomConfig.js index 62082c5cbae..a479ca37b71 100644 --- a/api/cache/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,10 +1,12 @@ const { CacheKeys } = require('librechat-data-provider'); -const loadCustomConfig = require('~/server/services/Config/loadCustomConfig'); -const getLogStores = require('./getLogStores'); +const loadCustomConfig = require('./loadCustomConfig'); +const getLogStores = require('~/cache/getLogStores'); /** * Retrieves the configuration object - * @function getCustomConfig */ + * @function getCustomConfig + * @returns {Promise<TCustomConfig | null>} + * */ async function getCustomConfig() { const cache = getLogStores(CacheKeys.CONFIG_STORE); let customConfig = await cache.get(CacheKeys.CUSTOM_CONFIG); diff --git a/api/server/services/Config/handleRateLimits.js b/api/server/services/Config/handleRateLimits.js new file mode 100644 index 00000000000..d40ccfb4f33 --- /dev/null +++ b/api/server/services/Config/handleRateLimits.js @@ -0,0 +1,22 @@ +/** + * + * @param {TCustomConfig['rateLimits'] | undefined} rateLimits + */ +const handleRateLimits = (rateLimits) => { + if (!rateLimits) { + return; + } + const { fileUploads } = rateLimits; + if (!fileUploads) { + return; + } + + process.env.FILE_UPLOAD_IP_MAX = fileUploads.ipMax ?? process.env.FILE_UPLOAD_IP_MAX; + process.env.FILE_UPLOAD_IP_WINDOW = + fileUploads.ipWindowInMinutes ?? process.env.FILE_UPLOAD_IP_WINDOW; + process.env.FILE_UPLOAD_USER_MAX = fileUploads.userMax ?? process.env.FILE_UPLOAD_USER_MAX; + process.env.FILE_UPLOAD_USER_WINDOW = + fileUploads.userWindowInMinutes ?? process.env.FILE_UPLOAD_USER_WINDOW; +}; + +module.exports = handleRateLimits; diff --git a/api/server/services/Config/index.js b/api/server/services/Config/index.js index 57a00bf515e..2e8ccb1433c 100644 --- a/api/server/services/Config/index.js +++ b/api/server/services/Config/index.js @@ -1,4 +1,5 @@ const { config } = require('./EndpointService'); +const getCustomConfig = require('./getCustomConfig'); const loadCustomConfig = require('./loadCustomConfig'); const loadConfigModels = require('./loadConfigModels'); const loadDefaultModels = require('./loadDefaultModels'); @@ -9,6 +10,7 @@ const loadDefaultEndpointsConfig = require('./loadDefaultEConfig'); module.exports = { config, + getCustomConfig, loadCustomConfig, loadConfigModels, loadDefaultModels, diff --git a/api/server/services/Config/loadAsyncEndpoints.js b/api/server/services/Config/loadAsyncEndpoints.js index 9e92f487fad..409b9485de2 100644 --- a/api/server/services/Config/loadAsyncEndpoints.js +++ b/api/server/services/Config/loadAsyncEndpoints.js @@ -1,12 +1,16 @@ -const { availableTools } = require('~/app/clients/tools'); +const { EModelEndpoint } = require('librechat-data-provider'); const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs'); -const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = - require('./EndpointService').config; +const { availableTools } = require('~/app/clients/tools'); +const { isUserProvided } = require('~/server/utils'); +const { config } = require('./EndpointService'); + +const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config; /** * Load async endpoints and return a configuration object + * @param {Express.Request} req - The request object */ -async function loadAsyncEndpoints() { +async function loadAsyncEndpoints(req) { let i = 0; let serviceKey, googleUserProvides; try { @@ -17,7 +21,7 @@ async function loadAsyncEndpoints() { } } - if (googleKey === 'user_provided') { + if (isUserProvided(googleKey)) { googleUserProvides = true; if (i <= 1) { i++; @@ -35,13 +39,18 @@ async function loadAsyncEndpoints() { const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false; + const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins; const gptPlugins = - openAIApiKey || azureOpenAIApiKey + useAzure || openAIApiKey || azureOpenAIApiKey ? { plugins, availableAgents: ['classic', 'functions'], - userProvide: userProvidedOpenAI, - azure: useAzurePlugins, + userProvide: useAzure ? false : userProvidedOpenAI, + userProvideURL: useAzure + ? false + : config[EModelEndpoint.openAI]?.userProvideURL || + config[EModelEndpoint.azureOpenAI]?.userProvideURL, + azure: useAzurePlugins || useAzure, } : false; diff --git a/api/server/services/Config/loadConfigEndpoints.js b/api/server/services/Config/loadConfigEndpoints.js index 1b435e144e9..cd05cb9acb4 100644 --- a/api/server/services/Config/loadConfigEndpoints.js +++ b/api/server/services/Config/loadConfigEndpoints.js @@ -1,18 +1,14 @@ -const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const { isUserProvided, extractEnvVariable } = require('~/server/utils'); -const loadCustomConfig = require('./loadCustomConfig'); -const { getLogStores } = require('~/cache'); +const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider'); +const { isUserProvided } = require('~/server/utils'); +const getCustomConfig = require('./getCustomConfig'); /** * Load config endpoints from the cached configuration object - * @function loadConfigEndpoints */ -async function loadConfigEndpoints() { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - let customConfig = await cache.get(CacheKeys.CUSTOM_CONFIG); - - if (!customConfig) { - customConfig = await loadCustomConfig(); - } + * @param {Express.Request} req - The request object + * @returns {Promise<TEndpointsConfig>} A promise that resolves to an object containing the endpoints configuration + */ +async function loadConfigEndpoints(req) { + const customConfig = await getCustomConfig(); if (!customConfig) { return {}; @@ -48,6 +44,20 @@ async function loadConfigEndpoints() { } } + if (req.app.locals[EModelEndpoint.azureOpenAI]) { + /** @type {Omit<TConfig, 'order'>} */ + endpointsConfig[EModelEndpoint.azureOpenAI] = { + userProvide: false, + }; + } + + if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + /** @type {Omit<TConfig, 'order'>} */ + endpointsConfig[EModelEndpoint.assistants] = { + userProvide: false, + }; + } + return endpointsConfig; } diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 0abe15a8a1f..b3997a2ada0 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -1,19 +1,15 @@ -const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const { isUserProvided, extractEnvVariable } = require('~/server/utils'); +const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider'); const { fetchModels } = require('~/server/services/ModelService'); -const loadCustomConfig = require('./loadCustomConfig'); -const { getLogStores } = require('~/cache'); +const { isUserProvided } = require('~/server/utils'); +const getCustomConfig = require('./getCustomConfig'); /** * Load config endpoints from the cached configuration object - * @function loadConfigModels */ -async function loadConfigModels() { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - let customConfig = await cache.get(CacheKeys.CUSTOM_CONFIG); - - if (!customConfig) { - customConfig = await loadCustomConfig(); - } + * @function loadConfigModels + * @param {Express.Request} req - The Express request object. + */ +async function loadConfigModels(req) { + const customConfig = await getCustomConfig(); if (!customConfig) { return {}; @@ -21,6 +17,21 @@ async function loadConfigModels() { const { endpoints = {} } = customConfig ?? {}; const modelsConfig = {}; + const azureEndpoint = endpoints[EModelEndpoint.azureOpenAI]; + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + const { modelNames } = azureConfig ?? {}; + + if (modelNames && azureEndpoint) { + modelsConfig[EModelEndpoint.azureOpenAI] = modelNames; + } + + if (modelNames && azureEndpoint && azureEndpoint.plugins) { + modelsConfig[EModelEndpoint.gptPlugins] = modelNames; + } + + if (azureEndpoint?.assistants && azureConfig.assistantModels) { + modelsConfig[EModelEndpoint.assistants] = azureConfig.assistantModels; + } if (!Array.isArray(endpoints[EModelEndpoint.custom])) { return modelsConfig; @@ -35,23 +46,43 @@ async function loadConfigModels() { (endpoint.models.fetch || endpoint.models.default), ); - const fetchPromisesMap = {}; // Map for promises keyed by baseURL - const baseUrlToNameMap = {}; // Map to associate baseURLs with names + /** + * @type {Record<string, string[]>} + * Map for promises keyed by unique combination of baseURL and apiKey */ + const fetchPromisesMap = {}; + /** + * @type {Record<string, string[]>} + * Map to associate unique keys with endpoint names; note: one key may can correspond to multiple endpoints */ + const uniqueKeyToEndpointsMap = {}; + /** + * @type {Record<string, Partial<TEndpoint>>} + * Map to associate endpoint names to their configurations */ + const endpointsMap = {}; for (let i = 0; i < customEndpoints.length; i++) { const endpoint = customEndpoints[i]; const { models, name, baseURL, apiKey } = endpoint; + endpointsMap[name] = endpoint; const API_KEY = extractEnvVariable(apiKey); const BASE_URL = extractEnvVariable(baseURL); + const uniqueKey = `${BASE_URL}__${API_KEY}`; + modelsConfig[name] = []; if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) { - fetchPromisesMap[BASE_URL] = - fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY }); - baseUrlToNameMap[BASE_URL] = baseUrlToNameMap[BASE_URL] || []; - baseUrlToNameMap[BASE_URL].push(name); + fetchPromisesMap[uniqueKey] = + fetchPromisesMap[uniqueKey] || + fetchModels({ + user: req.user.id, + baseURL: BASE_URL, + apiKey: API_KEY, + name, + userIdQuery: models.userIdQuery, + }); + uniqueKeyToEndpointsMap[uniqueKey] = uniqueKeyToEndpointsMap[uniqueKey] || []; + uniqueKeyToEndpointsMap[uniqueKey].push(name); continue; } @@ -61,15 +92,16 @@ async function loadConfigModels() { } const fetchedData = await Promise.all(Object.values(fetchPromisesMap)); - const baseUrls = Object.keys(fetchPromisesMap); + const uniqueKeys = Object.keys(fetchPromisesMap); for (let i = 0; i < fetchedData.length; i++) { - const currentBaseUrl = baseUrls[i]; + const currentKey = uniqueKeys[i]; const modelData = fetchedData[i]; - const associatedNames = baseUrlToNameMap[currentBaseUrl]; + const associatedNames = uniqueKeyToEndpointsMap[currentKey]; for (const name of associatedNames) { - modelsConfig[name] = modelData; + const endpoint = endpointsMap[name]; + modelsConfig[name] = !modelData?.length ? endpoint.models.default ?? [] : modelData; } } diff --git a/api/server/services/Config/loadConfigModels.spec.js b/api/server/services/Config/loadConfigModels.spec.js new file mode 100644 index 00000000000..1b7dec5fd71 --- /dev/null +++ b/api/server/services/Config/loadConfigModels.spec.js @@ -0,0 +1,329 @@ +const { fetchModels } = require('~/server/services/ModelService'); +const loadConfigModels = require('./loadConfigModels'); +const getCustomConfig = require('./getCustomConfig'); + +jest.mock('~/server/services/ModelService'); +jest.mock('./getCustomConfig'); + +const exampleConfig = { + endpoints: { + custom: [ + { + name: 'Mistral', + apiKey: '${MY_PRECIOUS_MISTRAL_KEY}', + baseURL: 'https://api.mistral.ai/v1', + models: { + default: ['mistral-tiny', 'mistral-small', 'mistral-medium', 'mistral-large-latest'], + fetch: true, + }, + dropParams: ['stop', 'user', 'frequency_penalty', 'presence_penalty'], + }, + { + name: 'OpenRouter', + apiKey: '${MY_OPENROUTER_API_KEY}', + baseURL: 'https://openrouter.ai/api/v1', + models: { + default: ['gpt-3.5-turbo'], + fetch: true, + }, + dropParams: ['stop'], + }, + { + name: 'groq', + apiKey: 'user_provided', + baseURL: 'https://api.groq.com/openai/v1/', + models: { + default: ['llama2-70b-4096', 'mixtral-8x7b-32768'], + fetch: false, + }, + }, + { + name: 'Ollama', + apiKey: 'user_provided', + baseURL: 'http://localhost:11434/v1/', + models: { + default: ['mistral', 'llama2:13b'], + fetch: false, + }, + }, + ], + }, +}; + +describe('loadConfigModels', () => { + const mockRequest = { app: { locals: {} }, user: { id: 'testUserId' } }; + + const originalEnv = process.env; + + beforeEach(() => { + jest.resetAllMocks(); + jest.resetModules(); + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should return an empty object if customConfig is null', async () => { + getCustomConfig.mockResolvedValue(null); + const result = await loadConfigModels(mockRequest); + expect(result).toEqual({}); + }); + + it('handles azure models and endpoint correctly', async () => { + mockRequest.app.locals.azureOpenAI = { modelNames: ['model1', 'model2'] }; + getCustomConfig.mockResolvedValue({ + endpoints: { + azureOpenAI: { + models: ['model1', 'model2'], + }, + }, + }); + + const result = await loadConfigModels(mockRequest); + expect(result.azureOpenAI).toEqual(['model1', 'model2']); + }); + + it('fetches custom models based on the unique key', async () => { + process.env.BASE_URL = 'http://example.com'; + process.env.API_KEY = 'some-api-key'; + const customEndpoints = { + custom: [ + { + baseURL: '${BASE_URL}', + apiKey: '${API_KEY}', + name: 'CustomModel', + models: { fetch: true }, + }, + ], + }; + + getCustomConfig.mockResolvedValue({ endpoints: customEndpoints }); + fetchModels.mockResolvedValue(['customModel1', 'customModel2']); + + const result = await loadConfigModels(mockRequest); + expect(fetchModels).toHaveBeenCalled(); + expect(result.CustomModel).toEqual(['customModel1', 'customModel2']); + }); + + it('correctly associates models to names using unique keys', async () => { + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + baseURL: 'http://example.com', + apiKey: 'API_KEY1', + name: 'Model1', + models: { fetch: true }, + }, + { + baseURL: 'http://example.com', + apiKey: 'API_KEY2', + name: 'Model2', + models: { fetch: true }, + }, + ], + }, + }); + fetchModels.mockImplementation(({ apiKey }) => + Promise.resolve(apiKey === 'API_KEY1' ? ['model1Data'] : ['model2Data']), + ); + + const result = await loadConfigModels(mockRequest); + expect(result.Model1).toEqual(['model1Data']); + expect(result.Model2).toEqual(['model2Data']); + }); + + it('correctly handles multiple endpoints with the same baseURL but different apiKeys', async () => { + // Mock the custom configuration to simulate the user's scenario + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'LiteLLM', + apiKey: '${LITELLM_ALL_MODELS}', + baseURL: '${LITELLM_HOST}', + models: { fetch: true }, + }, + { + name: 'OpenAI', + apiKey: '${LITELLM_OPENAI_MODELS}', + baseURL: '${LITELLM_SECOND_HOST}', + models: { fetch: true }, + }, + { + name: 'Google', + apiKey: '${LITELLM_GOOGLE_MODELS}', + baseURL: '${LITELLM_SECOND_HOST}', + models: { fetch: true }, + }, + ], + }, + }); + + // Mock `fetchModels` to return different models based on the apiKey + fetchModels.mockImplementation(({ apiKey }) => { + switch (apiKey) { + case '${LITELLM_ALL_MODELS}': + return Promise.resolve(['AllModel1', 'AllModel2']); + case '${LITELLM_OPENAI_MODELS}': + return Promise.resolve(['OpenAIModel']); + case '${LITELLM_GOOGLE_MODELS}': + return Promise.resolve(['GoogleModel']); + default: + return Promise.resolve([]); + } + }); + + const result = await loadConfigModels(mockRequest); + + // Assert that the models are correctly fetched and mapped based on unique keys + expect(result.LiteLLM).toEqual(['AllModel1', 'AllModel2']); + expect(result.OpenAI).toEqual(['OpenAIModel']); + expect(result.Google).toEqual(['GoogleModel']); + + // Ensure that fetchModels was called with correct parameters + expect(fetchModels).toHaveBeenCalledTimes(3); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_ALL_MODELS}' }), + ); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_OPENAI_MODELS}' }), + ); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_GOOGLE_MODELS}' }), + ); + }); + + it('loads models based on custom endpoint configuration respecting fetch rules', async () => { + process.env.MY_PRECIOUS_MISTRAL_KEY = 'actual_mistral_api_key'; + process.env.MY_OPENROUTER_API_KEY = 'actual_openrouter_api_key'; + // Setup custom configuration with specific API keys for Mistral and OpenRouter + // and "user_provided" for groq and Ollama, indicating no fetch for the latter two + getCustomConfig.mockResolvedValue(exampleConfig); + + // Assuming fetchModels would be called only for Mistral and OpenRouter + fetchModels.mockImplementation(({ name }) => { + switch (name) { + case 'Mistral': + return Promise.resolve([ + 'mistral-tiny', + 'mistral-small', + 'mistral-medium', + 'mistral-large-latest', + ]); + case 'OpenRouter': + return Promise.resolve(['gpt-3.5-turbo']); + default: + return Promise.resolve([]); + } + }); + + const result = await loadConfigModels(mockRequest); + + // Since fetch is true and apiKey is not "user_provided", fetching occurs for Mistral and OpenRouter + expect(result.Mistral).toEqual([ + 'mistral-tiny', + 'mistral-small', + 'mistral-medium', + 'mistral-large-latest', + ]); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Mistral', + apiKey: process.env.MY_PRECIOUS_MISTRAL_KEY, + }), + ); + + expect(result.OpenRouter).toEqual(['gpt-3.5-turbo']); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'OpenRouter', + apiKey: process.env.MY_OPENROUTER_API_KEY, + }), + ); + + // For groq and Ollama, since the apiKey is "user_provided", models should not be fetched + // Depending on your implementation's behavior regarding "default" models without fetching, + // you may need to adjust the following assertions: + expect(result.groq).toBe(exampleConfig.endpoints.custom[2].models.default); + expect(result.Ollama).toBe(exampleConfig.endpoints.custom[3].models.default); + + // Verifying fetchModels was not called for groq and Ollama + expect(fetchModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + name: 'groq', + }), + ); + expect(fetchModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Ollama', + }), + ); + }); + + it('falls back to default models if fetching returns an empty array', async () => { + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'EndpointWithSameFetchKey', + apiKey: 'API_KEY', + baseURL: 'http://example.com', + models: { + fetch: true, + default: ['defaultModel1'], + }, + }, + { + name: 'EmptyFetchModel', + apiKey: 'API_KEY', + baseURL: 'http://example.com', + models: { + fetch: true, + default: ['defaultModel1', 'defaultModel2'], + }, + }, + ], + }, + }); + + fetchModels.mockResolvedValue([]); + + const result = await loadConfigModels(mockRequest); + expect(fetchModels).toHaveBeenCalledTimes(1); + expect(result.EmptyFetchModel).toEqual(['defaultModel1', 'defaultModel2']); + }); + + it('falls back to default models if fetching returns a falsy value', async () => { + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'FalsyFetchModel', + apiKey: 'API_KEY', + baseURL: 'http://example.com', + models: { + fetch: true, + default: ['defaultModel1', 'defaultModel2'], + }, + }, + ], + }, + }); + + fetchModels.mockResolvedValue(false); + + const result = await loadConfigModels(mockRequest); + + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'FalsyFetchModel', + apiKey: 'API_KEY', + }), + ); + + expect(result.FalsyFetchModel).toEqual(['defaultModel1', 'defaultModel2']); + }); +}); diff --git a/api/server/services/Config/loadCustomConfig.js b/api/server/services/Config/loadCustomConfig.js index c17d3283b47..617cd7d9469 100644 --- a/api/server/services/Config/loadCustomConfig.js +++ b/api/server/services/Config/loadCustomConfig.js @@ -1,31 +1,68 @@ const path = require('path'); const { CacheKeys, configSchema } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); const loadYaml = require('~/utils/loadYaml'); -const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); +const axios = require('axios'); +const yaml = require('js-yaml'); const projectRoot = path.resolve(__dirname, '..', '..', '..', '..'); -const configPath = path.resolve(projectRoot, 'librechat.yaml'); +const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml'); + +let i = 0; /** * Load custom configuration files and caches the object if the `cache` field at root is true. * Validation via parsing the config file with the config schema. * @function loadCustomConfig - * @returns {Promise<null | Object>} A promise that resolves to null or the custom config object. + * @returns {Promise<TCustomConfig | null>} A promise that resolves to null or the custom config object. * */ - async function loadCustomConfig() { - const customConfig = loadYaml(configPath); - if (!customConfig) { - return null; + // Use CONFIG_PATH if set, otherwise fallback to defaultConfigPath + const configPath = process.env.CONFIG_PATH || defaultConfigPath; + + let customConfig; + + if (/^https?:\/\//.test(configPath)) { + try { + const response = await axios.get(configPath); + customConfig = response.data; + } catch (error) { + i === 0 && logger.error(`Failed to fetch the remote config file from ${configPath}`, error); + i === 0 && i++; + return null; + } + } else { + customConfig = loadYaml(configPath); + if (!customConfig) { + i === 0 && + logger.info( + 'Custom config file missing or YAML format invalid.\n\nCheck out the latest config file guide for configurable options and features.\nhttps://docs.librechat.ai/install/configuration/custom_config.html\n\n', + ); + i === 0 && i++; + return null; + } + } + + if (typeof customConfig === 'string') { + try { + customConfig = yaml.load(customConfig); + } catch (parseError) { + i === 0 && logger.info(`Failed to parse the YAML config from ${configPath}`, parseError); + i === 0 && i++; + return null; + } } const result = configSchema.strict().safeParse(customConfig); if (!result.success) { - logger.error(`Invalid custom config file at ${configPath}`, result.error); + i === 0 && logger.error(`Invalid custom config file at ${configPath}`, result.error); + i === 0 && i++; return null; } else { - logger.info('Loaded custom config file'); + logger.info('Custom config file loaded:'); + logger.info(JSON.stringify(customConfig, null, 2)); + logger.debug('Custom config:', customConfig); } if (customConfig.cache) { @@ -33,8 +70,6 @@ async function loadCustomConfig() { await cache.set(CacheKeys.CUSTOM_CONFIG, customConfig); } - // TODO: handle remote config - return customConfig; } diff --git a/api/server/services/Config/loadCustomConfig.spec.js b/api/server/services/Config/loadCustomConfig.spec.js new file mode 100644 index 00000000000..24553b9f3ea --- /dev/null +++ b/api/server/services/Config/loadCustomConfig.spec.js @@ -0,0 +1,153 @@ +jest.mock('axios'); +jest.mock('~/cache/getLogStores'); +jest.mock('~/utils/loadYaml'); + +const axios = require('axios'); +const loadCustomConfig = require('./loadCustomConfig'); +const getLogStores = require('~/cache/getLogStores'); +const loadYaml = require('~/utils/loadYaml'); +const { logger } = require('~/config'); + +describe('loadCustomConfig', () => { + const mockSet = jest.fn(); + const mockCache = { set: mockSet }; + + beforeEach(() => { + jest.resetAllMocks(); + delete process.env.CONFIG_PATH; + getLogStores.mockReturnValue(mockCache); + }); + + it('should return null and log error if remote config fetch fails', async () => { + process.env.CONFIG_PATH = 'http://example.com/config.yaml'; + axios.get.mockRejectedValue(new Error('Network error')); + const result = await loadCustomConfig(); + expect(logger.error).toHaveBeenCalledTimes(1); + expect(result).toBeNull(); + }); + + it('should return null for an invalid local config file', async () => { + process.env.CONFIG_PATH = 'localConfig.yaml'; + loadYaml.mockReturnValueOnce(null); + const result = await loadCustomConfig(); + expect(result).toBeNull(); + }); + + it('should parse, validate, and cache a valid local configuration', async () => { + const mockConfig = { + version: '1.0', + cache: true, + endpoints: { + custom: [ + { + name: 'mistral', + apiKey: 'user_provided', + baseURL: 'https://api.mistral.ai/v1', + }, + ], + }, + }; + process.env.CONFIG_PATH = 'validConfig.yaml'; + loadYaml.mockReturnValueOnce(mockConfig); + const result = await loadCustomConfig(); + + expect(result).toEqual(mockConfig); + expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig); + }); + + it('should return null and log if config schema validation fails', async () => { + const invalidConfig = { invalidField: true }; + process.env.CONFIG_PATH = 'invalidConfig.yaml'; + loadYaml.mockReturnValueOnce(invalidConfig); + + const result = await loadCustomConfig(); + + expect(result).toBeNull(); + }); + + it('should handle and return null on YAML parse error for a string response from remote', async () => { + process.env.CONFIG_PATH = 'http://example.com/config.yaml'; + axios.get.mockResolvedValue({ data: 'invalidYAMLContent' }); + + const result = await loadCustomConfig(); + + expect(result).toBeNull(); + }); + + it('should return the custom config object for a valid remote config file', async () => { + const mockConfig = { + version: '1.0', + cache: true, + endpoints: { + custom: [ + { + name: 'mistral', + apiKey: 'user_provided', + baseURL: 'https://api.mistral.ai/v1', + }, + ], + }, + }; + process.env.CONFIG_PATH = 'http://example.com/config.yaml'; + axios.get.mockResolvedValue({ data: mockConfig }); + const result = await loadCustomConfig(); + expect(result).toEqual(mockConfig); + expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig); + }); + + it('should return null if the remote config file is not found', async () => { + process.env.CONFIG_PATH = 'http://example.com/config.yaml'; + axios.get.mockRejectedValue({ response: { status: 404 } }); + const result = await loadCustomConfig(); + expect(result).toBeNull(); + }); + + it('should return null if the local config file is not found', async () => { + process.env.CONFIG_PATH = 'nonExistentConfig.yaml'; + loadYaml.mockReturnValueOnce(null); + const result = await loadCustomConfig(); + expect(result).toBeNull(); + }); + + it('should not cache the config if cache is set to false', async () => { + const mockConfig = { + version: '1.0', + cache: false, + endpoints: { + custom: [ + { + name: 'mistral', + apiKey: 'user_provided', + baseURL: 'https://api.mistral.ai/v1', + }, + ], + }, + }; + process.env.CONFIG_PATH = 'validConfig.yaml'; + loadYaml.mockReturnValueOnce(mockConfig); + await loadCustomConfig(); + expect(mockSet).not.toHaveBeenCalled(); + }); + + it('should log the loaded custom config', async () => { + const mockConfig = { + version: '1.0', + cache: true, + endpoints: { + custom: [ + { + name: 'mistral', + apiKey: 'user_provided', + baseURL: 'https://api.mistral.ai/v1', + }, + ], + }, + }; + process.env.CONFIG_PATH = 'validConfig.yaml'; + loadYaml.mockReturnValueOnce(mockConfig); + await loadCustomConfig(); + expect(logger.info).toHaveBeenCalledWith('Custom config file loaded:'); + expect(logger.info).toHaveBeenCalledWith(JSON.stringify(mockConfig, null, 2)); + expect(logger.debug).toHaveBeenCalledWith('Custom config:', mockConfig); + }); +}); diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index 34ab05d8ab8..960dfb4c77a 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -1,36 +1,21 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { EModelEndpoint, getEnabledEndpoints } = require('librechat-data-provider'); const loadAsyncEndpoints = require('./loadAsyncEndpoints'); const { config } = require('./EndpointService'); /** * Load async endpoints and return a configuration object - * @function loadDefaultEndpointsConfig + * @param {Express.Request} req - The request object * @returns {Promise<Object.<string, EndpointWithOrder>>} An object whose keys are endpoint names and values are objects that contain the endpoint configuration and an order. */ -async function loadDefaultEndpointsConfig() { - const { google, gptPlugins } = await loadAsyncEndpoints(); - const { openAI, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config; +async function loadDefaultEndpointsConfig(req) { + const { google, gptPlugins } = await loadAsyncEndpoints(req); + const { openAI, assistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config; - let enabledEndpoints = [ - EModelEndpoint.openAI, - EModelEndpoint.azureOpenAI, - EModelEndpoint.google, - EModelEndpoint.bingAI, - EModelEndpoint.chatGPTBrowser, - EModelEndpoint.gptPlugins, - EModelEndpoint.anthropic, - ]; - - const endpointsEnv = process.env.ENDPOINTS || ''; - if (endpointsEnv) { - enabledEndpoints = endpointsEnv - .split(',') - .filter((endpoint) => endpoint?.trim()) - .map((endpoint) => endpoint.trim()); - } + const enabledEndpoints = getEnabledEndpoints(); const endpointConfig = { [EModelEndpoint.openAI]: openAI, + [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.google]: google, [EModelEndpoint.bingAI]: bingAI, diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 665aa714790..e0b2ca0e4f9 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -7,17 +7,24 @@ const { getChatGPTBrowserModels, } = require('~/server/services/ModelService'); -const fitlerAssistantModels = (str) => { - return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str); -}; - -async function loadDefaultModels() { +/** + * Loads the default models for the application. + * @async + * @function + * @param {Express.Request} req - The Express request object. + */ +async function loadDefaultModels(req) { const google = getGoogleModels(); - const openAI = await getOpenAIModels(); + const openAI = await getOpenAIModels({ user: req.user.id }); const anthropic = getAnthropicModels(); const chatGPTBrowser = getChatGPTBrowserModels(); - const azureOpenAI = await getOpenAIModels({ azure: true }); - const gptPlugins = await getOpenAIModels({ azure: useAzurePlugins, plugins: true }); + const azureOpenAI = await getOpenAIModels({ user: req.user.id, azure: true }); + const gptPlugins = await getOpenAIModels({ + user: req.user.id, + azure: useAzurePlugins, + plugins: true, + }); + const assistants = await getOpenAIModels({ assistants: true }); return { [EModelEndpoint.openAI]: openAI, @@ -27,7 +34,7 @@ async function loadDefaultModels() { [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.bingAI]: ['BingAI', 'Sydney'], [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, - [EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels), + [EModelEndpoint.assistants]: assistants, }; } diff --git a/api/server/services/Endpoints/anthropic/addTitle.js b/api/server/services/Endpoints/anthropic/addTitle.js new file mode 100644 index 00000000000..30dddd1c3f8 --- /dev/null +++ b/api/server/services/Endpoints/anthropic/addTitle.js @@ -0,0 +1,32 @@ +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const { isEnabled } = require('~/server/utils'); +const { saveConvo } = require('~/models'); + +const addTitle = async (req, { text, response, client }) => { + const { TITLE_CONVO = 'true' } = process.env ?? {}; + if (!isEnabled(TITLE_CONVO)) { + return; + } + + if (client.options.titleConvo === false) { + return; + } + + // If the request was aborted, don't generate the title. + if (client.abortController.signal.aborted) { + return; + } + + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${response.conversationId}`; + + const title = await client.titleConvo({ text, responseText: response?.text }); + await titleCache.set(key, title, 120000); + await saveConvo(req.user.id, { + conversationId: response.conversationId, + title, + }); +}; + +module.exports = addTitle; diff --git a/api/server/services/Endpoints/anthropic/buildOptions.js b/api/server/services/Endpoints/anthropic/buildOptions.js index 2b0143d2b07..4cd9ba8b925 100644 --- a/api/server/services/Endpoints/anthropic/buildOptions.js +++ b/api/server/services/Endpoints/anthropic/buildOptions.js @@ -1,9 +1,10 @@ const buildOptions = (endpoint, parsedBody) => { - const { modelLabel, promptPrefix, ...rest } = parsedBody; + const { modelLabel, promptPrefix, resendFiles, ...rest } = parsedBody; const endpointOption = { endpoint, modelLabel, promptPrefix, + resendFiles, modelOptions: { ...rest, }, diff --git a/api/server/services/Endpoints/anthropic/index.js b/api/server/services/Endpoints/anthropic/index.js index 84e4bd5973a..772b1efb118 100644 --- a/api/server/services/Endpoints/anthropic/index.js +++ b/api/server/services/Endpoints/anthropic/index.js @@ -1,8 +1,9 @@ +const addTitle = require('./addTitle'); const buildOptions = require('./buildOptions'); const initializeClient = require('./initializeClient'); module.exports = { - // addTitle, // todo + addTitle, buildOptions, initializeClient, }; diff --git a/api/server/services/Endpoints/assistants/addTitle.js b/api/server/services/Endpoints/assistants/addTitle.js new file mode 100644 index 00000000000..7cca98cc7bc --- /dev/null +++ b/api/server/services/Endpoints/assistants/addTitle.js @@ -0,0 +1,28 @@ +const { CacheKeys } = require('librechat-data-provider'); +const { saveConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { isEnabled } = require('~/server/utils'); + +const addTitle = async (req, { text, responseText, conversationId, client }) => { + const { TITLE_CONVO = 'true' } = process.env ?? {}; + if (!isEnabled(TITLE_CONVO)) { + return; + } + + if (client.options.titleConvo === false) { + return; + } + + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${conversationId}`; + + const title = await client.titleConvo({ text, conversationId, responseText }); + await titleCache.set(key, title, 120000); + + await saveConvo(req.user.id, { + conversationId, + title, + }); +}; + +module.exports = addTitle; diff --git a/api/server/services/Endpoints/assistants/buildOptions.js b/api/server/services/Endpoints/assistants/buildOptions.js new file mode 100644 index 00000000000..c670953539d --- /dev/null +++ b/api/server/services/Endpoints/assistants/buildOptions.js @@ -0,0 +1,16 @@ +const buildOptions = (endpoint, parsedBody) => { + // eslint-disable-next-line no-unused-vars + const { promptPrefix, assistant_id, ...rest } = parsedBody; + const endpointOption = { + endpoint, + promptPrefix, + assistant_id, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/services/Endpoints/assistants/index.js b/api/server/services/Endpoints/assistants/index.js new file mode 100644 index 00000000000..10e94f2cd4f --- /dev/null +++ b/api/server/services/Endpoints/assistants/index.js @@ -0,0 +1,96 @@ +const addTitle = require('./addTitle'); +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +/** + * Asynchronously lists assistants based on provided query parameters. + * + * Initializes the client with the current request and response objects and lists assistants + * according to the query parameters. This function abstracts the logic for non-Azure paths. + * + * @async + * @param {object} params - The parameters object. + * @param {object} params.req - The request object, used for initializing the client. + * @param {object} params.res - The response object, used for initializing the client. + * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). + * @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call. + */ +const listAssistants = async ({ req, res, query }) => { + const { openai } = await initializeClient({ req, res }); + return openai.beta.assistants.list(query); +}; + +/** + * Asynchronously lists assistants for Azure configured groups. + * + * Iterates through Azure configured assistant groups, initializes the client with the current request and response objects, + * lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array. + * + * @async + * @param {object} params - The parameters object. + * @param {object} params.req - The request object, used for initializing the client and manipulating the request body. + * @param {object} params.res - The response object, used for initializing the client. + * @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap. + * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). + * @returns {Promise<AssistantListResponse>} A promise that resolves to an array of assistant data merged with their respective model information. + */ +const listAssistantsForAzure = async ({ req, res, azureConfig = {}, query }) => { + /** @type {Array<[string, TAzureModelConfig]>} */ + const groupModelTuples = []; + const promises = []; + /** @type {Array<TAzureGroup>} */ + const groups = []; + + const { groupMap, assistantGroups } = azureConfig; + + for (const groupName of assistantGroups) { + const group = groupMap[groupName]; + groups.push(group); + + const currentModelTuples = Object.entries(group?.models); + groupModelTuples.push(currentModelTuples); + + /* The specified model is only necessary to + fetch assistants for the shared instance */ + req.body.model = currentModelTuples[0][0]; + promises.push(listAssistants({ req, res, query })); + } + + const resolvedQueries = await Promise.all(promises); + const data = resolvedQueries.flatMap((res, i) => + res.data.map((assistant) => { + const deploymentName = assistant.model; + const currentGroup = groups[i]; + const currentModelTuples = groupModelTuples[i]; + const firstModel = currentModelTuples[0][0]; + + if (currentGroup.deploymentName === deploymentName) { + return { ...assistant, model: firstModel }; + } + + for (const [model, modelConfig] of currentModelTuples) { + if (modelConfig.deploymentName === deploymentName) { + return { ...assistant, model }; + } + } + + return { ...assistant, model: firstModel }; + }), + ); + + return { + first_id: data[0]?.id, + last_id: data[data.length - 1]?.id, + object: 'list', + has_more: false, + data, + }; +}; + +module.exports = { + addTitle, + buildOptions, + initializeClient, + listAssistants, + listAssistantsForAzure, +}; diff --git a/api/server/services/Endpoints/assistants/initializeClient.js b/api/server/services/Endpoints/assistants/initializeClient.js new file mode 100644 index 00000000000..05a9232f9fe --- /dev/null +++ b/api/server/services/Endpoints/assistants/initializeClient.js @@ -0,0 +1,148 @@ +const OpenAI = require('openai'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { + EModelEndpoint, + resolveHeaders, + mapModelToAzureConfig, +} = require('librechat-data-provider'); +const { + getUserKey, + getUserKeyExpiry, + checkUserKeyExpiry, +} = require('~/server/services/UserService'); +const OpenAIClient = require('~/app/clients/OpenAIClient'); +const { isUserProvided } = require('~/server/utils'); +const { constructAzureURL } = require('~/utils'); + +const initializeClient = async ({ req, res, endpointOption, initAppClient = false }) => { + const { PROXY, OPENAI_ORGANIZATION, ASSISTANTS_API_KEY, ASSISTANTS_BASE_URL } = process.env; + + const userProvidesKey = isUserProvided(ASSISTANTS_API_KEY); + const userProvidesURL = isUserProvided(ASSISTANTS_BASE_URL); + + let userValues = null; + if (userProvidesKey || userProvidesURL) { + const expiresAt = await getUserKeyExpiry({ + userId: req.user.id, + name: EModelEndpoint.assistants, + }); + checkUserKeyExpiry( + expiresAt, + 'Your Assistants API key has expired. Please provide your API key again.', + ); + userValues = await getUserKey({ userId: req.user.id, name: EModelEndpoint.assistants }); + try { + userValues = JSON.parse(userValues); + } catch (e) { + throw new Error( + 'Invalid JSON provided for Assistants API user values. Please provide them again.', + ); + } + } + + let apiKey = userProvidesKey ? userValues.apiKey : ASSISTANTS_API_KEY; + let baseURL = userProvidesURL ? userValues.baseURL : ASSISTANTS_BASE_URL; + + const opts = {}; + + const clientOptions = { + reverseProxyUrl: baseURL ?? null, + proxy: PROXY ?? null, + req, + res, + ...endpointOption, + }; + + /** @type {TAzureConfig | undefined} */ + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + + /** @type {AzureOptions | undefined} */ + let azureOptions; + + if (azureConfig && azureConfig.assistants) { + const { modelGroupMap, groupMap, assistantModels } = azureConfig; + const modelName = req.body.model ?? req.query.model ?? assistantModels[0]; + const { + azureOptions: currentOptions, + baseURL: azureBaseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName, + modelGroupMap, + groupMap, + }); + + azureOptions = currentOptions; + + baseURL = constructAzureURL({ + baseURL: azureBaseURL ?? 'https://${INSTANCE_NAME}.openai.azure.com/openai', + azureOptions, + }); + + apiKey = azureOptions.azureOpenAIApiKey; + opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion }; + opts.defaultHeaders = resolveHeaders({ ...headers, 'api-key': apiKey }); + opts.model = azureOptions.azureOpenAIApiDeploymentName; + + if (initAppClient) { + clientOptions.titleConvo = azureConfig.titleConvo; + clientOptions.titleModel = azureConfig.titleModel; + clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; + + const groupName = modelGroupMap[modelName].group; + clientOptions.addParams = azureConfig.groupMap[groupName].addParams; + clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; + clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; + + clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; + clientOptions.headers = opts.defaultHeaders; + clientOptions.azure = !serverless && azureOptions; + } + } + + if (!apiKey) { + throw new Error('Assistants API key not provided. Please provide it again.'); + } + + if (baseURL) { + opts.baseURL = baseURL; + } + + if (PROXY) { + opts.httpAgent = new HttpsProxyAgent(PROXY); + } + + if (OPENAI_ORGANIZATION) { + opts.organization = OPENAI_ORGANIZATION; + } + + /** @type {OpenAIClient} */ + const openai = new OpenAI({ + apiKey, + ...opts, + }); + + openai.req = req; + openai.res = res; + + if (azureOptions) { + openai.locals = { ...(openai.locals ?? {}), azureOptions }; + } + + if (endpointOption && initAppClient) { + const client = new OpenAIClient(apiKey, clientOptions); + return { + client, + openai, + openAIApiKey: apiKey, + }; + } + + return { + openai, + openAIApiKey: apiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/services/Endpoints/assistants/initializeClient.spec.js b/api/server/services/Endpoints/assistants/initializeClient.spec.js new file mode 100644 index 00000000000..3a1e4692738 --- /dev/null +++ b/api/server/services/Endpoints/assistants/initializeClient.spec.js @@ -0,0 +1,99 @@ +// const OpenAI = require('openai'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { getUserKey, getUserKeyExpiry } = require('~/server/services/UserService'); +const initializeClient = require('./initializeClient'); +// const { OpenAIClient } = require('~/app'); + +jest.mock('~/server/services/UserService', () => ({ + getUserKey: jest.fn(), + getUserKeyExpiry: jest.fn(), + checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, +})); + +const today = new Date(); +const tenDaysFromToday = new Date(today.setDate(today.getDate() + 10)); +const isoString = tenDaysFromToday.toISOString(); + +describe('initializeClient', () => { + // Set up environment variables + const originalEnvironment = process.env; + const app = { + locals: {}, + }; + + beforeEach(() => { + jest.resetModules(); // Clears the cache + process.env = { ...originalEnvironment }; // Make a copy + }); + + afterAll(() => { + process.env = originalEnvironment; // Restore original env vars + }); + + test('initializes OpenAI client with default API key and URL', async () => { + process.env.ASSISTANTS_API_KEY = 'default-api-key'; + process.env.ASSISTANTS_BASE_URL = 'https://default.api.url'; + + // Assuming 'isUserProvided' to return false for this test case + jest.mock('~/server/utils', () => ({ + isUserProvided: jest.fn().mockReturnValueOnce(false), + })); + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai, openAIApiKey } = await initializeClient({ req, res }); + expect(openai.apiKey).toBe('default-api-key'); + expect(openAIApiKey).toBe('default-api-key'); + expect(openai.baseURL).toBe('https://default.api.url'); + }); + + test('initializes OpenAI client with user-provided API key and URL', async () => { + process.env.ASSISTANTS_API_KEY = 'user_provided'; + process.env.ASSISTANTS_BASE_URL = 'user_provided'; + + getUserKey.mockResolvedValue( + JSON.stringify({ apiKey: 'user-api-key', baseURL: 'https://user.api.url' }), + ); + getUserKeyExpiry.mockResolvedValue(isoString); + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai, openAIApiKey } = await initializeClient({ req, res }); + expect(openAIApiKey).toBe('user-api-key'); + expect(openai.apiKey).toBe('user-api-key'); + expect(openai.baseURL).toBe('https://user.api.url'); + }); + + test('throws error for invalid JSON in user-provided values', async () => { + process.env.ASSISTANTS_API_KEY = 'user_provided'; + getUserKey.mockResolvedValue('invalid-json'); + getUserKeyExpiry.mockResolvedValue(isoString); + + const req = { user: { id: 'user123' } }; + const res = {}; + + await expect(initializeClient({ req, res })).rejects.toThrow(/Invalid JSON/); + }); + + test('throws error if API key is not provided', async () => { + delete process.env.ASSISTANTS_API_KEY; // Simulate missing API key + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + await expect(initializeClient({ req, res })).rejects.toThrow(/Assistants API key not/); + }); + + test('initializes OpenAI client with proxy configuration', async () => { + process.env.ASSISTANTS_API_KEY = 'test-key'; + process.env.PROXY = 'http://proxy.server'; + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai } = await initializeClient({ req, res }); + expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent); + }); +}); diff --git a/api/server/services/Endpoints/custom/buildOptions.js b/api/server/services/Endpoints/custom/buildOptions.js index 0bba48e2b95..3d937957323 100644 --- a/api/server/services/Endpoints/custom/buildOptions.js +++ b/api/server/services/Endpoints/custom/buildOptions.js @@ -1,11 +1,11 @@ const buildOptions = (endpoint, parsedBody, endpointType) => { - const { chatGptLabel, promptPrefix, resendImages, imageDetail, ...rest } = parsedBody; + const { chatGptLabel, promptPrefix, resendFiles, imageDetail, ...rest } = parsedBody; const endpointOption = { endpoint, endpointType, chatGptLabel, promptPrefix, - resendImages, + resendFiles, imageDetail, modelOptions: { ...rest, diff --git a/api/server/services/Endpoints/custom/initializeClient.js b/api/server/services/Endpoints/custom/initializeClient.js index 978506b7b47..a80f5efaa7a 100644 --- a/api/server/services/Endpoints/custom/initializeClient.js +++ b/api/server/services/Endpoints/custom/initializeClient.js @@ -1,11 +1,17 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { + CacheKeys, + envVarRegex, + EModelEndpoint, + FetchTokenConfig, + extractEnvVariable, +} = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { isUserProvided, extractEnvVariable } = require('~/server/utils'); -const getCustomConfig = require('~/cache/getCustomConfig'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); +const { fetchModels } = require('~/server/services/ModelService'); +const getLogStores = require('~/cache/getLogStores'); +const { isUserProvided } = require('~/server/utils'); const { OpenAIClient } = require('~/app'); -const envVarRegex = /^\${(.+)}$/; - const { PROXY } = process.env; const initializeClient = async ({ req, res, endpointOption }) => { @@ -37,24 +43,11 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error(`Missing Base URL for ${endpoint}.`); } - const customOptions = { - headers: resolvedHeaders, - addParams: endpointConfig.addParams, - dropParams: endpointConfig.dropParams, - titleConvo: endpointConfig.titleConvo, - titleModel: endpointConfig.titleModel, - forcePrompt: endpointConfig.forcePrompt, - summaryModel: endpointConfig.summaryModel, - modelDisplayLabel: endpointConfig.modelDisplayLabel, - titleMethod: endpointConfig.titleMethod ?? 'completion', - contextStrategy: endpointConfig.summarize ? 'summarize' : null, - }; - - const useUserKey = isUserProvided(CUSTOM_API_KEY); - const useUserURL = isUserProvided(CUSTOM_BASE_URL); + const userProvidesKey = isUserProvided(CUSTOM_API_KEY); + const userProvidesURL = isUserProvided(CUSTOM_BASE_URL); let userValues = null; - if (expiresAt && (useUserKey || useUserURL)) { + if (expiresAt && (userProvidesKey || userProvidesURL)) { checkUserKeyExpiry( expiresAt, `Your API values for ${endpoint} have expired. Please configure them again.`, @@ -67,8 +60,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { } } - let apiKey = useUserKey ? userValues.apiKey : CUSTOM_API_KEY; - let baseURL = useUserURL ? userValues.baseURL : CUSTOM_BASE_URL; + let apiKey = userProvidesKey ? userValues?.apiKey : CUSTOM_API_KEY; + let baseURL = userProvidesURL ? userValues?.baseURL : CUSTOM_BASE_URL; if (!apiKey) { throw new Error(`${endpoint} API key not provided.`); @@ -78,6 +71,41 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error(`${endpoint} Base URL not provided.`); } + const cache = getLogStores(CacheKeys.TOKEN_CONFIG); + const tokenKey = + !endpointConfig.tokenConfig && (userProvidesKey || userProvidesURL) + ? `${endpoint}:${req.user.id}` + : endpoint; + + let endpointTokenConfig = + !endpointConfig.tokenConfig && + FetchTokenConfig[endpoint.toLowerCase()] && + (await cache.get(tokenKey)); + + if ( + FetchTokenConfig[endpoint.toLowerCase()] && + endpointConfig && + endpointConfig.models.fetch && + !endpointTokenConfig + ) { + await fetchModels({ apiKey, baseURL, name: endpoint, user: req.user.id, tokenKey }); + endpointTokenConfig = await cache.get(tokenKey); + } + + const customOptions = { + headers: resolvedHeaders, + addParams: endpointConfig.addParams, + dropParams: endpointConfig.dropParams, + titleConvo: endpointConfig.titleConvo, + titleModel: endpointConfig.titleModel, + forcePrompt: endpointConfig.forcePrompt, + summaryModel: endpointConfig.summaryModel, + modelDisplayLabel: endpointConfig.modelDisplayLabel, + titleMethod: endpointConfig.titleMethod ?? 'completion', + contextStrategy: endpointConfig.summarize ? 'summarize' : null, + endpointTokenConfig, + }; + const clientOptions = { reverseProxyUrl: baseURL ?? null, proxy: PROXY ?? null, diff --git a/api/server/services/Endpoints/google/initializeClient.spec.js b/api/server/services/Endpoints/google/initializeClient.spec.js index 8587c71e2d8..e39e51b8571 100644 --- a/api/server/services/Endpoints/google/initializeClient.spec.js +++ b/api/server/services/Endpoints/google/initializeClient.spec.js @@ -1,3 +1,5 @@ +// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets + const initializeClient = require('./initializeClient'); const { GoogleClient } = require('~/app'); const { checkUserKeyExpiry, getUserKey } = require('../../UserService'); diff --git a/api/server/services/Endpoints/gptPlugins/initializeClient.js b/api/server/services/Endpoints/gptPlugins/initializeClient.js index 54ea822e494..2920a589176 100644 --- a/api/server/services/Endpoints/gptPlugins/initializeClient.js +++ b/api/server/services/Endpoints/gptPlugins/initializeClient.js @@ -1,7 +1,11 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { + EModelEndpoint, + mapModelToAzureConfig, + resolveHeaders, +} = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); +const { isEnabled, isUserProvided } = require('~/server/utils'); const { getAzureCredentials } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); const { PluginsClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption }) => { @@ -16,57 +20,96 @@ const initializeClient = async ({ req, res, endpointOption }) => { DEBUG_PLUGINS, } = process.env; - const { key: expiresAt } = req.body; + const { key: expiresAt, model: modelName } = req.body; const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null; - const useAzure = isEnabled(PLUGINS_USE_AZURE); - const endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI; + let useAzure = isEnabled(PLUGINS_USE_AZURE); + let endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI; + + /** @type {false | TAzureConfig} */ + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + useAzure = useAzure || azureConfig?.plugins; + + if (useAzure && endpoint !== EModelEndpoint.azureOpenAI) { + endpoint = EModelEndpoint.azureOpenAI; + } + + const credentials = { + [EModelEndpoint.openAI]: OPENAI_API_KEY, + [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, + }; const baseURLOptions = { [EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY, [EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL, }; - const reverseProxyUrl = baseURLOptions[endpoint] ?? null; + const userProvidesKey = isUserProvided(credentials[endpoint]); + const userProvidesURL = isUserProvided(baseURLOptions[endpoint]); + + let userValues = null; + if (expiresAt && (userProvidesKey || userProvidesURL)) { + checkUserKeyExpiry( + expiresAt, + 'Your OpenAI API values have expired. Please provide them again.', + ); + userValues = await getUserKey({ userId: req.user.id, name: endpoint }); + try { + userValues = JSON.parse(userValues); + } catch (e) { + throw new Error( + `Invalid JSON provided for ${endpoint} user values. Please provide them again.`, + ); + } + } + + let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint]; + let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint]; const clientOptions = { contextStrategy, debug: isEnabled(DEBUG_PLUGINS), - reverseProxyUrl, + reverseProxyUrl: baseURL ? baseURL : null, proxy: PROXY ?? null, req, res, ...endpointOption, }; - const credentials = { - [EModelEndpoint.openAI]: OPENAI_API_KEY, - [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, - }; + if (useAzure && azureConfig) { + const { modelGroupMap, groupMap } = azureConfig; + const { + azureOptions, + baseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName, + modelGroupMap, + groupMap, + }); - const isUserProvided = credentials[endpoint] === 'user_provided'; + clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; + clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); - let userKey = null; - if (expiresAt && isUserProvided) { - checkUserKeyExpiry( - expiresAt, - 'Your OpenAI API key has expired. Please provide your API key again.', - ); - userKey = await getUserKey({ - userId: req.user.id, - name: endpoint, - }); - } + clientOptions.titleConvo = azureConfig.titleConvo; + clientOptions.titleModel = azureConfig.titleModel; + clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; - let apiKey = isUserProvided ? userKey : credentials[endpoint]; + const groupName = modelGroupMap[modelName].group; + clientOptions.addParams = azureConfig.groupMap[groupName].addParams; + clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; + clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; - if (useAzure || (apiKey && apiKey.includes('azure') && !clientOptions.azure)) { - clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials(); + apiKey = azureOptions.azureOpenAIApiKey; + clientOptions.azure = !serverless && azureOptions; + } else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) { + clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); apiKey = clientOptions.azure.azureOpenAIApiKey; } if (!apiKey) { - throw new Error('API key not provided.'); + throw new Error(`${endpoint} API key not provided. Please provide it again.`); } const client = new PluginsClient(apiKey, clientOptions); diff --git a/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js b/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js index 5b772209c64..280acf5aadb 100644 --- a/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js +++ b/api/server/services/Endpoints/gptPlugins/initializeClient.spec.js @@ -1,7 +1,8 @@ // gptPlugins/initializeClient.spec.js -const { PluginsClient } = require('~/app'); +const { EModelEndpoint, validateAzureGroups } = require('librechat-data-provider'); +const { getUserKey } = require('~/server/services/UserService'); const initializeClient = require('./initializeClient'); -const { getUserKey } = require('../../UserService'); +const { PluginsClient } = require('~/app'); // Mock getUserKey since it's the only function we want to mock jest.mock('~/server/services/UserService', () => ({ @@ -12,6 +13,72 @@ jest.mock('~/server/services/UserService', () => ({ describe('gptPlugins/initializeClient', () => { // Set up environment variables const originalEnvironment = process.env; + const app = { + locals: {}, + }; + + const validAzureConfigs = [ + { + group: 'librechat-westus', + apiKey: 'WESTUS_API_KEY', + instanceName: 'librechat-westus', + version: '2023-12-01-preview', + models: { + 'gpt-4-vision-preview': { + deploymentName: 'gpt-4-vision-preview', + version: '2024-02-15-preview', + }, + 'gpt-3.5-turbo': { + deploymentName: 'gpt-35-turbo', + }, + 'gpt-3.5-turbo-1106': { + deploymentName: 'gpt-35-turbo-1106', + }, + 'gpt-4': { + deploymentName: 'gpt-4', + }, + 'gpt-4-1106-preview': { + deploymentName: 'gpt-4-1106-preview', + }, + }, + }, + { + group: 'librechat-eastus', + apiKey: 'EASTUS_API_KEY', + instanceName: 'librechat-eastus', + deploymentName: 'gpt-4-turbo', + version: '2024-02-15-preview', + models: { + 'gpt-4-turbo': true, + }, + baseURL: 'https://eastus.example.com', + additionalHeaders: { + 'x-api-key': 'x-api-key-value', + }, + }, + { + group: 'mistral-inference', + apiKey: 'AZURE_MISTRAL_API_KEY', + baseURL: + 'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions', + serverless: true, + models: { + 'mistral-large': true, + }, + }, + { + group: 'llama-70b-chat', + apiKey: 'AZURE_LLAMA2_70B_API_KEY', + baseURL: + 'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions', + serverless: true, + models: { + 'llama-70b-chat': true, + }, + }, + ]; + + const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(validAzureConfigs); beforeEach(() => { jest.resetModules(); // Clears the cache @@ -31,6 +98,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -55,6 +123,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'test-model' } }; @@ -72,6 +141,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -88,6 +158,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -107,12 +178,13 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - 'API key not provided.', + `${EModelEndpoint.openAI} API key not provided.`, ); }); @@ -128,11 +200,12 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: futureDate }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; - getUserKey.mockResolvedValue('test-user-provided-openai-api-key'); + getUserKey.mockResolvedValue(JSON.stringify({ apiKey: 'test-user-provided-openai-api-key' })); const { openAIApiKey } = await initializeClient({ req, res, endpointOption }); @@ -147,14 +220,17 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: futureDate }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'test-model' } }; getUserKey.mockResolvedValue( JSON.stringify({ - azureOpenAIApiKey: 'test-user-provided-azure-api-key', - azureOpenAIApiDeploymentName: 'test-deployment', + apiKey: JSON.stringify({ + azureOpenAIApiKey: 'test-user-provided-azure-api-key', + azureOpenAIApiDeploymentName: 'test-deployment', + }), }), ); @@ -170,13 +246,12 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: expiresAt }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /Your OpenAI API key has expired/, - ); + await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(/Your OpenAI API/); }); test('should throw an error if the user-provided Azure key is invalid JSON', async () => { @@ -186,6 +261,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: new Date(Date.now() + 10000).toISOString() }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -194,7 +270,7 @@ describe('gptPlugins/initializeClient', () => { getUserKey.mockResolvedValue('invalid-json'); await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /Unexpected token/, + /Invalid JSON provided/, ); }); @@ -206,6 +282,7 @@ describe('gptPlugins/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -215,4 +292,90 @@ describe('gptPlugins/initializeClient', () => { expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy'); expect(client.options.proxy).toBe('http://proxy'); }); + + test('should throw an error when user-provided values are not valid JSON', async () => { + process.env.OPENAI_API_KEY = 'user_provided'; + const req = { + body: { key: new Date(Date.now() + 10000).toISOString(), endpoint: 'openAI' }, + user: { id: '123' }, + app, + }; + const res = {}; + const endpointOption = {}; + + // Mock getUserKey to return a non-JSON string + getUserKey.mockResolvedValue('not-a-json'); + + await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( + /Invalid JSON provided for openAI user values/, + ); + }); + + test('should initialize client correctly for Azure OpenAI with valid configuration', async () => { + const req = { + body: { + key: null, + endpoint: EModelEndpoint.gptPlugins, + model: modelNames[0], + }, + user: { id: '123' }, + app: { + locals: { + [EModelEndpoint.azureOpenAI]: { + plugins: true, + modelNames, + modelGroupMap, + groupMap, + }, + }, + }, + }; + const res = {}; + const endpointOption = {}; + + const client = await initializeClient({ req, res, endpointOption }); + expect(client.client.options.azure).toBeDefined(); + }); + + test('should initialize client with default options when certain env vars are not set', async () => { + delete process.env.OPENAI_SUMMARIZE; + process.env.OPENAI_API_KEY = 'some-api-key'; + + const req = { + body: { key: null, endpoint: EModelEndpoint.gptPlugins }, + user: { id: '123' }, + app, + }; + const res = {}; + const endpointOption = {}; + + const client = await initializeClient({ req, res, endpointOption }); + expect(client.client.options.contextStrategy).toBe(null); + }); + + test('should correctly use user-provided apiKey and baseURL when provided', async () => { + process.env.OPENAI_API_KEY = 'user_provided'; + process.env.OPENAI_REVERSE_PROXY = 'user_provided'; + const req = { + body: { + key: new Date(Date.now() + 10000).toISOString(), + endpoint: 'openAI', + }, + user: { + id: '123', + }, + app, + }; + const res = {}; + const endpointOption = {}; + + getUserKey.mockResolvedValue( + JSON.stringify({ apiKey: 'test', baseURL: 'https://user-provided-url.com' }), + ); + + const result = await initializeClient({ req, res, endpointOption }); + + expect(result.openAIApiKey).toBe('test'); + expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); + }); }); diff --git a/api/server/services/Endpoints/openAI/addTitle.js b/api/server/services/Endpoints/openAI/addTitle.js index ab15443f942..7bd3fc07a2c 100644 --- a/api/server/services/Endpoints/openAI/addTitle.js +++ b/api/server/services/Endpoints/openAI/addTitle.js @@ -1,5 +1,7 @@ -const { saveConvo } = require('~/models'); +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); const { isEnabled } = require('~/server/utils'); +const { saveConvo } = require('~/models'); const addTitle = async (req, { text, response, client }) => { const { TITLE_CONVO = 'true' } = process.env ?? {}; @@ -16,7 +18,11 @@ const addTitle = async (req, { text, response, client }) => { return; } + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${response.conversationId}`; + const title = await client.titleConvo({ text, responseText: response?.text }); + await titleCache.set(key, title, 120000); await saveConvo(req.user.id, { conversationId: response.conversationId, title, diff --git a/api/server/services/Endpoints/openAI/buildOptions.js b/api/server/services/Endpoints/openAI/buildOptions.js index 80037fb4b8e..0b1fb3eabd8 100644 --- a/api/server/services/Endpoints/openAI/buildOptions.js +++ b/api/server/services/Endpoints/openAI/buildOptions.js @@ -1,10 +1,10 @@ const buildOptions = (endpoint, parsedBody) => { - const { chatGptLabel, promptPrefix, resendImages, imageDetail, ...rest } = parsedBody; + const { chatGptLabel, promptPrefix, resendFiles, imageDetail, ...rest } = parsedBody; const endpointOption = { endpoint, chatGptLabel, promptPrefix, - resendImages, + resendFiles, imageDetail, modelOptions: { ...rest, diff --git a/api/server/services/Endpoints/openAI/initializeClient.js b/api/server/services/Endpoints/openAI/initializeClient.js index b6427823e12..10a541526bd 100644 --- a/api/server/services/Endpoints/openAI/initializeClient.js +++ b/api/server/services/Endpoints/openAI/initializeClient.js @@ -1,7 +1,11 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { + EModelEndpoint, + resolveHeaders, + mapModelToAzureConfig, +} = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); +const { isEnabled, isUserProvided } = require('~/server/utils'); const { getAzureCredentials } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); const { OpenAIClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption }) => { @@ -14,51 +18,89 @@ const initializeClient = async ({ req, res, endpointOption }) => { OPENAI_SUMMARIZE, DEBUG_OPENAI, } = process.env; - const { key: expiresAt, endpoint } = req.body; + const { key: expiresAt, endpoint, model: modelName } = req.body; const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null; + const credentials = { + [EModelEndpoint.openAI]: OPENAI_API_KEY, + [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, + }; + const baseURLOptions = { [EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY, [EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL, }; - const reverseProxyUrl = baseURLOptions[endpoint] ?? null; + const userProvidesKey = isUserProvided(credentials[endpoint]); + const userProvidesURL = isUserProvided(baseURLOptions[endpoint]); + + let userValues = null; + if (expiresAt && (userProvidesKey || userProvidesURL)) { + checkUserKeyExpiry( + expiresAt, + 'Your OpenAI API values have expired. Please provide them again.', + ); + userValues = await getUserKey({ userId: req.user.id, name: endpoint }); + try { + userValues = JSON.parse(userValues); + } catch (e) { + throw new Error( + `Invalid JSON provided for ${endpoint} user values. Please provide them again.`, + ); + } + } + + let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint]; + let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint]; const clientOptions = { debug: isEnabled(DEBUG_OPENAI), contextStrategy, - reverseProxyUrl, + reverseProxyUrl: baseURL ? baseURL : null, proxy: PROXY ?? null, req, res, ...endpointOption, }; - const credentials = { - [EModelEndpoint.openAI]: OPENAI_API_KEY, - [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, - }; + const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI; + /** @type {false | TAzureConfig} */ + const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI]; - const isUserProvided = credentials[endpoint] === 'user_provided'; + if (isAzureOpenAI && azureConfig) { + const { modelGroupMap, groupMap } = azureConfig; + const { + azureOptions, + baseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName, + modelGroupMap, + groupMap, + }); - let userKey = null; - if (expiresAt && isUserProvided) { - checkUserKeyExpiry( - expiresAt, - 'Your OpenAI API key has expired. Please provide your API key again.', - ); - userKey = await getUserKey({ userId: req.user.id, name: endpoint }); - } + clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; + clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); + + clientOptions.titleConvo = azureConfig.titleConvo; + clientOptions.titleModel = azureConfig.titleModel; + clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; - let apiKey = isUserProvided ? userKey : credentials[endpoint]; + const groupName = modelGroupMap[modelName].group; + clientOptions.addParams = azureConfig.groupMap[groupName].addParams; + clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; + clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; - if (endpoint === EModelEndpoint.azureOpenAI) { - clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials(); + apiKey = azureOptions.azureOpenAIApiKey; + clientOptions.azure = !serverless && azureOptions; + } else if (isAzureOpenAI) { + clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); apiKey = clientOptions.azure.azureOpenAIApiKey; } if (!apiKey) { - throw new Error('API key not provided.'); + throw new Error(`${endpoint} API key not provided. Please provide it again.`); } const client = new OpenAIClient(apiKey, clientOptions); diff --git a/api/server/services/Endpoints/openAI/initializeClient.spec.js b/api/server/services/Endpoints/openAI/initializeClient.spec.js index 03f5677441c..1a53f95b3de 100644 --- a/api/server/services/Endpoints/openAI/initializeClient.spec.js +++ b/api/server/services/Endpoints/openAI/initializeClient.spec.js @@ -1,6 +1,7 @@ -const { OpenAIClient } = require('~/app'); -const initializeClient = require('./initializeClient'); +const { EModelEndpoint, validateAzureGroups } = require('librechat-data-provider'); const { getUserKey } = require('~/server/services/UserService'); +const initializeClient = require('./initializeClient'); +const { OpenAIClient } = require('~/app'); // Mock getUserKey since it's the only function we want to mock jest.mock('~/server/services/UserService', () => ({ @@ -11,6 +12,72 @@ jest.mock('~/server/services/UserService', () => ({ describe('initializeClient', () => { // Set up environment variables const originalEnvironment = process.env; + const app = { + locals: {}, + }; + + const validAzureConfigs = [ + { + group: 'librechat-westus', + apiKey: 'WESTUS_API_KEY', + instanceName: 'librechat-westus', + version: '2023-12-01-preview', + models: { + 'gpt-4-vision-preview': { + deploymentName: 'gpt-4-vision-preview', + version: '2024-02-15-preview', + }, + 'gpt-3.5-turbo': { + deploymentName: 'gpt-35-turbo', + }, + 'gpt-3.5-turbo-1106': { + deploymentName: 'gpt-35-turbo-1106', + }, + 'gpt-4': { + deploymentName: 'gpt-4', + }, + 'gpt-4-1106-preview': { + deploymentName: 'gpt-4-1106-preview', + }, + }, + }, + { + group: 'librechat-eastus', + apiKey: 'EASTUS_API_KEY', + instanceName: 'librechat-eastus', + deploymentName: 'gpt-4-turbo', + version: '2024-02-15-preview', + models: { + 'gpt-4-turbo': true, + }, + baseURL: 'https://eastus.example.com', + additionalHeaders: { + 'x-api-key': 'x-api-key-value', + }, + }, + { + group: 'mistral-inference', + apiKey: 'AZURE_MISTRAL_API_KEY', + baseURL: + 'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions', + serverless: true, + models: { + 'mistral-large': true, + }, + }, + { + group: 'llama-70b-chat', + apiKey: 'AZURE_LLAMA2_70B_API_KEY', + baseURL: + 'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions', + serverless: true, + models: { + 'llama-70b-chat': true, + }, + }, + ]; + + const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(validAzureConfigs); beforeEach(() => { jest.resetModules(); // Clears the cache @@ -27,16 +94,17 @@ describe('initializeClient', () => { process.env.OPENAI_SUMMARIZE = 'false'; const req = { - body: { key: null, endpoint: 'openAI' }, + body: { key: null, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; - const client = await initializeClient({ req, res, endpointOption }); + const result = await initializeClient({ req, res, endpointOption }); - expect(client.openAIApiKey).toBe('test-openai-api-key'); - expect(client.client).toBeInstanceOf(OpenAIClient); + expect(result.openAIApiKey).toBe('test-openai-api-key'); + expect(result.client).toBeInstanceOf(OpenAIClient); }); test('should initialize client with Azure credentials when endpoint is azureOpenAI', async () => { @@ -53,6 +121,7 @@ describe('initializeClient', () => { const req = { body: { key: null, endpoint: 'azureOpenAI' }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'test-model' } }; @@ -68,8 +137,9 @@ describe('initializeClient', () => { process.env.DEBUG_OPENAI = 'true'; const req = { - body: { key: null, endpoint: 'openAI' }, + body: { key: null, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; @@ -84,8 +154,9 @@ describe('initializeClient', () => { process.env.OPENAI_SUMMARIZE = 'true'; const req = { - body: { key: null, endpoint: 'openAI' }, + body: { key: null, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; @@ -101,8 +172,9 @@ describe('initializeClient', () => { process.env.PROXY = 'http://proxy'; const req = { - body: { key: null, endpoint: 'openAI' }, + body: { key: null, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; @@ -121,15 +193,14 @@ describe('initializeClient', () => { const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired const req = { - body: { key: expiresAt, endpoint: 'openAI' }, + body: { key: expiresAt, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - 'Your OpenAI API key has expired. Please provide your API key again.', - ); + await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(/Your OpenAI API/); }); test('should throw an error if no API keys are provided in the environment', async () => { @@ -138,14 +209,15 @@ describe('initializeClient', () => { delete process.env.AZURE_API_KEY; const req = { - body: { key: null, endpoint: 'openAI' }, + body: { key: null, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - 'API key not provided.', + `${EModelEndpoint.openAI} API key not provided.`, ); }); @@ -154,11 +226,12 @@ describe('initializeClient', () => { const req = { body: { key: new Date(Date.now() + 10000).toISOString(), - endpoint: 'openAI', + endpoint: EModelEndpoint.openAI, }, user: { id: '123', }, + app, }; const res = {}; @@ -168,7 +241,7 @@ describe('initializeClient', () => { process.env.OPENAI_API_KEY = 'user_provided'; // Mock getUserKey to return the expected key - getUserKey.mockResolvedValue('test-user-provided-openai-api-key'); + getUserKey.mockResolvedValue(JSON.stringify({ apiKey: 'test-user-provided-openai-api-key' })); // Call the initializeClient function const result = await initializeClient({ req, res, endpointOption }); @@ -180,8 +253,9 @@ describe('initializeClient', () => { test('should throw an error if the user-provided key is invalid', async () => { const invalidKey = new Date(Date.now() - 100000).toISOString(); const req = { - body: { key: invalidKey, endpoint: 'openAI' }, + body: { key: invalidKey, endpoint: EModelEndpoint.openAI }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = {}; @@ -192,8 +266,94 @@ describe('initializeClient', () => { // Mock getUserKey to return an invalid key getUserKey.mockResolvedValue(invalidKey); + await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(/Your OpenAI API/); + }); + + test('should throw an error when user-provided values are not valid JSON', async () => { + process.env.OPENAI_API_KEY = 'user_provided'; + const req = { + body: { key: new Date(Date.now() + 10000).toISOString(), endpoint: EModelEndpoint.openAI }, + user: { id: '123' }, + app, + }; + const res = {}; + const endpointOption = {}; + + // Mock getUserKey to return a non-JSON string + getUserKey.mockResolvedValue('not-a-json'); + await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /Your OpenAI API key has expired/, + /Invalid JSON provided for openAI user values/, + ); + }); + + test('should initialize client correctly for Azure OpenAI with valid configuration', async () => { + const req = { + body: { + key: null, + endpoint: EModelEndpoint.azureOpenAI, + model: modelNames[0], + }, + user: { id: '123' }, + app: { + locals: { + [EModelEndpoint.azureOpenAI]: { + modelNames, + modelGroupMap, + groupMap, + }, + }, + }, + }; + const res = {}; + const endpointOption = {}; + + const client = await initializeClient({ req, res, endpointOption }); + expect(client.client.options.azure).toBeDefined(); + }); + + test('should initialize client with default options when certain env vars are not set', async () => { + delete process.env.DEBUG_OPENAI; + delete process.env.OPENAI_SUMMARIZE; + process.env.OPENAI_API_KEY = 'some-api-key'; + + const req = { + body: { key: null, endpoint: EModelEndpoint.openAI }, + user: { id: '123' }, + app, + }; + const res = {}; + const endpointOption = {}; + + const client = await initializeClient({ req, res, endpointOption }); + + expect(client.client.options.debug).toBe(false); + expect(client.client.options.contextStrategy).toBe(null); + }); + + test('should correctly use user-provided apiKey and baseURL when provided', async () => { + process.env.OPENAI_API_KEY = 'user_provided'; + process.env.OPENAI_REVERSE_PROXY = 'user_provided'; + const req = { + body: { + key: new Date(Date.now() + 10000).toISOString(), + endpoint: EModelEndpoint.openAI, + }, + user: { + id: '123', + }, + app, + }; + const res = {}; + const endpointOption = {}; + + getUserKey.mockResolvedValue( + JSON.stringify({ apiKey: 'test', baseURL: 'https://user-provided-url.com' }), ); + + const result = await initializeClient({ req, res, endpointOption }); + + expect(result.openAIApiKey).toBe('test'); + expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); }); }); diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 68f534bcb6d..43b5ec9b252 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -1,6 +1,11 @@ +const fs = require('fs'); +const path = require('path'); +const axios = require('axios'); const fetch = require('node-fetch'); -const { ref, uploadBytes, getDownloadURL, deleteObject } = require('firebase/storage'); +const { ref, uploadBytes, getDownloadURL, getStream, deleteObject } = require('firebase/storage'); +const { getBufferMetadata } = require('~/server/utils'); const { getFirebaseStorage } = require('./initialize'); +const { logger } = require('~/config'); /** * Deletes a file from Firebase Storage. @@ -11,7 +16,7 @@ const { getFirebaseStorage } = require('./initialize'); async function deleteFile(basePath, fileName) { const storage = getFirebaseStorage(); if (!storage) { - console.error('Firebase is not initialized. Cannot delete file from Firebase Storage.'); + logger.error('Firebase is not initialized. Cannot delete file from Firebase Storage.'); throw new Error('Firebase is not initialized'); } @@ -19,9 +24,9 @@ async function deleteFile(basePath, fileName) { try { await deleteObject(storageRef); - console.log('File deleted successfully from Firebase Storage'); + logger.debug('File deleted successfully from Firebase Storage'); } catch (error) { - console.error('Error deleting file from Firebase Storage:', error.message); + logger.error('Error deleting file from Firebase Storage:', error.message); throw error; } } @@ -41,24 +46,25 @@ async function deleteFile(basePath, fileName) { * @param {string} [params.basePath='images'] - Optional. The base basePath in Firebase Storage where the file will * be stored. Defaults to 'images' if not specified. * - * @returns {Promise<string|null>} - * A promise that resolves to the file name if the file is successfully uploaded, or null if there - * is an error in initialization or upload. + * @returns {Promise<{ bytes: number, type: string, dimensions: Record<string, number>} | null>} + * A promise that resolves to the file metadata if the file is successfully saved, or null if there is an error. */ async function saveURLToFirebase({ userId, URL, fileName, basePath = 'images' }) { const storage = getFirebaseStorage(); if (!storage) { - console.error('Firebase is not initialized. Cannot save file to Firebase Storage.'); + logger.error('Firebase is not initialized. Cannot save file to Firebase Storage.'); return null; } const storageRef = ref(storage, `${basePath}/${userId.toString()}/${fileName}`); + const response = await fetch(URL); + const buffer = await response.buffer(); try { - await uploadBytes(storageRef, await fetch(URL).then((response) => response.buffer())); - return fileName; + await uploadBytes(storageRef, buffer); + return await getBufferMetadata(buffer); } catch (error) { - console.error('Error uploading file to Firebase Storage:', error.message); + logger.error('Error uploading file to Firebase Storage:', error.message); return null; } } @@ -82,7 +88,7 @@ async function saveURLToFirebase({ userId, URL, fileName, basePath = 'images' }) async function getFirebaseURL({ fileName, basePath = 'images' }) { const storage = getFirebaseStorage(); if (!storage) { - console.error('Firebase is not initialized. Cannot get image URL from Firebase Storage.'); + logger.error('Firebase is not initialized. Cannot get image URL from Firebase Storage.'); return null; } @@ -91,7 +97,7 @@ async function getFirebaseURL({ fileName, basePath = 'images' }) { try { return await getDownloadURL(storageRef); } catch (error) { - console.error('Error fetching file URL from Firebase Storage:', error.message); + logger.error('Error fetching file URL from Firebase Storage:', error.message); return null; } } @@ -158,6 +164,18 @@ function extractFirebaseFilePath(urlString) { * Throws an error if there is an issue with deletion. */ const deleteFirebaseFile = async (req, file) => { + if (file.embedded && process.env.RAG_API_URL) { + const jwtToken = req.headers.authorization.split(' ')[1]; + axios.delete(`${process.env.RAG_API_URL}/documents`, { + headers: { + Authorization: `Bearer ${jwtToken}`, + 'Content-Type': 'application/json', + accept: 'application/json', + }, + data: [file.file_id], + }); + } + const fileName = extractFirebaseFilePath(file.filepath); if (!fileName.includes(req.user.id)) { throw new Error('Invalid file path'); @@ -165,10 +183,62 @@ const deleteFirebaseFile = async (req, file) => { await deleteFile('', fileName); }; +/** + * Uploads a file to Firebase Storage. + * + * @param {Object} params - The params object. + * @param {Express.Request} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user. + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should + * have a `path` property that points to the location of the uploaded file. + * @param {string} params.file_id - The file ID. + * + * @returns {Promise<{ filepath: string, bytes: number }>} + * A promise that resolves to an object containing: + * - filepath: The download URL of the uploaded file. + * - bytes: The size of the uploaded file in bytes. + */ +async function uploadFileToFirebase({ req, file, file_id }) { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const bytes = Buffer.byteLength(inputBuffer); + const userId = req.user.id; + + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + + const downloadURL = await saveBufferToFirebase({ userId, buffer: inputBuffer, fileName }); + + await fs.promises.unlink(inputFilePath); + + return { filepath: downloadURL, bytes }; +} + +/** + * Retrieves a readable stream for a file from Firebase storage. + * + * @param {string} filepath - The filepath. + * @returns {ReadableStream} A readable stream of the file. + */ +function getFirebaseFileStream(filepath) { + try { + const storage = getFirebaseStorage(); + if (!storage) { + throw new Error('Firebase is not initialized'); + } + const fileRef = ref(storage, filepath); + return getStream(fileRef); + } catch (error) { + logger.error('Error getting Firebase file stream:', error); + throw error; + } +} + module.exports = { deleteFile, getFirebaseURL, saveURLToFirebase, deleteFirebaseFile, + uploadFileToFirebase, saveBufferToFirebase, + getFirebaseFileStream, }; diff --git a/api/server/services/Files/Firebase/images.js b/api/server/services/Files/Firebase/images.js index 95b600962f6..f06718063c9 100644 --- a/api/server/services/Files/Firebase/images.js +++ b/api/server/services/Files/Firebase/images.js @@ -1,7 +1,8 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); -const { resizeImage } = require('../images/resize'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); const { saveBufferToFirebase } = require('./crud'); const { updateFile } = require('~/models/File'); const { logger } = require('~/config'); @@ -10,12 +11,13 @@ const { logger } = require('~/config'); * Converts an image file to the WebP format. The function first resizes the image based on the specified * resolution. * - * - * @param {Object} req - The request object from Express. It should have a `user` property with an `id` + * @param {Object} params - The params object. + * @param {Express.Request} params.req - The request object from Express. It should have a `user` property with an `id` * representing the user, and an `app.locals.paths` object with an `imageOutput` path. - * @param {Express.Multer.File} file - The file object, which is part of the request. The file object should + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should * have a `path` property that points to the location of the uploaded file. - * @param {string} [resolution='high'] - Optional. The desired resolution for the image resizing. Default is 'high'. + * @param {EModelEndpoint} params.endpoint - The params object. + * @param {string} [params.resolution='high'] - Optional. The desired resolution for the image resizing. Default is 'high'. * * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number}>} * A promise that resolves to an object containing: @@ -24,14 +26,19 @@ const { logger } = require('~/config'); * - width: The width of the converted image. * - height: The height of the converted image. */ -async function uploadImageToFirebase(req, file, resolution = 'high') { +async function uploadImageToFirebase({ req, file, file_id, endpoint, resolution = 'high' }) { const inputFilePath = file.path; - const { buffer: resizedBuffer, width, height } = await resizeImage(inputFilePath, resolution); + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); const extension = path.extname(inputFilePath); const userId = req.user.id; let webPBuffer; - let fileName = path.basename(inputFilePath); + let fileName = `${file_id}__${path.basename(inputFilePath)}`; if (extension.toLowerCase() === '.webp') { webPBuffer = resizedBuffer; } else { @@ -73,15 +80,15 @@ async function prepareImageURL(req, file) { * * @param {object} params - The parameters object. * @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format. - * @param {object} params.User - The User document (mongoose); TODO: remove direct use of Model, `User` + * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). * @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processFirebaseAvatar({ buffer, User, manual }) { +async function processFirebaseAvatar({ buffer, userId, manual }) { try { const downloadURL = await saveBufferToFirebase({ - userId: User._id.toString(), + userId, buffer, fileName: 'avatar.png', }); @@ -91,8 +98,7 @@ async function processFirebaseAvatar({ buffer, User, manual }) { const url = `${downloadURL}?manual=${isManual}`; if (isManual) { - User.avatar = url; - await User.save(); + await updateUser(userId, { avatar: url }); } return url; diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index d81c063031a..18bf5127fd4 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -1,8 +1,9 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); -const { logger } = require('~/config'); +const { getBufferMetadata } = require('~/server/utils'); const paths = require('~/config/paths'); +const { logger } = require('~/config'); /** * Saves a file to a specified output path with a new filename. @@ -13,7 +14,7 @@ const paths = require('~/config/paths'); * @returns {Promise<string>} The full path of the saved file. * @throws Will throw an error if the file saving process fails. */ -async function saveFile(file, outputPath, outputFilename) { +async function saveLocalFile(file, outputPath, outputFilename) { try { if (!fs.existsSync(outputPath)) { fs.mkdirSync(outputPath, { recursive: true }); @@ -44,9 +45,41 @@ async function saveFile(file, outputPath, outputFilename) { const saveLocalImage = async (req, file, filename) => { const imagePath = req.app.locals.paths.imageOutput; const outputPath = path.join(imagePath, req.user.id ?? ''); - await saveFile(file, outputPath, filename); + await saveLocalFile(file, outputPath, filename); }; +/** + * Saves a buffer to a specified directory on the local file system. + * + * @param {Object} params - The parameters object. + * @param {string} params.userId - The user's unique identifier. This is used to create a user-specific directory. + * @param {Buffer} params.buffer - The buffer to be saved. + * @param {string} params.fileName - The name of the file to be saved. + * @param {string} [params.basePath='images'] - Optional. The base path where the file will be stored. + * Defaults to 'images' if not specified. + * @returns {Promise<string>} - A promise that resolves to the path of the saved file. + */ +async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' }) { + try { + const { publicPath, uploads } = paths; + + const directoryPath = path.join(basePath === 'images' ? publicPath : uploads, basePath, userId); + + if (!fs.existsSync(directoryPath)) { + fs.mkdirSync(directoryPath, { recursive: true }); + } + + fs.writeFileSync(path.join(directoryPath, fileName), buffer); + + const filePath = path.posix.join('/', basePath, userId, fileName); + + return filePath; + } catch (error) { + logger.error('[saveLocalBuffer] Error while saving the buffer:', error); + throw error; + } +} + /** * Saves a file from a given URL to a local directory. The function fetches the file using the provided URL, * determines the content type, and saves it to a specified local directory with the correct file extension. @@ -62,20 +95,18 @@ const saveLocalImage = async (req, file, filename) => { * @param {string} [params.basePath='images'] - Optional. The base directory where the file will be saved. * Defaults to 'images' if not specified. * - * @returns {Promise<string|null>} - * A promise that resolves to the file name if the file is successfully saved, or null if there is an error. + * @returns {Promise<{ bytes: number, type: string, dimensions: Record<string, number>} | null>} + * A promise that resolves to the file metadata if the file is successfully saved, or null if there is an error. */ async function saveFileFromURL({ userId, URL, fileName, basePath = 'images' }) { try { - // Fetch the file from the URL const response = await axios({ url: URL, - responseType: 'stream', + responseType: 'arraybuffer', }); - // Get the content type from the response headers - const contentType = response.headers['content-type']; - let extension = contentType.split('/').pop(); + const buffer = Buffer.from(response.data, 'binary'); + const { bytes, type, dimensions, extension } = await getBufferMetadata(buffer); // Construct the outputPath based on the basePath and userId const outputPath = path.join(paths.publicPath, basePath, userId.toString()); @@ -92,17 +123,15 @@ async function saveFileFromURL({ userId, URL, fileName, basePath = 'images' }) { fileName += `.${extension}`; } - // Create a writable stream for the output path + // Save the file to the output path const outputFilePath = path.join(outputPath, fileName); - const writer = fs.createWriteStream(outputFilePath); + fs.writeFileSync(outputFilePath, buffer); - // Pipe the response data to the output file - response.data.pipe(writer); - - return new Promise((resolve, reject) => { - writer.on('finish', () => resolve(fileName)); - writer.on('error', reject); - }); + return { + bytes, + type, + dimensions, + }; } catch (error) { logger.error('[saveFileFromURL] Error while saving the file:', error); return null; @@ -159,7 +188,26 @@ const isValidPath = (req, base, subfolder, filepath) => { * file path is invalid or if there is an error in deletion. */ const deleteLocalFile = async (req, file) => { - const { publicPath } = req.app.locals.paths; + const { publicPath, uploads } = req.app.locals.paths; + if (file.embedded && process.env.RAG_API_URL) { + const jwtToken = req.headers.authorization.split(' ')[1]; + axios.delete(`${process.env.RAG_API_URL}/documents`, { + headers: { + Authorization: `Bearer ${jwtToken}`, + 'Content-Type': 'application/json', + accept: 'application/json', + }, + data: [file.file_id], + }); + } + + if (file.filepath.startsWith(`/uploads/${req.user.id}`)) { + const basePath = file.filepath.split('/uploads/')[1]; + const filepath = path.join(uploads, basePath); + await fs.promises.unlink(filepath); + return; + } + const parts = file.filepath.split(path.sep); const subfolder = parts[1]; const filepath = path.join(publicPath, file.filepath); @@ -171,4 +219,64 @@ const deleteLocalFile = async (req, file) => { await fs.promises.unlink(filepath); }; -module.exports = { saveFile, saveLocalImage, saveFileFromURL, getLocalFileURL, deleteLocalFile }; +/** + * Uploads a file to the specified upload directory. + * + * @param {Object} params - The params object. + * @param {Object} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user, and an `app.locals.paths` object with an `uploads` path. + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should + * have a `path` property that points to the location of the uploaded file. + * @param {string} params.file_id - The file ID. + * + * @returns {Promise<{ filepath: string, bytes: number }>} + * A promise that resolves to an object containing: + * - filepath: The path where the file is saved. + * - bytes: The size of the file in bytes. + */ +async function uploadLocalFile({ req, file, file_id }) { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const bytes = Buffer.byteLength(inputBuffer); + + const { uploads } = req.app.locals.paths; + const userPath = path.join(uploads, req.user.id); + + if (!fs.existsSync(userPath)) { + fs.mkdirSync(userPath, { recursive: true }); + } + + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const newPath = path.join(userPath, fileName); + + await fs.promises.writeFile(newPath, inputBuffer); + const filepath = path.posix.join('/', 'uploads', req.user.id, path.basename(newPath)); + + return { filepath, bytes }; +} + +/** + * Retrieves a readable stream for a file from local storage. + * + * @param {string} filepath - The filepath. + * @returns {ReadableStream} A readable stream of the file. + */ +function getLocalFileStream(filepath) { + try { + return fs.createReadStream(filepath); + } catch (error) { + logger.error('Error getting local file stream:', error); + throw error; + } +} + +module.exports = { + saveLocalFile, + saveLocalImage, + saveLocalBuffer, + saveFileFromURL, + getLocalFileURL, + deleteLocalFile, + uploadLocalFile, + getLocalFileStream, +}; diff --git a/api/server/services/Files/Local/images.js b/api/server/services/Files/Local/images.js index 63ed5b2f64b..4d5b9565f1f 100644 --- a/api/server/services/Files/Local/images.js +++ b/api/server/services/Files/Local/images.js @@ -1,7 +1,8 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); -const { resizeImage } = require('../images/resize'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); const { updateFile } = require('~/models/File'); /** @@ -12,12 +13,14 @@ const { updateFile } = require('~/models/File'); * it converts the image to WebP format before saving. * * The original image is deleted after conversion. - * - * @param {Object} req - The request object from Express. It should have a `user` property with an `id` + * @param {Object} params - The params object. + * @param {Object} params.req - The request object from Express. It should have a `user` property with an `id` * representing the user, and an `app.locals.paths` object with an `imageOutput` path. - * @param {Express.Multer.File} file - The file object, which is part of the request. The file object should + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should * have a `path` property that points to the location of the uploaded file. - * @param {string} [resolution='high'] - Optional. The desired resolution for the image resizing. Default is 'high'. + * @param {string} params.file_id - The file ID. + * @param {EModelEndpoint} params.endpoint - The params object. + * @param {string} [params.resolution='high'] - Optional. The desired resolution for the image resizing. Default is 'high'. * * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number}>} * A promise that resolves to an object containing: @@ -26,9 +29,14 @@ const { updateFile } = require('~/models/File'); * - width: The width of the converted image. * - height: The height of the converted image. */ -async function uploadLocalImage(req, file, resolution = 'high') { +async function uploadLocalImage({ req, file, file_id, endpoint, resolution = 'high' }) { const inputFilePath = file.path; - const { buffer: resizedBuffer, width, height } = await resizeImage(inputFilePath, resolution); + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); const extension = path.extname(inputFilePath); const { imageOutput } = req.app.locals.paths; @@ -38,7 +46,8 @@ async function uploadLocalImage(req, file, resolution = 'high') { fs.mkdirSync(userPath, { recursive: true }); } - const newPath = path.join(userPath, path.basename(inputFilePath)); + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const newPath = path.join(userPath, fileName); if (extension.toLowerCase() === '.webp') { const bytes = Buffer.byteLength(resizedBuffer); @@ -96,17 +105,17 @@ async function prepareImagesLocal(req, file) { } /** - * Uploads a user's avatar to Firebase Storage and returns the URL. + * Uploads a user's avatar to local server storage and returns the URL. * If the 'manual' flag is set to 'true', it also updates the user's avatar URL in the database. * * @param {object} params - The parameters object. * @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format. - * @param {object} params.User - The User document (mongoose); TODO: remove direct use of Model, `User` + * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). * @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processLocalAvatar({ buffer, User, manual }) { +async function processLocalAvatar({ buffer, userId, manual }) { const userDir = path.resolve( __dirname, '..', @@ -117,10 +126,11 @@ async function processLocalAvatar({ buffer, User, manual }) { 'client', 'public', 'images', - User._id.toString(), + userId, ); + const fileName = `avatar-${new Date().getTime()}.png`; - const urlRoute = `/images/${User._id.toString()}/${fileName}`; + const urlRoute = `/images/${userId}/${fileName}`; const avatarPath = path.join(userDir, fileName); await fs.promises.mkdir(userDir, { recursive: true }); @@ -130,8 +140,7 @@ async function processLocalAvatar({ buffer, User, manual }) { let url = `${urlRoute}?manual=${isManual}`; if (isManual) { - User.avatar = url; - await User.save(); + await updateUser(userId, { avatar: url }); } return url; diff --git a/api/server/services/Files/OpenAI/crud.js b/api/server/services/Files/OpenAI/crud.js new file mode 100644 index 00000000000..346259e8215 --- /dev/null +++ b/api/server/services/Files/OpenAI/crud.js @@ -0,0 +1,79 @@ +const fs = require('fs'); +const { FilePurpose } = require('librechat-data-provider'); +const { sleep } = require('~/server/utils'); +const { logger } = require('~/config'); + +/** + * Uploads a file that can be used across various OpenAI services. + * + * @param {Object} params - The params object. + * @param {Express.Request} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user, and an `app.locals.paths` object with an `imageOutput` path. + * @param {Express.Multer.File} params.file - The file uploaded to the server via multer. + * @param {OpenAIClient} params.openai - The initialized OpenAI client. + * @returns {Promise<OpenAIFile>} + */ +async function uploadOpenAIFile({ req, file, openai }) { + const uploadedFile = await openai.files.create({ + file: fs.createReadStream(file.path), + purpose: FilePurpose.Assistants, + }); + + logger.debug( + `[uploadOpenAIFile] User ${req.user.id} successfully uploaded file to OpenAI`, + uploadedFile, + ); + + if (uploadedFile.status !== 'processed') { + const sleepTime = 2500; + logger.debug( + `[uploadOpenAIFile] File ${ + uploadedFile.id + } is not yet processed. Waiting for it to be processed (${sleepTime / 1000}s)...`, + ); + await sleep(sleepTime); + } + + return uploadedFile; +} + +/** + * Deletes a file previously uploaded to OpenAI. + * + * @param {Express.Request} req - The request object from Express. + * @param {MongoFile} file - The database representation of the uploaded file. + * @param {OpenAI} openai - The initialized OpenAI client. + * @returns {Promise<void>} + */ +async function deleteOpenAIFile(req, file, openai) { + try { + const res = await openai.files.del(file.file_id); + if (!res.deleted) { + throw new Error('OpenAI returned `false` for deleted status'); + } + logger.debug( + `[deleteOpenAIFile] User ${req.user.id} successfully deleted ${file.file_id} from OpenAI`, + ); + } catch (error) { + logger.error('[deleteOpenAIFile] Error deleting file from OpenAI: ' + error.message); + throw error; + } +} + +/** + * Retrieves a readable stream for a file from local storage. + * + * @param {string} file_id - The file_id. + * @param {OpenAI} openai - The initialized OpenAI client. + * @returns {Promise<ReadableStream>} A readable stream of the file. + */ +async function getOpenAIFileStream(file_id, openai) { + try { + return await openai.files.content(file_id); + } catch (error) { + logger.error('Error getting OpenAI file download stream:', error); + throw error; + } +} + +module.exports = { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream }; diff --git a/api/server/services/Files/OpenAI/index.js b/api/server/services/Files/OpenAI/index.js new file mode 100644 index 00000000000..a6223d1ee5d --- /dev/null +++ b/api/server/services/Files/OpenAI/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js new file mode 100644 index 00000000000..c9a8c315834 --- /dev/null +++ b/api/server/services/Files/VectorDB/crud.js @@ -0,0 +1,102 @@ +const fs = require('fs'); +const axios = require('axios'); +const FormData = require('form-data'); +const { FileSources } = require('librechat-data-provider'); +const { logger } = require('~/config'); + +/** + * Deletes a file from the vector database. This function takes a file object, constructs the full path, and + * verifies the path's validity before deleting the file. If the path is invalid, an error is thrown. + * + * @param {Express.Request} req - The request object from Express. It should have an `app.locals.paths` object with + * a `publicPath` property. + * @param {MongoFile} file - The file object to be deleted. It should have a `filepath` property that is + * a string representing the path of the file relative to the publicPath. + * + * @returns {Promise<void>} + * A promise that resolves when the file has been successfully deleted, or throws an error if the + * file path is invalid or if there is an error in deletion. + */ +const deleteVectors = async (req, file) => { + if (!file.embedded || !process.env.RAG_API_URL) { + return; + } + try { + const jwtToken = req.headers.authorization.split(' ')[1]; + return await axios.delete(`${process.env.RAG_API_URL}/documents`, { + headers: { + Authorization: `Bearer ${jwtToken}`, + 'Content-Type': 'application/json', + accept: 'application/json', + }, + data: [file.file_id], + }); + } catch (error) { + logger.error('Error deleting vectors', error); + throw new Error(error.message || 'An error occurred during file deletion.'); + } +}; + +/** + * Uploads a file to the configured Vector database + * + * @param {Object} params - The params object. + * @param {Object} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user, and an `app.locals.paths` object with an `uploads` path. + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should + * have a `path` property that points to the location of the uploaded file. + * @param {string} params.file_id - The file ID. + * + * @returns {Promise<{ filepath: string, bytes: number }>} + * A promise that resolves to an object containing: + * - filepath: The path where the file is saved. + * - bytes: The size of the file in bytes. + */ +async function uploadVectors({ req, file, file_id }) { + if (!process.env.RAG_API_URL) { + throw new Error('RAG_API_URL not defined'); + } + + try { + const jwtToken = req.headers.authorization.split(' ')[1]; + const formData = new FormData(); + formData.append('file_id', file_id); + formData.append('file', fs.createReadStream(file.path)); + + const formHeaders = formData.getHeaders(); // Automatically sets the correct Content-Type + + const response = await axios.post(`${process.env.RAG_API_URL}/embed`, formData, { + headers: { + Authorization: `Bearer ${jwtToken}`, + accept: 'application/json', + ...formHeaders, + }, + }); + + const responseData = response.data; + logger.debug('Response from embedding file', responseData); + + if (responseData.known_type === false) { + throw new Error(`File embedding failed. The filetype ${file.mimetype} is not supported`); + } + + if (!responseData.status) { + throw new Error('File embedding failed.'); + } + + return { + bytes: file.size, + filename: file.originalname, + filepath: FileSources.vectordb, + embedded: Boolean(responseData.known_type), + }; + } catch (error) { + logger.error('Error embedding file', error); + throw new Error(error.message || 'An error occurred during file upload.'); + } +} + +module.exports = { + deleteVectors, + uploadVectors, +}; diff --git a/api/server/services/Files/VectorDB/index.js b/api/server/services/Files/VectorDB/index.js new file mode 100644 index 00000000000..a6223d1ee5d --- /dev/null +++ b/api/server/services/Files/VectorDB/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/images/avatar.js b/api/server/services/Files/images/avatar.js index 490fc86171d..8f4f65b8e29 100644 --- a/api/server/services/Files/images/avatar.js +++ b/api/server/services/Files/images/avatar.js @@ -1,42 +1,29 @@ const sharp = require('sharp'); const fs = require('fs').promises; const fetch = require('node-fetch'); -const User = require('~/models/User'); -const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAndConvert } = require('./resize'); const { logger } = require('~/config'); -async function convertToWebP(inputBuffer) { - return sharp(inputBuffer).resize({ width: 150 }).toFormat('webp').toBuffer(); -} - /** * Uploads an avatar image for a user. This function can handle various types of input (URL, Buffer, or File object), - * processes the image to a square format, converts it to WebP format, and then uses a specified file strategy for - * further processing. It performs validation on the user ID and the input type. The function can throw errors for - * invalid input types, fetching issues, or other processing errors. + * processes the image to a square format, converts it to WebP format, and returns the resized buffer. * * @param {Object} params - The parameters object. * @param {string} params.userId - The unique identifier of the user for whom the avatar is being uploaded. - * @param {FileSources} params.fileStrategy - The file handling strategy to use, determining how the avatar is processed. * @param {(string|Buffer|File)} params.input - The input representing the avatar image. Can be a URL (string), * a Buffer, or a File object. - * @param {string} params.manual - A string flag indicating whether the upload process is manual. * * @returns {Promise<any>} - * A promise that resolves to the result of the `processAvatar` function, specific to the chosen file - * strategy. Throws an error if any step in the process fails. + * A promise that resolves to a resized buffer. * * @throws {Error} Throws an error if the user ID is undefined, the input type is invalid, the image fetching fails, * or any other error occurs during the processing. */ -async function uploadAvatar({ userId, fileStrategy, input, manual }) { +async function resizeAvatar({ userId, input }) { try { if (userId === undefined) { throw new Error('User ID is undefined'); } - const _id = userId; - // TODO: remove direct use of Model, `User` - const oldUser = await User.findOne({ _id }); let imageBuffer; if (typeof input === 'string') { @@ -66,13 +53,12 @@ async function uploadAvatar({ userId, fileStrategy, input, manual }) { }) .toBuffer(); - const webPBuffer = await convertToWebP(squaredBuffer); - const { processAvatar } = getStrategyFunctions(fileStrategy); - return await processAvatar({ buffer: webPBuffer, User: oldUser, manual }); + const { buffer } = await resizeAndConvert(squaredBuffer); + return buffer; } catch (error) { logger.error('Error uploading the avatar:', error); throw error; } } -module.exports = uploadAvatar; +module.exports = { resizeAvatar }; diff --git a/api/server/services/Files/images/convert.js b/api/server/services/Files/images/convert.js new file mode 100644 index 00000000000..744e591717f --- /dev/null +++ b/api/server/services/Files/images/convert.js @@ -0,0 +1,70 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('./resize'); +const { getStrategyFunctions } = require('../strategies'); +const { logger } = require('~/config'); + +/** + * Converts an image file or buffer to WebP format with specified resolution. + * + * @param {Express.Request} req - The request object, containing user and app configuration data. + * @param {Buffer | Express.Multer.File} file - The file object, containing either a path or a buffer. + * @param {'low' | 'high'} [resolution='high'] - The desired resolution for the output image. + * @param {string} [basename=''] - The basename of the input file, if it is a buffer. + * @returns {Promise<{filepath: string, bytes: number, width: number, height: number}>} An object containing the path, size, and dimensions of the converted image. + * @throws Throws an error if there is an issue during the conversion process. + */ +async function convertToWebP(req, file, resolution = 'high', basename = '') { + try { + let inputBuffer; + let outputBuffer; + let extension = path.extname(file.path ?? basename).toLowerCase(); + + // Check if the input is a buffer or a file path + if (Buffer.isBuffer(file)) { + inputBuffer = file; + } else if (file && file.path) { + const inputFilePath = file.path; + inputBuffer = await fs.promises.readFile(inputFilePath); + } else { + throw new Error('Invalid input: file must be a buffer or contain a valid path.'); + } + + // Resize the image buffer + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution); + + // Check if the file is already in WebP format + // If it isn't, convert it: + if (extension === '.webp') { + outputBuffer = resizedBuffer; + } else { + outputBuffer = await sharp(resizedBuffer).toFormat('webp').toBuffer(); + extension = '.webp'; + } + + // Generate a new filename for the output file + const newFileName = + path.basename(file.path ?? basename, path.extname(file.path ?? basename)) + extension; + + const { saveBuffer } = getStrategyFunctions(req.app.locals.fileStrategy); + + const savedFilePath = await saveBuffer({ + userId: req.user.id, + buffer: outputBuffer, + fileName: newFileName, + }); + + const bytes = Buffer.byteLength(outputBuffer); + return { filepath: savedFilePath, bytes, width, height }; + } catch (err) { + logger.error(err); + throw err; + } +} + +module.exports = { convertToWebP }; diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 809ec0e8401..ade39ac2e72 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -1,5 +1,29 @@ +const axios = require('axios'); const { EModelEndpoint, FileSources } = require('librechat-data-provider'); const { getStrategyFunctions } = require('../strategies'); +const { logger } = require('~/config'); + +/** + * Fetches an image from a URL and returns its base64 representation. + * + * @async + * @param {string} url The URL of the image. + * @returns {Promise<string>} The base64-encoded string of the image. + * @throws {Error} If there's an issue fetching the image or encoding it. + */ +async function fetchImageToBase64(url) { + try { + const response = await axios.get(url, { + responseType: 'arraybuffer', + }); + return Buffer.from(response.data).toString('base64'); + } catch (error) { + logger.error('Error fetching image to convert to base64', error); + throw error; + } +} + +const base64Only = new Set([EModelEndpoint.google, EModelEndpoint.anthropic]); /** * Encodes and formats the given files. @@ -15,18 +39,29 @@ async function encodeAndFormat(req, files, endpoint) { for (let file of files) { const source = file.source ?? FileSources.local; - if (encodingMethods[source]) { - promises.push(encodingMethods[source](req, file)); + if (!file.height) { + promises.push([file, null]); continue; } - const { prepareImagePayload } = getStrategyFunctions(source); - if (!prepareImagePayload) { - throw new Error(`Encoding function not implemented for ${source}`); + if (!encodingMethods[source]) { + const { prepareImagePayload } = getStrategyFunctions(source); + if (!prepareImagePayload) { + throw new Error(`Encoding function not implemented for ${source}`); + } + + encodingMethods[source] = prepareImagePayload; } - encodingMethods[source] = prepareImagePayload; - promises.push(prepareImagePayload(req, file)); + const preparePayload = encodingMethods[source]; + + /* Google & Anthropic don't support passing URLs to payload */ + if (source !== FileSources.local && base64Only.has(endpoint)) { + const [_file, imageURL] = await preparePayload(req, file); + promises.push([_file, await fetchImageToBase64(imageURL)]); + continue; + } + promises.push(preparePayload(req, file)); } const detail = req.body.imageDetail ?? 'auto'; @@ -40,6 +75,24 @@ async function encodeAndFormat(req, files, endpoint) { }; for (const [file, imageContent] of formattedImages) { + const fileMetadata = { + type: file.type, + file_id: file.file_id, + filepath: file.filepath, + filename: file.filename, + embedded: !!file.embedded, + }; + + if (file.height && file.width) { + fileMetadata.height = file.height; + fileMetadata.width = file.width; + } + + if (!imageContent) { + result.files.push(fileMetadata); + continue; + } + const imagePart = { type: 'image_url', image_url: { @@ -52,18 +105,18 @@ async function encodeAndFormat(req, files, endpoint) { if (endpoint && endpoint === EModelEndpoint.google) { imagePart.image_url = imagePart.image_url.url; + } else if (endpoint && endpoint === EModelEndpoint.anthropic) { + imagePart.type = 'image'; + imagePart.source = { + type: 'base64', + media_type: file.type, + data: imageContent, + }; + delete imagePart.image_url; } result.image_urls.push(imagePart); - - result.files.push({ - file_id: file.file_id, - // filepath: file.filepath, - // filename: file.filename, - // type: file.type, - // height: file.height, - // width: file.width, - }); + result.files.push(fileMetadata); } return result; } diff --git a/api/server/services/Files/images/index.js b/api/server/services/Files/images/index.js index 1438887e6d1..889b19f2060 100644 --- a/api/server/services/Files/images/index.js +++ b/api/server/services/Files/images/index.js @@ -1,13 +1,13 @@ const avatar = require('./avatar'); +const convert = require('./convert'); const encode = require('./encode'); const parse = require('./parse'); const resize = require('./resize'); -const validate = require('./validate'); module.exports = { + ...convert, ...encode, ...parse, ...resize, - ...validate, avatar, }; diff --git a/api/server/services/Files/images/resize.js b/api/server/services/Files/images/resize.js index dd6f24ceeab..ac05ba99463 100644 --- a/api/server/services/Files/images/resize.js +++ b/api/server/services/Files/images/resize.js @@ -1,9 +1,21 @@ const sharp = require('sharp'); +const { EModelEndpoint } = require('librechat-data-provider'); -async function resizeImage(inputFilePath, resolution) { +/** + * Resizes an image from a given buffer based on the specified resolution. + * + * @param {Buffer} inputBuffer - The buffer of the image to be resized. + * @param {'low' | 'high'} resolution - The resolution to resize the image to. + * 'low' for a maximum of 512x512 resolution, + * 'high' for a maximum of 768x2000 resolution. + * @param {EModelEndpoint} endpoint - Identifier for specific endpoint handling + * @returns {Promise<{buffer: Buffer, width: number, height: number}>} An object containing the resized image buffer and its dimensions. + * @throws Will throw an error if the resolution parameter is invalid. + */ +async function resizeImageBuffer(inputBuffer, resolution, endpoint) { const maxLowRes = 512; const maxShortSideHighRes = 768; - const maxLongSideHighRes = 2000; + const maxLongSideHighRes = endpoint === EModelEndpoint.anthropic ? 1568 : 2000; let newWidth, newHeight; let resizeOptions = { fit: 'inside', withoutEnlargement: true }; @@ -12,7 +24,7 @@ async function resizeImage(inputFilePath, resolution) { resizeOptions.width = maxLowRes; resizeOptions.height = maxLowRes; } else if (resolution === 'high') { - const metadata = await sharp(inputFilePath).metadata(); + const metadata = await sharp(inputBuffer).metadata(); const isWidthShorter = metadata.width < metadata.height; if (isWidthShorter) { @@ -43,10 +55,28 @@ async function resizeImage(inputFilePath, resolution) { throw new Error('Invalid resolution parameter'); } - const resizedBuffer = await sharp(inputFilePath).rotate().resize(resizeOptions).toBuffer(); + const resizedBuffer = await sharp(inputBuffer).rotate().resize(resizeOptions).toBuffer(); const resizedMetadata = await sharp(resizedBuffer).metadata(); return { buffer: resizedBuffer, width: resizedMetadata.width, height: resizedMetadata.height }; } -module.exports = { resizeImage }; +/** + * Resizes an image buffer to webp format as well as reduces by specified or default 150 px width. + * + * @param {Buffer} inputBuffer - The buffer of the image to be resized. + * @returns {Promise<{ buffer: Buffer, width: number, height: number, bytes: number }>} An object containing the resized image buffer, its size and dimensions. + * @throws Will throw an error if the resolution parameter is invalid. + */ +async function resizeAndConvert(inputBuffer, width = 150) { + const resizedBuffer = await sharp(inputBuffer).resize({ width }).toFormat('webp').toBuffer(); + const resizedMetadata = await sharp(resizedBuffer).metadata(); + return { + buffer: resizedBuffer, + width: resizedMetadata.width, + height: resizedMetadata.height, + bytes: Buffer.byteLength(resizedBuffer), + }; +} + +module.exports = { resizeImageBuffer, resizeAndConvert }; diff --git a/api/server/services/Files/images/validate.js b/api/server/services/Files/images/validate.js deleted file mode 100644 index 97ae73cf91a..00000000000 --- a/api/server/services/Files/images/validate.js +++ /dev/null @@ -1,13 +0,0 @@ -const { visionModels } = require('librechat-data-provider'); - -function validateVisionModel(model) { - if (!model) { - return false; - } - - return visionModels.some((visionModel) => model.includes(visionModel)); -} - -module.exports = { - validateVisionModel, -}; diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 4ee9510b4f1..66a5e454996 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -1,5 +1,23 @@ -const { updateFileUsage, createFile } = require('~/models'); +const path = require('path'); +const mime = require('mime'); +const { v4 } = require('uuid'); +const { + isUUID, + megabyte, + FileContext, + FileSources, + imageExtRegex, + EModelEndpoint, + mergeFileConfig, + hostImageIdSuffix, + hostImageNamePrefix, +} = require('librechat-data-provider'); +const { convertToWebP, resizeAndConvert } = require('~/server/services/Files/images'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); +const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); +const { determineFileType } = require('~/server/utils'); const { logger } = require('~/config'); const processFiles = async (files) => { @@ -13,6 +31,99 @@ const processFiles = async (files) => { return await Promise.all(promises); }; +/** + * Enqueues the delete operation to the leaky bucket queue if necessary, or adds it directly to promises. + * + * @param {Express.Request} req - The express request object. + * @param {MongoFile} file - The file object to delete. + * @param {Function} deleteFile - The delete file function. + * @param {Promise[]} promises - The array of promises to await. + * @param {OpenAI | undefined} [openai] - If an OpenAI file, the initialized OpenAI client. + */ +function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { + if (file.source === FileSources.openai) { + // Enqueue to leaky bucket + promises.push( + new Promise((resolve, reject) => { + LB_QueueAsyncCall( + () => deleteFile(req, file, openai), + [], + (err, result) => { + if (err) { + logger.error('Error deleting file from OpenAI source', err); + reject(err); + } else { + resolve(result); + } + }, + ); + }), + ); + } else { + // Add directly to promises + promises.push( + deleteFile(req, file).catch((err) => { + logger.error('Error deleting file', err); + return Promise.reject(err); + }), + ); + } +} + +// TODO: refactor as currently only image files can be deleted this way +// as other filetypes will not reside in public path +/** + * Deletes a list of files from the server filesystem and the database. + * + * @param {Object} params - The params object. + * @param {MongoFile[]} params.files - The file objects to delete. + * @param {Express.Request} params.req - The express request object. + * @param {DeleteFilesBody} params.req.body - The request body. + * @param {string} [params.req.body.assistant_id] - The assistant ID if file uploaded is associated to an assistant. + * + * @returns {Promise<void>} + */ +const processDeleteRequest = async ({ req, files }) => { + const file_ids = files.map((file) => file.file_id); + + const deletionMethods = {}; + const promises = []; + promises.push(deleteFiles(file_ids)); + + /** @type {OpenAI | undefined} */ + let openai; + if (req.body.assistant_id) { + ({ openai } = await initializeClient({ req })); + } + + for (const file of files) { + const source = file.source ?? FileSources.local; + + if (source === FileSources.openai && !openai) { + ({ openai } = await initializeClient({ req })); + } + + if (req.body.assistant_id) { + promises.push(openai.beta.assistants.files.del(req.body.assistant_id, file.file_id)); + } + + if (deletionMethods[source]) { + enqueueDeleteOperation(req, file, deletionMethods[source], promises, openai); + continue; + } + + const { deleteFile } = getStrategyFunctions(source); + if (!deleteFile) { + throw new Error(`Delete function not implemented for ${source}`); + } + + deletionMethods[source] = deleteFile; + enqueueDeleteOperation(req, file, deleteFile, promises, openai); + } + + await Promise.allSettled(promises); +}; + /** * Processes a file URL using a specified file handling strategy. This function accepts a strategy name, * fetches the corresponding file processing functions (for saving and retrieving file URLs), and then @@ -21,25 +132,42 @@ const processFiles = async (files) => { * exception with an appropriate message. * * @param {Object} params - The parameters object. - * @param {FileSources} params.fileStrategy - The file handling strategy to use. Must be a value from the - * `FileSources` enum, which defines different file handling - * strategies (like saving to Firebase, local storage, etc.). + * @param {FileSources} params.fileStrategy - The file handling strategy to use. + * Must be a value from the `FileSources` enum, which defines different file + * handling strategies (like saving to Firebase, local storage, etc.). * @param {string} params.userId - The user's unique identifier. Used for creating user-specific paths or - * references in the file handling process. + * references in the file handling process. * @param {string} params.URL - The URL of the file to be processed. - * @param {string} params.fileName - The name that will be used to save the file. This should include the - * file extension. + * @param {string} params.fileName - The name that will be used to save the file (including extension) * @param {string} params.basePath - The base path or directory where the file will be saved or retrieved from. - * - * @returns {Promise<string>} - * A promise that resolves to the URL of the processed file. It throws an error if the file processing - * fails at any stage. + * @param {FileContext} params.context - The context of the file (e.g., 'avatar', 'image_generation', etc.) + * @returns {Promise<MongoFile>} A promise that resolves to the DB representation (MongoFile) + * of the processed file. It throws an error if the file processing fails at any stage. */ -const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) => { +const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath, context }) => { const { saveURL, getFileURL } = getStrategyFunctions(fileStrategy); try { - await saveURL({ userId, URL, fileName, basePath }); - return await getFileURL({ fileName: `${userId}/${fileName}`, basePath }); + const { + bytes = 0, + type = '', + dimensions = {}, + } = (await saveURL({ userId, URL, fileName, basePath })) || {}; + const filepath = await getFileURL({ fileName: `${userId}/${fileName}`, basePath }); + return await createFile( + { + user: userId, + file_id: v4(), + bytes, + filepath, + filename: fileName, + source: fileStrategy, + type, + context, + width: dimensions.width, + height: dimensions.height, + }, + true, + ); } catch (error) { logger.error(`Error while processing the image with ${fileStrategy}:`, error); throw new Error(`Failed to process the image with ${fileStrategy}. ${error.message}`); @@ -49,7 +177,6 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) /** * Applies the current strategy for image uploads. * Saves file metadata to the database with an expiry TTL. - * Files must be deleted from the server filesystem manually. * * @param {Object} params - The parameters object. * @param {Express.Request} params.req - The Express request object. @@ -58,11 +185,18 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) * @param {ImageMetadata} params.metadata - Additional metadata for the file. * @returns {Promise<void>} */ -const processImageUpload = async ({ req, res, file, metadata }) => { +const processImageFile = async ({ req, res, file, metadata }) => { const source = req.app.locals.fileStrategy; const { handleImageUpload } = getStrategyFunctions(source); - const { file_id, temp_file_id } = metadata; - const { filepath, bytes, width, height } = await handleImageUpload(req, file); + const { file_id, temp_file_id, endpoint } = metadata; + + const { filepath, bytes, width, height } = await handleImageUpload({ + req, + file, + file_id, + endpoint, + }); + const result = await createFile( { user: req.user.id, @@ -71,6 +205,7 @@ const processImageUpload = async ({ req, res, file, metadata }) => { bytes, filepath, filename: file.originalname, + context: FileContext.message_attachment, source, type: 'image/webp', width, @@ -81,8 +216,343 @@ const processImageUpload = async ({ req, res, file, metadata }) => { res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; +/** + * Applies the current strategy for image uploads and + * returns minimal file metadata, without saving to the database. + * + * @param {Object} params - The parameters object. + * @param {Express.Request} params.req - The Express request object. + * @param {FileContext} params.context - The context of the file (e.g., 'avatar', 'image_generation', etc.) + * @param {boolean} [params.resize=true] - Whether to resize and convert the image to WebP. Default is `true`. + * @param {{ buffer: Buffer, width: number, height: number, bytes: number, filename: string, type: string, file_id: string }} [params.metadata] - Required metadata for the file if resize is false. + * @returns {Promise<{ filepath: string, filename: string, source: string, type: 'image/webp'}>} + */ +const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) => { + const source = req.app.locals.fileStrategy; + const { saveBuffer } = getStrategyFunctions(source); + let { buffer, width, height, bytes, filename, file_id, type } = metadata; + if (resize) { + file_id = v4(); + type = 'image/webp'; + ({ buffer, width, height, bytes } = await resizeAndConvert(req.file.buffer)); + filename = path.basename(req.file.originalname, path.extname(req.file.originalname)) + '.webp'; + } + + const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer }); + return await createFile( + { + user: req.user.id, + file_id, + bytes, + filepath, + filename, + context, + source, + type, + width, + height, + }, + true, + ); +}; + +/** + * Applies the current strategy for file uploads. + * Saves file metadata to the database with an expiry TTL. + * Files must be deleted from the server filesystem manually. + * + * @param {Object} params - The parameters object. + * @param {Express.Request} params.req - The Express request object. + * @param {Express.Response} params.res - The Express response object. + * @param {Express.Multer.File} params.file - The uploaded file. + * @param {FileMetadata} params.metadata - Additional metadata for the file. + * @returns {Promise<void>} + */ +const processFileUpload = async ({ req, res, file, metadata }) => { + const isAssistantUpload = metadata.endpoint === EModelEndpoint.assistants; + const source = isAssistantUpload ? FileSources.openai : FileSources.vectordb; + const { handleFileUpload } = getStrategyFunctions(source); + const { file_id, temp_file_id } = metadata; + + /** @type {OpenAI | undefined} */ + let openai; + if (source === FileSources.openai) { + ({ openai } = await initializeClient({ req })); + } + + const { id, bytes, filename, filepath, embedded } = await handleFileUpload({ + req, + file, + file_id, + openai, + }); + + if (isAssistantUpload && !metadata.message_file) { + await openai.beta.assistants.files.create(metadata.assistant_id, { + file_id: id, + }); + } + + const result = await createFile( + { + user: req.user.id, + file_id: id ?? file_id, + temp_file_id, + bytes, + filename: filename ?? file.originalname, + filepath: isAssistantUpload ? `${openai.baseURL}/files/${id}` : filepath, + context: isAssistantUpload ? FileContext.assistants : FileContext.message_attachment, + model: isAssistantUpload ? req.body.model : undefined, + type: file.mimetype, + embedded, + source, + }, + true, + ); + res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); +}; + +/** + * @param {object} params - The params object. + * @param {OpenAI} params.openai - The OpenAI client instance. + * @param {string} params.file_id - The ID of the file to retrieve. + * @param {string} params.userId - The user ID. + * @param {string} [params.filename] - The name of the file. `undefined` for `file_citation` annotations. + * @param {boolean} [params.saveFile=false] - Whether to save the file metadata to the database. + * @param {boolean} [params.updateUsage=false] - Whether to update file usage in database. + */ +const processOpenAIFile = async ({ + openai, + file_id, + userId, + filename, + saveFile = false, + updateUsage = false, +}) => { + const _file = await openai.files.retrieve(file_id); + const originalName = filename ?? (_file.filename ? path.basename(_file.filename) : undefined); + const filepath = `${openai.baseURL}/files/${userId}/${file_id}${ + originalName ? `/${originalName}` : '' + }`; + const type = mime.getType(originalName ?? file_id); + + const file = { + ..._file, + type, + file_id, + filepath, + usage: 1, + user: userId, + context: _file.purpose, + source: FileSources.openai, + model: openai.req.body.model, + filename: originalName ?? file_id, + }; + + if (saveFile) { + await createFile(file, true); + } else if (updateUsage) { + try { + await updateFileUsage({ file_id }); + } catch (error) { + logger.error('Error updating file usage', error); + } + } + + return file; +}; + +/** + * Process OpenAI image files, convert to webp, save and return file metadata. + * @param {object} params - The params object. + * @param {Express.Request} params.req - The Express request object. + * @param {Buffer} params.buffer - The image buffer. + * @param {string} params.file_id - The file ID. + * @param {string} params.filename - The filename. + * @param {string} params.fileExt - The file extension. + * @returns {Promise<MongoFile>} The file metadata. + */ +const processOpenAIImageOutput = async ({ req, buffer, file_id, filename, fileExt }) => { + const currentDate = new Date(); + const formattedDate = currentDate.toISOString(); + const _file = await convertToWebP(req, buffer, 'high', `${file_id}${fileExt}`); + const file = { + ..._file, + usage: 1, + user: req.user.id, + type: 'image/webp', + createdAt: formattedDate, + updatedAt: formattedDate, + source: req.app.locals.fileStrategy, + context: FileContext.assistants_output, + file_id: `${file_id}${hostImageIdSuffix}`, + filename: `${hostImageNamePrefix}${filename}`, + }; + createFile(file, true); + createFile( + { + ...file, + file_id, + filename, + source: FileSources.openai, + type: mime.getType(fileExt), + }, + true, + ); + return file; +}; + +/** + * Retrieves and processes an OpenAI file based on its type. + * + * @param {Object} params - The params passed to the function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {RunClient} params.client - The LibreChat client instance: either refers to `openai` or `streamRunManager`. + * @param {string} params.file_id - The ID of the file to retrieve. + * @param {string} [params.basename] - The basename of the file (if image); e.g., 'image.jpg'. `undefined` for `file_citation` annotations. + * @param {boolean} [params.unknownType] - Whether the file type is unknown. + * @returns {Promise<{file_id: string, filepath: string, source: string, bytes?: number, width?: number, height?: number} | null>} + * - Returns null if `file_id` is not defined; else, the file metadata if successfully retrieved and processed. + */ +async function retrieveAndProcessFile({ + openai, + client, + file_id, + basename: _basename, + unknownType, +}) { + if (!file_id) { + return null; + } + + let basename = _basename; + const processArgs = { openai, file_id, filename: basename, userId: client.req.user.id }; + + // If no basename provided, return only the file metadata + if (!basename) { + return await processOpenAIFile({ ...processArgs, saveFile: true }); + } + + const fileExt = path.extname(basename); + if (client.attachedFileIds?.has(file_id) || client.processedFileIds?.has(file_id)) { + return processOpenAIFile({ ...processArgs, updateUsage: true }); + } + + /** + * @returns {Promise<Buffer>} The file data buffer. + */ + const getDataBuffer = async () => { + const response = await openai.files.content(file_id); + const arrayBuffer = await response.arrayBuffer(); + return Buffer.from(arrayBuffer); + }; + + let dataBuffer; + if (unknownType || !fileExt || imageExtRegex.test(basename)) { + try { + dataBuffer = await getDataBuffer(); + } catch (error) { + logger.error('Error downloading file from OpenAI:', error); + dataBuffer = null; + } + } + + if (!dataBuffer) { + return await processOpenAIFile({ ...processArgs, saveFile: true }); + } + + // If the filetype is unknown, inspect the file + if (dataBuffer && (unknownType || !fileExt)) { + const detectedExt = await determineFileType(dataBuffer); + const isImageOutput = detectedExt && imageExtRegex.test('.' + detectedExt); + + if (!isImageOutput) { + return await processOpenAIFile({ ...processArgs, saveFile: true }); + } + + return await processOpenAIImageOutput({ + file_id, + req: client.req, + buffer: dataBuffer, + filename: basename, + fileExt: detectedExt, + }); + } else if (dataBuffer && imageExtRegex.test(basename)) { + return await processOpenAIImageOutput({ + file_id, + req: client.req, + buffer: dataBuffer, + filename: basename, + fileExt, + }); + } else { + logger.debug(`[retrieveAndProcessFile] Non-image file type detected: ${basename}`); + return await processOpenAIFile({ ...processArgs, saveFile: true }); + } +} + +/** + * Filters a file based on its size and the endpoint origin. + * + * @param {Object} params - The parameters for the function. + * @param {Express.Request} params.req - The request object from Express. + * @param {Express.Multer.File} params.file - The file uploaded to the server via multer. + * @param {boolean} [params.image] - Whether the file expected is an image. + * @returns {void} + * + * @throws {Error} If a file exception is caught (invalid file size or type, lack of metadata). + */ +function filterFile({ req, file, image }) { + const { endpoint, file_id, width, height } = req.body; + + if (!file_id) { + throw new Error('No file_id provided'); + } + + /* parse to validate api call, throws error on fail */ + isUUID.parse(file_id); + + if (!endpoint) { + throw new Error('No endpoint provided'); + } + + const fileConfig = mergeFileConfig(req.app.locals.fileConfig); + + const { fileSizeLimit, supportedMimeTypes } = + fileConfig.endpoints[endpoint] ?? fileConfig.endpoints.default; + + if (file.size > fileSizeLimit) { + throw new Error( + `File size limit of ${fileSizeLimit / megabyte} MB exceeded for ${endpoint} endpoint`, + ); + } + + const isSupportedMimeType = fileConfig.checkType(file.mimetype, supportedMimeTypes); + + if (!isSupportedMimeType) { + throw new Error('Unsupported file type'); + } + + if (!image) { + return; + } + + if (!width) { + throw new Error('No width provided'); + } + + if (!height) { + throw new Error('No height provided'); + } +} + module.exports = { - processImageUpload, + filterFile, processFiles, processFileURL, + processImageFile, + uploadImageBuffer, + processFileUpload, + processDeleteRequest, + retrieveAndProcessFile, }; diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index 4e201860434..96733e4037f 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -4,38 +4,103 @@ const { prepareImageURL, saveURLToFirebase, deleteFirebaseFile, + saveBufferToFirebase, uploadImageToFirebase, processFirebaseAvatar, + getFirebaseFileStream, } = require('./Firebase'); const { getLocalFileURL, saveFileFromURL, + saveLocalBuffer, deleteLocalFile, uploadLocalImage, prepareImagesLocal, processLocalAvatar, + getLocalFileStream, } = require('./Local'); +const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); +const { uploadVectors, deleteVectors } = require('./VectorDB'); -// Firebase Strategy Functions +/** + * Firebase Storage Strategy Functions + * + * */ const firebaseStrategy = () => ({ // saveFile: + /** @type {typeof uploadVectors | null} */ + handleFileUpload: null, saveURL: saveURLToFirebase, getFileURL: getFirebaseURL, deleteFile: deleteFirebaseFile, + saveBuffer: saveBufferToFirebase, prepareImagePayload: prepareImageURL, processAvatar: processFirebaseAvatar, handleImageUpload: uploadImageToFirebase, + getDownloadStream: getFirebaseFileStream, }); -// Local Strategy Functions +/** + * Local Server Storage Strategy Functions + * + * */ const localStrategy = () => ({ - // saveFile: , + /** @type {typeof uploadVectors | null} */ + handleFileUpload: null, saveURL: saveFileFromURL, getFileURL: getLocalFileURL, + saveBuffer: saveLocalBuffer, deleteFile: deleteLocalFile, processAvatar: processLocalAvatar, handleImageUpload: uploadLocalImage, prepareImagePayload: prepareImagesLocal, + getDownloadStream: getLocalFileStream, +}); + +/** + * VectorDB Storage Strategy Functions + * + * */ +const vectorStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadVectors, + deleteFile: deleteVectors, +}); + +/** + * OpenAI Strategy Functions + * + * Note: null values mean that the strategy is not supported. + * */ +const openAIStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + deleteFile: deleteOpenAIFile, + handleFileUpload: uploadOpenAIFile, + getDownloadStream: getOpenAIFileStream, }); // Strategy Selector @@ -44,6 +109,10 @@ const getStrategyFunctions = (fileSource) => { return firebaseStrategy(); } else if (fileSource === FileSources.local) { return localStrategy(); + } else if (fileSource === FileSources.openai) { + return openAIStrategy(); + } else if (fileSource === FileSources.vectordb) { + return vectorStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 76ac061546d..69c71629a4a 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,76 +1,101 @@ -const Keyv = require('keyv'); const axios = require('axios'); -const HttpsProxyAgent = require('https-proxy-agent'); -const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); -const { isEnabled } = require('~/server/utils'); -const keyvRedis = require('~/cache/keyvRedis'); -const { extractBaseURL } = require('~/utils'); -const { logger } = require('~/config'); - -// const { getAzureCredentials, genAzureChatCompletion } = require('~/utils/'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); +const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require('~/utils'); +const getLogStores = require('~/cache/getLogStores'); const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config; -const modelsCache = isEnabled(process.env.USE_REDIS) - ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'models' }); - -const { - OPENROUTER_API_KEY, - OPENAI_REVERSE_PROXY, - CHATGPT_MODELS, - ANTHROPIC_MODELS, - GOOGLE_MODELS, - PROXY, -} = process.env ?? {}; - /** * Fetches OpenAI models from the specified base API path or Azure, based on the provided configuration. * * @param {Object} params - The parameters for fetching the models. + * @param {Object} params.user - The user ID to send to the API. * @param {string} params.apiKey - The API key for authentication with the API. * @param {string} params.baseURL - The base path URL for the API. * @param {string} [params.name='OpenAI'] - The name of the API; defaults to 'OpenAI'. * @param {boolean} [params.azure=false] - Whether to fetch models from Azure. + * @param {boolean} [params.userIdQuery=false] - Whether to send the user ID as a query parameter. + * @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response. + * @param {string} [params.tokenKey] - The cache key to save the token configuration. Uses `name` if omitted. * @returns {Promise<string[]>} A promise that resolves to an array of model identifiers. * @async */ -const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false }) => { +const fetchModels = async ({ + user, + apiKey, + baseURL, + name = 'OpenAI', + azure = false, + userIdQuery = false, + createTokenConfig = true, + tokenKey, +}) => { let models = []; if (!baseURL && !azure) { return models; } + if (!apiKey) { + return models; + } + try { - const payload = { + const options = { headers: { Authorization: `Bearer ${apiKey}`, }, }; - if (PROXY) { - payload.httpsAgent = new HttpsProxyAgent(PROXY); + if (process.env.PROXY) { + options.httpsAgent = new HttpsProxyAgent(process.env.PROXY); } if (process.env.OPENAI_ORGANIZATION && baseURL.includes('openai')) { - payload.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION; + options.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION; } - const res = await axios.get(`${baseURL}${azure ? '' : '/models'}`, payload); - models = res.data.data.map((item) => item.id); - } catch (err) { - logger.error(`Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`, err); + const url = new URL(`${baseURL}${azure ? '' : '/models'}`); + if (user && userIdQuery) { + url.searchParams.append('user', user); + } + const res = await axios.get(url.toString(), options); + + /** @type {z.infer<typeof inputSchema>} */ + const input = res.data; + + const validationResult = inputSchema.safeParse(input); + if (validationResult.success && createTokenConfig) { + const endpointTokenConfig = processModelData(input); + const cache = getLogStores(CacheKeys.TOKEN_CONFIG); + await cache.set(tokenKey ?? name, endpointTokenConfig); + } + models = input.data.map((item) => item.id); + } catch (error) { + const logMessage = `Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`; + logAxiosError({ message: logMessage, error }); } return models; }; -const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => { +/** + * Fetches models from the specified API path or Azure, based on the provided options. + * @async + * @function + * @param {object} opts - The options for fetching the models. + * @param {string} opts.user - The user ID to send to the API. + * @param {boolean} [opts.azure=false] - Whether to fetch models from Azure. + * @param {boolean} [opts.plugins=false] - Whether to fetch models from the plugins. + * @param {string[]} [_models=[]] - The models to use as a fallback. + */ +const fetchOpenAIModels = async (opts, _models = []) => { let models = _models.slice() ?? []; let apiKey = openAIApiKey; - let baseURL = 'https://api.openai.com/v1'; - let reverseProxyUrl = OPENAI_REVERSE_PROXY; + const openaiBaseURL = 'https://api.openai.com/v1'; + let baseURL = openaiBaseURL; + let reverseProxyUrl = process.env.OPENAI_REVERSE_PROXY; if (opts.azure) { return models; // const azure = getAzureCredentials(); @@ -78,15 +103,17 @@ const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _model // .split('/deployments')[0] // .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`); // apiKey = azureOpenAIApiKey; - } else if (OPENROUTER_API_KEY) { + } else if (process.env.OPENROUTER_API_KEY) { reverseProxyUrl = 'https://openrouter.ai/api/v1'; - apiKey = OPENROUTER_API_KEY; + apiKey = process.env.OPENROUTER_API_KEY; } if (reverseProxyUrl) { baseURL = extractBaseURL(reverseProxyUrl); } + const modelsCache = getLogStores(CacheKeys.MODEL_QUERIES); + const cachedModels = await modelsCache.get(baseURL); if (cachedModels) { return cachedModels; @@ -97,34 +124,57 @@ const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _model apiKey, baseURL, azure: opts.azure, + user: opts.user, }); } - if (!reverseProxyUrl) { + if (models.length === 0) { + return _models; + } + + if (baseURL === openaiBaseURL) { const regex = /(text-davinci-003|gpt-)/; models = models.filter((model) => regex.test(model)); + const instructModels = models.filter((model) => model.includes('instruct')); + const otherModels = models.filter((model) => !model.includes('instruct')); + models = otherModels.concat(instructModels); } await modelsCache.set(baseURL, models); return models; }; -const getOpenAIModels = async (opts = { azure: false, plugins: false }) => { - let models = [ - 'gpt-4', - 'gpt-4-0613', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k', - 'gpt-3.5-turbo-0613', - 'gpt-3.5-turbo-0301', - ]; +/** + * Loads the default models for the application. + * @async + * @function + * @param {object} opts - The options for fetching the models. + * @param {string} opts.user - The user ID to send to the API. + * @param {boolean} [opts.azure=false] - Whether to fetch models from Azure. + * @param {boolean} [opts.plugins=false] - Whether to fetch models from the plugins. + */ +const getOpenAIModels = async (opts) => { + let models = defaultModels[EModelEndpoint.openAI]; + + if (opts.assistants) { + models = defaultModels[EModelEndpoint.assistants]; + } - if (!opts.plugins) { - models.push('text-davinci-003'); + if (opts.plugins) { + models = models.filter( + (model) => + !model.includes('text-davinci') && + !model.includes('instruct') && + !model.includes('0613') && + !model.includes('0314') && + !model.includes('0301'), + ); } let key; - if (opts.azure) { + if (opts.assistants) { + key = 'ASSISTANTS_MODELS'; + } else if (opts.azure) { key = 'AZURE_OPENAI_MODELS'; } else if (opts.plugins) { key = 'PLUGIN_MODELS'; @@ -137,7 +187,11 @@ const getOpenAIModels = async (opts = { azure: false, plugins: false }) => { return models; } - if (userProvidedOpenAI && !OPENROUTER_API_KEY) { + if (userProvidedOpenAI && !process.env.OPENROUTER_API_KEY) { + return models; + } + + if (opts.assistants) { return models; } @@ -146,8 +200,8 @@ const getOpenAIModels = async (opts = { azure: false, plugins: false }) => { const getChatGPTBrowserModels = () => { let models = ['text-davinci-002-render-sha', 'gpt-4']; - if (CHATGPT_MODELS) { - models = String(CHATGPT_MODELS).split(','); + if (process.env.CHATGPT_MODELS) { + models = String(process.env.CHATGPT_MODELS).split(','); } return models; @@ -155,8 +209,8 @@ const getChatGPTBrowserModels = () => { const getAnthropicModels = () => { let models = defaultModels[EModelEndpoint.anthropic]; - if (ANTHROPIC_MODELS) { - models = String(ANTHROPIC_MODELS).split(','); + if (process.env.ANTHROPIC_MODELS) { + models = String(process.env.ANTHROPIC_MODELS).split(','); } return models; @@ -164,8 +218,8 @@ const getAnthropicModels = () => { const getGoogleModels = () => { let models = defaultModels[EModelEndpoint.google]; - if (GOOGLE_MODELS) { - models = String(GOOGLE_MODELS).split(','); + if (process.env.GOOGLE_MODELS) { + models = String(process.env.GOOGLE_MODELS).split(','); } return models; diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js new file mode 100644 index 00000000000..7c1d326fa1a --- /dev/null +++ b/api/server/services/ModelService.spec.js @@ -0,0 +1,258 @@ +const axios = require('axios'); + +const { fetchModels, getOpenAIModels } = require('./ModelService'); +jest.mock('~/utils', () => { + const originalUtils = jest.requireActual('~/utils'); + return { + ...originalUtils, + processModelData: jest.fn((...args) => { + return originalUtils.processModelData(...args); + }), + }; +}); + +jest.mock('axios'); +jest.mock('~/cache/getLogStores', () => + jest.fn().mockImplementation(() => ({ + get: jest.fn().mockResolvedValue(undefined), + set: jest.fn().mockResolvedValue(true), + })), +); +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, +})); +jest.mock('./Config/EndpointService', () => ({ + config: { + openAIApiKey: 'mockedApiKey', + userProvidedOpenAI: false, + }, +})); + +axios.get.mockResolvedValue({ + data: { + data: [{ id: 'model-1' }, { id: 'model-2' }], + }, +}); + +describe('fetchModels', () => { + it('fetches models successfully from the API', async () => { + const models = await fetchModels({ + user: 'user123', + apiKey: 'testApiKey', + baseURL: 'https://api.test.com', + name: 'TestAPI', + }); + + expect(models).toEqual(['model-1', 'model-2']); + expect(axios.get).toHaveBeenCalledWith( + expect.stringContaining('https://api.test.com/models'), + expect.any(Object), + ); + }); + + it('adds the user ID to the models query when option and ID are passed', async () => { + const models = await fetchModels({ + user: 'user123', + apiKey: 'testApiKey', + baseURL: 'https://api.test.com', + userIdQuery: true, + name: 'TestAPI', + }); + + expect(models).toEqual(['model-1', 'model-2']); + expect(axios.get).toHaveBeenCalledWith( + expect.stringContaining('https://api.test.com/models?user=user123'), + expect.any(Object), + ); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); +}); + +describe('fetchModels with createTokenConfig true', () => { + const data = { + data: [ + { + id: 'model-1', + pricing: { + prompt: '0.002', + completion: '0.001', + }, + context_length: 1024, + }, + { + id: 'model-2', + pricing: { + prompt: '0.003', + completion: '0.0015', + }, + context_length: 2048, + }, + ], + }; + + beforeEach(() => { + // Clears the mock's history before each test + const _utils = require('~/utils'); + axios.get.mockResolvedValue({ data }); + }); + + it('creates and stores token configuration if createTokenConfig is true', async () => { + await fetchModels({ + user: 'user123', + apiKey: 'testApiKey', + baseURL: 'https://api.test.com', + createTokenConfig: true, + }); + + const { processModelData } = require('~/utils'); + expect(processModelData).toHaveBeenCalled(); + expect(processModelData).toHaveBeenCalledWith(data); + }); +}); + +describe('getOpenAIModels', () => { + let originalEnv; + + beforeEach(() => { + originalEnv = { ...process.env }; + axios.get.mockRejectedValue(new Error('Network error')); + }); + + afterEach(() => { + process.env = originalEnv; + axios.get.mockReset(); + }); + + it('returns default models when no environment configurations are provided (and fetch fails)', async () => { + const models = await getOpenAIModels({ user: 'user456' }); + expect(models).toContain('gpt-4'); + }); + + it('returns `AZURE_OPENAI_MODELS` with `azure` flag (and fetch fails)', async () => { + process.env.AZURE_OPENAI_MODELS = 'azure-model,azure-model-2'; + const models = await getOpenAIModels({ azure: true }); + expect(models).toEqual(expect.arrayContaining(['azure-model', 'azure-model-2'])); + }); + + it('returns `PLUGIN_MODELS` with `plugins` flag (and fetch fails)', async () => { + process.env.PLUGIN_MODELS = 'plugins-model,plugins-model-2'; + const models = await getOpenAIModels({ plugins: true }); + expect(models).toEqual(expect.arrayContaining(['plugins-model', 'plugins-model-2'])); + }); + + it('returns `OPENAI_MODELS` with no flags (and fetch fails)', async () => { + process.env.OPENAI_MODELS = 'openai-model,openai-model-2'; + const models = await getOpenAIModels({}); + expect(models).toEqual(expect.arrayContaining(['openai-model', 'openai-model-2'])); + }); + + it('attempts to use OPENROUTER_API_KEY if set', async () => { + process.env.OPENROUTER_API_KEY = 'test-router-key'; + const expectedModels = ['model-router-1', 'model-router-2']; + + axios.get.mockResolvedValue({ + data: { + data: expectedModels.map((id) => ({ id })), + }, + }); + + const models = await getOpenAIModels({ user: 'user456' }); + + expect(models).toEqual(expect.arrayContaining(expectedModels)); + expect(axios.get).toHaveBeenCalled(); + }); + + it('utilizes proxy configuration when PROXY is set', async () => { + axios.get.mockResolvedValue({ + data: { + data: [], + }, + }); + process.env.PROXY = 'http://localhost:8888'; + await getOpenAIModels({ user: 'user456' }); + + expect(axios.get).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + httpsAgent: expect.anything(), + }), + ); + }); +}); + +describe('getOpenAIModels with mocked config', () => { + it('uses alternative behavior when userProvidedOpenAI is true', async () => { + jest.mock('./Config/EndpointService', () => ({ + config: { + openAIApiKey: 'mockedApiKey', + userProvidedOpenAI: true, + }, + })); + jest.mock('librechat-data-provider', () => { + const original = jest.requireActual('librechat-data-provider'); + return { + ...original, + defaultModels: { + [original.EModelEndpoint.openAI]: ['some-default-model'], + }, + }; + }); + + jest.resetModules(); + const { getOpenAIModels } = require('./ModelService'); + + const models = await getOpenAIModels({ user: 'user456' }); + expect(models).toContain('some-default-model'); + }); +}); + +describe('getOpenAIModels sorting behavior', () => { + beforeEach(() => { + axios.get.mockResolvedValue({ + data: { + data: [ + { id: 'gpt-3.5-turbo-instruct-0914' }, + { id: 'gpt-3.5-turbo-instruct' }, + { id: 'gpt-3.5-turbo' }, + { id: 'gpt-4-0314' }, + { id: 'gpt-4-turbo-preview' }, + ], + }, + }); + }); + + it('ensures instruct models are listed last', async () => { + const models = await getOpenAIModels({ user: 'user456' }); + + // Check if the last model is an "instruct" model + expect(models[models.length - 1]).toMatch(/instruct/); + + // Check if the "instruct" models are placed at the end + const instructIndexes = models + .map((model, index) => (model.includes('instruct') ? index : -1)) + .filter((index) => index !== -1); + const nonInstructIndexes = models + .map((model, index) => (!model.includes('instruct') ? index : -1)) + .filter((index) => index !== -1); + + expect(Math.max(...nonInstructIndexes)).toBeLessThan(Math.min(...instructIndexes)); + + const expectedOrder = [ + 'gpt-3.5-turbo', + 'gpt-4-0314', + 'gpt-4-turbo-preview', + 'gpt-3.5-turbo-instruct-0914', + 'gpt-3.5-turbo-instruct', + ]; + expect(models).toEqual(expectedOrder); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); +}); diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 61582382914..efe0bb03fd8 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -90,8 +90,7 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { const deleteUserPluginAuth = async (userId, authField) => { try { - const response = await PluginAuth.deleteOne({ userId, authField }); - return response; + return await PluginAuth.deleteOne({ userId, authField }); } catch (err) { logger.error('[deleteUserPluginAuth]', err); return err; diff --git a/api/server/services/Runs/RunMananger.js b/api/server/services/Runs/RunManager.js similarity index 59% rename from api/server/services/Runs/RunMananger.js rename to api/server/services/Runs/RunManager.js index 67a3624c187..c8deeb9264b 100644 --- a/api/server/services/Runs/RunMananger.js +++ b/api/server/services/Runs/RunManager.js @@ -1,3 +1,4 @@ +const { ToolCallTypes } = require('librechat-data-provider'); const { logger } = require('~/config'); /** @@ -18,6 +19,53 @@ const { logger } = require('~/config'); * @property {Function} handleStep - Handles a run step based on its status. */ +/** + * Generates a signature string for a given tool call object. This signature includes + * the tool call's id, type, and other distinguishing features based on its type. + * + * @param {ToolCall} toolCall The tool call object for which to generate a signature. + * @returns {string} The generated signature for the tool call. + */ +function getToolCallSignature(toolCall) { + if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) { + const inputLength = toolCall.code_interpreter?.input?.length ?? 0; + const outputsLength = toolCall.code_interpreter?.outputs?.length ?? 0; + return `${toolCall.id}-${toolCall.type}-${inputLength}-${outputsLength}`; + } + if (toolCall.type === ToolCallTypes.RETRIEVAL) { + return `${toolCall.id}-${toolCall.type}`; + } + if (toolCall.type === ToolCallTypes.FUNCTION) { + const argsLength = toolCall.function?.arguments?.length ?? 0; + const hasOutput = toolCall.function?.output ? 1 : 0; + return `${toolCall.id}-${toolCall.type}-${argsLength}-${hasOutput}`; + } + + return `${toolCall.id}-unknown-type`; +} + +/** + * Generates a signature based on the specifics of the step details. + * This function supports 'message_creation' and 'tool_calls' types, and returns a default signature + * for any other type or in case the details are undefined. + * + * @param {MessageCreationStepDetails | ToolCallsStepDetails | undefined} details - The detailed content of the step, which can be undefined. + * @returns {string} A signature string derived from the content of step details. + */ +function getDetailsSignature(details) { + if (!details) { + return 'undefined-details'; + } + + if (details.type === 'message_creation') { + return `${details.type}-${details.message_creation.message_id}`; + } else if (details.type === 'tool_calls') { + const toolCallsSignature = details.tool_calls.map(getToolCallSignature).join('|'); + return `${details.type}-${toolCallsSignature}`; + } + return 'unknown-type'; +} + /** * Manages the retrieval and processing of run steps based on run status. */ @@ -44,15 +92,25 @@ class RunManager { */ async fetchRunSteps({ openai, thread_id, run_id, runStatus, final = false }) { // const { data: steps, first_id, last_id, has_more } = await openai.beta.threads.runs.steps.list(thread_id, run_id); - const { data: _steps } = await openai.beta.threads.runs.steps.list(thread_id, run_id); + const { data: _steps } = await openai.beta.threads.runs.steps.list( + thread_id, + run_id, + {}, + { + timeout: 3000, + maxRetries: 5, + }, + ); const steps = _steps.sort((a, b) => a.created_at - b.created_at); for (const [i, step] of steps.entries()) { - if (this.seenSteps.has(step.id)) { + const detailsSignature = getDetailsSignature(step.step_details); + const stepKey = `${step.id}-${step.status}-${detailsSignature}`; + if (!final && this.seenSteps.has(stepKey)) { continue; } const isLast = i === steps.length - 1; - this.seenSteps.add(step.id); + this.seenSteps.add(stepKey); this.stepsByStatus[runStatus] = this.stepsByStatus[runStatus] || []; const currentStepPromise = (async () => { @@ -64,6 +122,13 @@ class RunManager { return await currentStepPromise; } + if (step.type === 'tool_calls') { + await currentStepPromise; + } + if (step.type === 'message_creation' && step.status === 'completed') { + await currentStepPromise; + } + this.lastStepPromiseByStatus[runStatus] = currentStepPromise; this.stepsByStatus[runStatus].push(currentStepPromise); } @@ -79,7 +144,7 @@ class RunManager { */ async handleStep({ step, runStatus, final, isLast }) { if (this.handlers[runStatus]) { - return this.handlers[runStatus]({ step, final, isLast }); + return await this.handlers[runStatus]({ step, final, isLast }); } if (final && isLast && this.handlers['final']) { diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js new file mode 100644 index 00000000000..ce78b593188 --- /dev/null +++ b/api/server/services/Runs/StreamRunManager.js @@ -0,0 +1,619 @@ +const { + StepTypes, + ContentTypes, + ToolCallTypes, + // StepStatus, + MessageContentTypes, + AssistantStreamEvents, +} = require('librechat-data-provider'); +const { retrieveAndProcessFile } = require('~/server/services/Files/process'); +const { processRequiredActions } = require('~/server/services/ToolService'); +const { createOnProgress, sendMessage } = require('~/server/utils'); +const { processMessages } = require('~/server/services/Threads'); +const { logger } = require('~/config'); + +/** + * Implements the StreamRunManager functionality for managing the streaming + * and processing of run steps, messages, and tool calls within a thread. + * @implements {StreamRunManager} + */ +class StreamRunManager { + constructor(fields) { + this.index = 0; + /** @type {Map<string, RunStep>} */ + this.steps = new Map(); + + /** @type {Map<string, number} */ + this.mappedOrder = new Map(); + /** @type {Map<string, StepToolCall} */ + this.orderedRunSteps = new Map(); + /** @type {Set<string>} */ + this.processedFileIds = new Set(); + /** @type {Map<string, (delta: ToolCallDelta | string) => Promise<void>} */ + this.progressCallbacks = new Map(); + /** @type {Run | null} */ + this.run = null; + + /** @type {Express.Request} */ + this.req = fields.req; + /** @type {Express.Response} */ + this.res = fields.res; + /** @type {OpenAI} */ + this.openai = fields.openai; + /** @type {string} */ + this.apiKey = this.openai.apiKey; + /** @type {string} */ + this.thread_id = fields.thread_id; + /** @type {RunCreateAndStreamParams} */ + this.initialRunBody = fields.runBody; + /** + * @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>} + */ + this.clientHandlers = fields.handlers ?? {}; + /** @type {OpenAIRequestOptions} */ + this.streamOptions = fields.streamOptions ?? {}; + /** @type {Partial<TMessage>} */ + this.finalMessage = fields.responseMessage ?? {}; + /** @type {ThreadMessage[]} */ + this.messages = []; + /** @type {string} */ + this.text = ''; + /** @type {Set<string>} */ + this.attachedFileIds = fields.attachedFileIds; + /** @type {undefined | Promise<ChatCompletion>} */ + this.visionPromise = fields.visionPromise; + + /** + * @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>} + */ + this.handlers = { + [AssistantStreamEvents.ThreadCreated]: this.handleThreadCreated, + [AssistantStreamEvents.ThreadRunCreated]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunQueued]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunInProgress]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunRequiresAction]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunCompleted]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunFailed]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunCancelling]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunCancelled]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunExpired]: this.handleRunEvent, + [AssistantStreamEvents.ThreadRunStepCreated]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepInProgress]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepCompleted]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepFailed]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepCancelled]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepExpired]: this.handleRunStepEvent, + [AssistantStreamEvents.ThreadRunStepDelta]: this.handleRunStepDeltaEvent, + [AssistantStreamEvents.ThreadMessageCreated]: this.handleMessageEvent, + [AssistantStreamEvents.ThreadMessageInProgress]: this.handleMessageEvent, + [AssistantStreamEvents.ThreadMessageCompleted]: this.handleMessageEvent, + [AssistantStreamEvents.ThreadMessageIncomplete]: this.handleMessageEvent, + [AssistantStreamEvents.ThreadMessageDelta]: this.handleMessageDeltaEvent, + [AssistantStreamEvents.ErrorEvent]: this.handleErrorEvent, + }; + } + + /** + * + * Sends the content data to the client via SSE. + * + * @param {StreamContentData} data + * @returns {Promise<void>} + */ + async addContentData(data) { + const { type, index, edited } = data; + /** @type {ContentPart} */ + const contentPart = data[type]; + this.finalMessage.content[index] = { type, [type]: contentPart }; + + if (type === ContentTypes.TEXT && !edited) { + this.text += contentPart.value; + return; + } + + const contentData = { + index, + type, + [type]: contentPart, + thread_id: this.thread_id, + messageId: this.finalMessage.messageId, + conversationId: this.finalMessage.conversationId, + }; + + sendMessage(this.res, contentData); + } + + /* <------------------ Main Event Handlers ------------------> */ + + /** + * Run the assistant and handle the events. + * @param {Object} params - + * The parameters for running the assistant. + * @param {string} params.thread_id - The thread id. + * @param {RunCreateAndStreamParams} params.body - The body of the run. + * @returns {Promise<void>} + */ + async runAssistant({ thread_id, body }) { + const streamRun = this.openai.beta.threads.runs.createAndStream( + thread_id, + body, + this.streamOptions, + ); + for await (const event of streamRun) { + await this.handleEvent(event); + } + } + + /** + * Handle the event. + * @param {AssistantStreamEvent} event - The stream event object. + * @returns {Promise<void>} + */ + async handleEvent(event) { + const handler = this.handlers[event.event]; + const clientHandler = this.clientHandlers[event.event]; + + if (clientHandler) { + await clientHandler.call(this, event); + } + + if (handler) { + await handler.call(this, event); + } else { + logger.warn(`Unhandled event type: ${event.event}`); + } + } + + /** + * Handle thread.created event + * @param {ThreadCreated} event - + * The thread.created event object. + */ + async handleThreadCreated(event) { + logger.debug('Thread created:', event.data); + } + + /** + * Handle Run Events + * @param {ThreadRunCreated | ThreadRunQueued | ThreadRunInProgress | ThreadRunRequiresAction | ThreadRunCompleted | ThreadRunFailed | ThreadRunCancelling | ThreadRunCancelled | ThreadRunExpired} event - + * The run event object. + */ + async handleRunEvent(event) { + this.run = event.data; + logger.debug('Run event:', this.run); + if (event.event === AssistantStreamEvents.ThreadRunRequiresAction) { + await this.onRunRequiresAction(event); + } else if (event.event === AssistantStreamEvents.ThreadRunCompleted) { + logger.debug('Run completed:', this.run); + } + } + + /** + * Handle Run Step Events + * @param {ThreadRunStepCreated | ThreadRunStepInProgress | ThreadRunStepCompleted | ThreadRunStepFailed | ThreadRunStepCancelled | ThreadRunStepExpired} event - + * The run step event object. + */ + async handleRunStepEvent(event) { + logger.debug('Run step event:', event.data); + + const step = event.data; + this.steps.set(step.id, step); + + if (event.event === AssistantStreamEvents.ThreadRunStepCreated) { + this.onRunStepCreated(event); + } else if (event.event === AssistantStreamEvents.ThreadRunStepCompleted) { + this.onRunStepCompleted(event); + } + } + + /* <------------------ Delta Events ------------------> */ + + /** @param {CodeImageOutput} */ + async handleCodeImageOutput(output) { + if (this.processedFileIds.has(output.image?.file_id)) { + return; + } + + const { file_id } = output.image; + const file = await retrieveAndProcessFile({ + openai: this.openai, + client: this, + file_id, + basename: `${file_id}.png`, + }); + + const prelimImage = file; + + // check if every key has a value before adding to content + const prelimImageKeys = Object.keys(prelimImage); + const validImageFile = prelimImageKeys.every((key) => prelimImage[key]); + + if (!validImageFile) { + return; + } + + const index = this.getStepIndex(file_id); + const image_file = { + [ContentTypes.IMAGE_FILE]: prelimImage, + type: ContentTypes.IMAGE_FILE, + index, + }; + this.addContentData(image_file); + this.processedFileIds.add(file_id); + } + + /** + * Create Tool Call Stream + * @param {number} index - The index of the tool call. + * @param {StepToolCall} toolCall - + * The current tool call object. + */ + createToolCallStream(index, toolCall) { + /** @type {StepToolCall} */ + const state = toolCall; + const type = state.type; + const data = state[type]; + + /** @param {ToolCallDelta} */ + const deltaHandler = async (delta) => { + for (const key in delta) { + if (!Object.prototype.hasOwnProperty.call(data, key)) { + logger.warn(`Unhandled tool call key "${key}", delta: `, delta); + continue; + } + + if (Array.isArray(delta[key])) { + if (!Array.isArray(data[key])) { + data[key] = []; + } + + for (const d of delta[key]) { + if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) { + logger.warn('Expected an object with an \'index\' for array updates but got:', d); + continue; + } + + const imageOutput = type === ToolCallTypes.CODE_INTERPRETER && d?.type === 'image'; + + if (imageOutput) { + await this.handleCodeImageOutput(d); + continue; + } + + const { index, ...updateData } = d; + // Ensure the data at index is an object or undefined before assigning + if (typeof data[key][index] !== 'object' || data[key][index] === null) { + data[key][index] = {}; + } + // Merge the updateData into data[key][index] + for (const updateKey in updateData) { + data[key][index][updateKey] = updateData[updateKey]; + } + } + } else if (typeof delta[key] === 'string' && typeof data[key] === 'string') { + // Concatenate strings + data[key] += delta[key]; + } else if ( + typeof delta[key] === 'object' && + delta[key] !== null && + !Array.isArray(delta[key]) + ) { + // Merge objects + data[key] = { ...data[key], ...delta[key] }; + } else { + // Directly set the value for other types + data[key] = delta[key]; + } + + state[type] = data; + + this.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + type: ContentTypes.TOOL_CALL, + index, + }); + } + }; + + return deltaHandler; + } + + /** + * @param {string} stepId - + * @param {StepToolCall} toolCall - + * + */ + handleNewToolCall(stepId, toolCall) { + const stepKey = this.generateToolCallKey(stepId, toolCall); + const index = this.getStepIndex(stepKey); + this.getStepIndex(toolCall.id, index); + toolCall.progress = 0.01; + this.orderedRunSteps.set(index, toolCall); + const progressCallback = this.createToolCallStream(index, toolCall); + this.progressCallbacks.set(stepKey, progressCallback); + + this.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + type: ContentTypes.TOOL_CALL, + index, + }); + } + + /** + * Handle Completed Tool Call + * @param {string} stepId - The id of the step the tool_call is part of. + * @param {StepToolCall} toolCall - The tool call object. + * + */ + handleCompletedToolCall(stepId, toolCall) { + if (toolCall.type === ToolCallTypes.FUNCTION) { + return; + } + + const stepKey = this.generateToolCallKey(stepId, toolCall); + const index = this.getStepIndex(stepKey); + toolCall.progress = 1; + this.orderedRunSteps.set(index, toolCall); + this.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + type: ContentTypes.TOOL_CALL, + index, + }); + } + + /** + * Handle Run Step Delta Event + * @param {ThreadRunStepDelta} event - + * The run step delta event object. + */ + async handleRunStepDeltaEvent(event) { + const { delta, id: stepId } = event.data; + + if (!delta.step_details) { + logger.warn('Undefined or unhandled run step delta:', delta); + return; + } + + /** @type {{ tool_calls: Array<ToolCallDeltaObject> }} */ + const { tool_calls } = delta.step_details; + + if (!tool_calls) { + logger.warn('Unhandled run step details', delta.step_details); + return; + } + + for (const toolCall of tool_calls) { + const stepKey = this.generateToolCallKey(stepId, toolCall); + + if (!this.mappedOrder.has(stepKey)) { + this.handleNewToolCall(stepId, toolCall); + continue; + } + + const toolCallDelta = toolCall[toolCall.type]; + const progressCallback = this.progressCallbacks.get(stepKey); + await progressCallback(toolCallDelta); + } + } + + /** + * Handle Message Delta Event + * @param {ThreadMessageDelta} event - + * The Message Delta event object. + */ + async handleMessageDeltaEvent(event) { + const message = event.data; + const onProgress = this.progressCallbacks.get(message.id); + const content = message.delta.content?.[0]; + + if (content && content.type === MessageContentTypes.TEXT) { + onProgress(content.text.value); + } + } + + /** + * Handle Error Event + * @param {ErrorEvent} event - + * The Error event object. + */ + async handleErrorEvent(event) { + logger.error('Error event:', event.data); + } + + /* <------------------ Misc. Helpers ------------------> */ + + /** + * Gets the step index for a given step key, creating a new index if it doesn't exist. + * @param {string} stepKey - + * The access key for the step. Either a message.id, tool_call key, or file_id. + * @param {number | undefined} [overrideIndex] - An override index to use an alternative stepKey. + * This is necessary due to the toolCall Id being unavailable in delta stream events. + * @returns {number | undefined} index - The index of the step; `undefined` if invalid key or using overrideIndex. + */ + getStepIndex(stepKey, overrideIndex) { + if (!stepKey) { + return; + } + + if (!isNaN(overrideIndex)) { + this.mappedOrder.set(stepKey, overrideIndex); + return; + } + + let index = this.mappedOrder.get(stepKey); + + if (index === undefined) { + index = this.index; + this.mappedOrder.set(stepKey, this.index); + this.index++; + } + + return index; + } + + /** + * Generate Tool Call Key + * @param {string} stepId - The id of the step the tool_call is part of. + * @param {StepToolCall} toolCall - The tool call object. + * @returns {string} key - The generated key for the tool call. + */ + generateToolCallKey(stepId, toolCall) { + return `${stepId}_tool_call_${toolCall.index}_${toolCall.type}`; + } + + /* <------------------ Run Event handlers ------------------> */ + + /** + * Handle Run Events Requiring Action + * @param {ThreadRunRequiresAction} event - + * The run event object requiring action. + */ + async onRunRequiresAction(event) { + const run = event.data; + const { submit_tool_outputs } = run.required_action; + const actions = submit_tool_outputs.tool_calls.map((item) => { + const functionCall = item.function; + const args = JSON.parse(functionCall.arguments); + return { + tool: functionCall.name, + toolInput: args, + toolCallId: item.id, + run_id: run.id, + thread_id: this.thread_id, + }; + }); + + const { tool_outputs } = await processRequiredActions(this, actions); + /** @type {AssistantStream | undefined} */ + let toolRun; + try { + toolRun = this.openai.beta.threads.runs.submitToolOutputsStream( + run.thread_id, + run.id, + { + tool_outputs, + stream: true, + }, + this.streamOptions, + ); + } catch (error) { + logger.error('Error submitting tool outputs:', error); + throw error; + } + + for await (const event of toolRun) { + await this.handleEvent(event); + } + } + + /* <------------------ RunStep Event handlers ------------------> */ + + /** + * Handle Run Step Created Events + * @param {ThreadRunStepCreated} event - + * The created run step event object. + */ + async onRunStepCreated(event) { + const step = event.data; + const isMessage = step.type === StepTypes.MESSAGE_CREATION; + + if (isMessage) { + /** @type {MessageCreationStepDetails} */ + const { message_creation } = step.step_details; + const stepKey = message_creation.message_id; + const index = this.getStepIndex(stepKey); + this.orderedRunSteps.set(index, message_creation); + // Create the Factory Function to stream the message + const { onProgress: progressCallback } = createOnProgress({ + // todo: add option to save partialText to db + // onProgress: () => {}, + }); + + // This creates a function that attaches all of the parameters + // specified here to each SSE message generated by the TextStream + const onProgress = progressCallback({ + index, + res: this.res, + messageId: this.finalMessage.messageId, + conversationId: this.finalMessage.conversationId, + thread_id: this.thread_id, + type: ContentTypes.TEXT, + }); + + this.progressCallbacks.set(stepKey, onProgress); + this.orderedRunSteps.set(index, step); + return; + } + + if (step.type !== StepTypes.TOOL_CALLS) { + logger.warn('Unhandled step creation type:', step.type); + return; + } + + /** @type {{ tool_calls: StepToolCall[] }} */ + const { tool_calls } = step.step_details; + for (const toolCall of tool_calls) { + this.handleNewToolCall(step.id, toolCall); + } + } + + /** + * Handle Run Step Completed Events + * @param {ThreadRunStepCompleted} event - + * The completed run step event object. + */ + async onRunStepCompleted(event) { + const step = event.data; + const isMessage = step.type === StepTypes.MESSAGE_CREATION; + + if (isMessage) { + logger.debug('RunStep Message completion: to be handled by Message Event.', step); + return; + } + + /** @type {{ tool_calls: StepToolCall[] }} */ + const { tool_calls } = step.step_details; + for (let i = 0; i < tool_calls.length; i++) { + const toolCall = tool_calls[i]; + toolCall.index = i; + this.handleCompletedToolCall(step.id, toolCall); + } + } + + /* <------------------ Message Event handlers ------------------> */ + + /** + * Handle Message Event + * @param {ThreadMessageCreated | ThreadMessageInProgress | ThreadMessageCompleted | ThreadMessageIncomplete} event - + * The Message event object. + */ + async handleMessageEvent(event) { + if (event.event === AssistantStreamEvents.ThreadMessageCompleted) { + await this.messageCompleted(event); + } + } + + /** + * Handle Message Completed Events + * @param {ThreadMessageCompleted} event - + * The Completed Message event object. + */ + async messageCompleted(event) { + const message = event.data; + const result = await processMessages({ + openai: this.openai, + client: this, + messages: [message], + }); + const index = this.mappedOrder.get(message.id); + this.addContentData({ + [ContentTypes.TEXT]: { value: result.text }, + type: ContentTypes.TEXT, + edited: result.edited, + index, + }); + this.messages.push(message); + } +} + +module.exports = StreamRunManager; diff --git a/api/server/services/Runs/handle.js b/api/server/services/Runs/handle.js new file mode 100644 index 00000000000..8b73b099eec --- /dev/null +++ b/api/server/services/Runs/handle.js @@ -0,0 +1,264 @@ +const { RunStatus, defaultOrderQuery, CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const { retrieveRun } = require('./methods'); +const { sleep } = require('~/server/utils'); +const RunManager = require('./RunManager'); +const { logger } = require('~/config'); + +async function withTimeout(promise, timeoutMs, timeoutMessage) { + let timeoutHandle; + + const timeoutPromise = new Promise((_, reject) => { + timeoutHandle = setTimeout(() => { + logger.debug(timeoutMessage); + reject(new Error('Operation timed out')); + }, timeoutMs); + }); + + try { + return await Promise.race([promise, timeoutPromise]); + } finally { + clearTimeout(timeoutHandle); + } +} + +/** + * Creates a run on a thread using the OpenAI API. + * + * @param {Object} params - The parameters for creating a run. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - The ID of the thread to run. + * @param {Object} params.body - The body of the request to create a run. + * @param {string} params.body.assistant_id - The ID of the assistant to use for this run. + * @param {string} [params.body.model] - Optional. The ID of the model to be used for this run. + * @param {string} [params.body.instructions] - Optional. Override the default system message of the assistant. + * @param {string} [params.body.additional_instructions] - Optional. Appends additional instructions + * at theend of the instructions for the run. This is useful for modifying + * the behavior on a per-run basis without overriding other instructions. + * @param {Object[]} [params.body.tools] - Optional. Override the tools the assistant can use for this run. + * @param {string[]} [params.body.file_ids] - Optional. + * List of File IDs the assistant can use for this run. + * + * **Note:** The API seems to prefer files added to messages, not runs. + * @param {Object} [params.body.metadata] - Optional. Metadata for the run. + * @return {Promise<Run>} A promise that resolves to the created run object. + */ +async function createRun({ openai, thread_id, body }) { + return await openai.beta.threads.runs.create(thread_id, body); +} + +/** + * Waits for a run to complete by repeatedly checking its status. It uses a RunManager instance to fetch and manage run steps based on the run status. + * + * @param {Object} params - The parameters for the waitForRun function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.run_id - The ID of the run to wait for. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @param {RunManager} params.runManager - The RunManager instance to manage run steps. + * @param {number} [params.pollIntervalMs=750] - The interval for polling the run status; default is 750 milliseconds. + * @param {number} [params.timeout=180000] - The period to wait until timing out polling; default is 3 minutes (in ms). + * @return {Promise<Run>} A promise that resolves to the last fetched run object. + */ +async function waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs = 750, + timeout = 60000 * 3, +}) { + let timeElapsed = 0; + let run; + + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const cacheKey = `${openai.req.user.id}:${openai.responseMessage.conversationId}`; + + let i = 0; + let lastSeenStatus = null; + const runIdLog = `run_id: ${run_id}`; + const runInfo = `user: ${openai.req.user.id} | thread_id: ${thread_id} | ${runIdLog}`; + const raceTimeoutMs = 3000; + let maxRetries = 5; + while (timeElapsed < timeout) { + i++; + logger.debug(`[heartbeat ${i}] ${runIdLog} | Retrieving run status...`); + let updatedRun; + + let attempt = 0; + let startTime = Date.now(); + while (!updatedRun && attempt < maxRetries) { + try { + updatedRun = await withTimeout( + retrieveRun({ thread_id, run_id, timeout: raceTimeoutMs, openai }), + raceTimeoutMs, + `[heartbeat ${i}] ${runIdLog} | Run retrieval timed out after ${raceTimeoutMs} ms. Trying again (attempt ${ + attempt + 1 + } of ${maxRetries})...`, + ); + const endTime = Date.now(); + logger.debug( + `[heartbeat ${i}] ${runIdLog} | Elapsed run retrieval time: ${endTime - startTime}`, + ); + } catch (error) { + attempt++; + startTime = Date.now(); + logger.warn(`${runIdLog} | Error retrieving run status`, error); + } + } + + if (!updatedRun) { + const errorMessage = `[waitForRun] ${runIdLog} | Run retrieval failed after ${maxRetries} attempts`; + throw new Error(errorMessage); + } + + run = updatedRun; + attempt = 0; + const runStatus = `${runInfo} | status: ${run.status}`; + + if (run.status !== lastSeenStatus) { + logger.debug(`[${run.status}] ${runInfo}`); + lastSeenStatus = run.status; + } + + logger.debug(`[heartbeat ${i}] ${runStatus}`); + + let cancelStatus; + try { + const timeoutMessage = `[heartbeat ${i}] ${runIdLog} | Cancel Status check operation timed out.`; + cancelStatus = await withTimeout(cache.get(cacheKey), raceTimeoutMs, timeoutMessage); + } catch (error) { + logger.warn(`Error retrieving cancel status: ${error}`); + } + + if (cancelStatus === 'cancelled') { + logger.warn(`[waitForRun] ${runStatus} | RUN CANCELLED`); + throw new Error('Run cancelled'); + } + + if (![RunStatus.IN_PROGRESS, RunStatus.QUEUED].includes(run.status)) { + logger.debug(`[FINAL] ${runInfo} | status: ${run.status}`); + await runManager.fetchRunSteps({ + openai, + thread_id: thread_id, + run_id: run_id, + runStatus: run.status, + final: true, + }); + break; + } + + // may use in future; for now, just fetch from the final status + await runManager.fetchRunSteps({ + openai, + thread_id: thread_id, + run_id: run_id, + runStatus: run.status, + }); + + await sleep(pollIntervalMs); + timeElapsed += pollIntervalMs; + } + + if (timeElapsed >= timeout) { + const timeoutMessage = `[waitForRun] ${runInfo} | status: ${run.status} | timed out after ${timeout} ms`; + logger.warn(timeoutMessage); + throw new Error(timeoutMessage); + } + + return run; +} + +/** + * Retrieves all steps of a run. + * + * @deprecated: Steps are handled with runAssistant now. + * @param {Object} params - The parameters for the retrieveRunSteps function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @param {string} params.run_id - The ID of the run to retrieve steps for. + * @return {Promise<RunStep[]>} A promise that resolves to an array of RunStep objects. + */ +async function _retrieveRunSteps({ openai, thread_id, run_id }) { + const runSteps = await openai.beta.threads.runs.steps.list(thread_id, run_id); + return runSteps; +} + +/** + * Initializes a RunManager with handlers, then invokes waitForRun to monitor and manage an OpenAI run. + * + * @deprecated Use runAssistant instead. + * @param {Object} params - The parameters for managing and monitoring the run. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.run_id - The ID of the run to manage and monitor. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @return {Promise<Object>} A promise that resolves to an object containing the run and managed steps. + */ +async function _handleRun({ openai, run_id, thread_id }) { + let steps = []; + let messages = []; + const runManager = new RunManager({ + // 'in_progress': async ({ step, final, isLast }) => { + // // Define logic for handling steps with 'in_progress' status + // }, + // 'queued': async ({ step, final, isLast }) => { + // // Define logic for handling steps with 'queued' status + // }, + final: async ({ step, runStatus, stepsByStatus }) => { + console.log(`Final step for ${run_id} with status ${runStatus}`); + console.dir(step, { depth: null }); + + const promises = []; + promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery)); + + // const finalSteps = stepsByStatus[runStatus]; + // for (const stepPromise of finalSteps) { + // promises.push(stepPromise); + // } + + // loop across all statuses + for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { + promises.push(...stepsPromises); + } + + const resolved = await Promise.all(promises); + const res = resolved.shift(); + messages = res.data.filter((msg) => msg.run_id === run_id); + resolved.push(step); + steps = resolved; + }, + }); + + const run = await waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs: 750, + timeout: 60000, + }); + const actions = []; + if (run.required_action) { + const { submit_tool_outputs } = run.required_action; + submit_tool_outputs.tool_calls.forEach((item) => { + const functionCall = item.function; + const args = JSON.parse(functionCall.arguments); + actions.push({ + tool: functionCall.name, + toolInput: args, + toolCallId: item.id, + run_id, + thread_id, + }); + }); + } + + return { run, steps, messages, actions }; +} + +module.exports = { + sleep, + createRun, + waitForRun, + // _handleRun, + // retrieveRunSteps, +}; diff --git a/api/server/services/Runs/index.js b/api/server/services/Runs/index.js new file mode 100644 index 00000000000..7327b271ff9 --- /dev/null +++ b/api/server/services/Runs/index.js @@ -0,0 +1,11 @@ +const handle = require('./handle'); +const methods = require('./methods'); +const RunManager = require('./RunManager'); +const StreamRunManager = require('./StreamRunManager'); + +module.exports = { + ...handle, + ...methods, + RunManager, + StreamRunManager, +}; diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js new file mode 100644 index 00000000000..c6dfcbeddeb --- /dev/null +++ b/api/server/services/Runs/methods.js @@ -0,0 +1,63 @@ +const axios = require('axios'); +const { EModelEndpoint } = require('librechat-data-provider'); +const { logAxiosError } = require('~/utils'); + +/** + * @typedef {Object} RetrieveOptions + * @property {string} thread_id - The ID of the thread to retrieve. + * @property {string} run_id - The ID of the run to retrieve. + * @property {number} [timeout] - Optional timeout for the API call. + * @property {number} [maxRetries] - TODO: not yet implemented; Optional maximum number of retries for the API call. + * @property {OpenAIClient} openai - Configuration and credentials for OpenAI API access. + */ + +/** + * Asynchronously retrieves data from an API endpoint based on provided thread and run IDs. + * + * @param {RetrieveOptions} options - The options for the retrieve operation. + * @returns {Promise<Object>} The data retrieved from the API. + */ +async function retrieveRun({ thread_id, run_id, timeout, openai }) { + const { apiKey, baseURL, httpAgent, organization } = openai; + let url = `${baseURL}/threads/${thread_id}/runs/${run_id}`; + + let headers = { + Authorization: `Bearer ${apiKey}`, + 'OpenAI-Beta': 'assistants=v1', + }; + + if (organization) { + headers['OpenAI-Organization'] = organization; + } + + /** @type {TAzureConfig | undefined} */ + const azureConfig = openai.req.app.locals[EModelEndpoint.azureOpenAI]; + + if (azureConfig && azureConfig.assistants) { + delete headers.Authorization; + headers = { ...headers, ...openai._options.defaultHeaders }; + const queryParams = new URLSearchParams(openai._options.defaultQuery).toString(); + url = `${url}?${queryParams}`; + } + + try { + const axiosConfig = { + headers: headers, + timeout: timeout, + }; + + if (httpAgent) { + axiosConfig.httpAgent = httpAgent; + axiosConfig.httpsAgent = httpAgent; + } + + const response = await axios.get(url, axiosConfig); + return response.data; + } catch (error) { + const message = '[retrieveRun] Failed to retrieve run data:'; + logAxiosError({ message, error }); + throw error; + } +} + +module.exports = { retrieveRun }; diff --git a/api/server/services/Threads/index.js b/api/server/services/Threads/index.js new file mode 100644 index 00000000000..850cddc4e15 --- /dev/null +++ b/api/server/services/Threads/index.js @@ -0,0 +1,5 @@ +const manage = require('./manage'); + +module.exports = { + ...manage, +}; diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js new file mode 100644 index 00000000000..f875b108412 --- /dev/null +++ b/api/server/services/Threads/manage.js @@ -0,0 +1,657 @@ +const path = require('path'); +const { v4 } = require('uuid'); +const { + Constants, + ContentTypes, + EModelEndpoint, + AnnotationTypes, + defaultOrderQuery, +} = require('librechat-data-provider'); +const { retrieveAndProcessFile } = require('~/server/services/Files/process'); +const { recordMessage, getMessages } = require('~/models/Message'); +const { saveConvo } = require('~/models/Conversation'); +const spendTokens = require('~/models/spendTokens'); +const { countTokens } = require('~/server/utils'); +const { logger } = require('~/config'); + +/** + * Initializes a new thread or adds messages to an existing thread. + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {Object} params.body - The body of the request. + * @param {ThreadMessage[]} params.body.messages - A list of messages to start the thread with. + * @param {Object} [params.body.metadata] - Optional metadata for the thread. + * @param {string} [params.thread_id] - Optional existing thread ID. If provided, a message will be added to this thread. + * @return {Promise<Thread>} A promise that resolves to the newly created thread object or the updated thread object. + */ +async function initThread({ openai, body, thread_id: _thread_id }) { + let thread = {}; + const messages = []; + if (_thread_id) { + const message = await openai.beta.threads.messages.create(_thread_id, body.messages[0]); + messages.push(message); + } else { + thread = await openai.beta.threads.create(body); + } + + const thread_id = _thread_id ?? thread.id; + return { messages, thread_id, ...thread }; +} + +/** + * Saves a user message to the DB in the Assistants endpoint format. + * + * @param {Object} params - The parameters of the user message + * @param {string} params.user - The user's ID. + * @param {string} params.text - The user's prompt. + * @param {string} params.messageId - The user message Id. + * @param {string} params.model - The model used by the assistant. + * @param {string} params.assistant_id - The current assistant Id. + * @param {string} params.thread_id - The thread Id. + * @param {string} params.conversationId - The message's conversationId + * @param {string} [params.parentMessageId] - Optional if initial message. + * Defaults to Constants.NO_PARENT. + * @param {string} [params.instructions] - Optional: from preset for `instructions` field. + * Overrides the instructions of the assistant. + * @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field. + * @param {import('librechat-data-provider').TFile[]} [params.files] - Optional. List of Attached File Objects. + * @param {string[]} [params.file_ids] - Optional. List of File IDs attached to the userMessage. + * @return {Promise<Run>} A promise that resolves to the created run object. + */ +async function saveUserMessage(params) { + const tokenCount = await countTokens(params.text); + + // todo: do this on the frontend + // const { file_ids = [] } = params; + // let content; + // if (file_ids.length) { + // content = [ + // { + // value: params.text, + // }, + // ...( + // file_ids + // .filter(f => f) + // .map((file_id) => ({ + // file_id, + // })) + // ), + // ]; + // } + + const userMessage = { + user: params.user, + endpoint: EModelEndpoint.assistants, + messageId: params.messageId, + conversationId: params.conversationId, + parentMessageId: params.parentMessageId ?? Constants.NO_PARENT, + /* For messages, use the assistant_id instead of model */ + model: params.assistant_id, + thread_id: params.thread_id, + sender: 'User', + text: params.text, + isCreatedByUser: true, + tokenCount, + }; + + const convo = { + endpoint: EModelEndpoint.assistants, + conversationId: params.conversationId, + promptPrefix: params.promptPrefix, + instructions: params.instructions, + assistant_id: params.assistant_id, + model: params.model, + }; + + if (params.files?.length) { + userMessage.files = params.files.map(({ file_id }) => ({ file_id })); + convo.file_ids = params.file_ids; + } + + const message = await recordMessage(userMessage); + await saveConvo(params.user, convo); + + return message; +} + +/** + * Saves an Assistant message to the DB in the Assistants endpoint format. + * + * @param {Object} params - The parameters of the Assistant message + * @param {string} params.user - The user's ID. + * @param {string} params.messageId - The message Id. + * @param {string} params.assistant_id - The assistant Id. + * @param {string} params.thread_id - The thread Id. + * @param {string} params.model - The model used by the assistant. + * @param {ContentPart[]} params.content - The message content parts. + * @param {string} params.conversationId - The message's conversationId + * @param {string} params.parentMessageId - The latest user message that triggered this response. + * @param {string} [params.instructions] - Optional: from preset for `instructions` field. + * Overrides the instructions of the assistant. + * @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field. + * @return {Promise<Run>} A promise that resolves to the created run object. + */ +async function saveAssistantMessage(params) { + const text = params.content.reduce((acc, part) => { + if (!part.value) { + return acc; + } + + return acc + ' ' + part.value; + }, ''); + + // const tokenCount = // TODO: need to count each content part + + const message = await recordMessage({ + user: params.user, + endpoint: EModelEndpoint.assistants, + messageId: params.messageId, + conversationId: params.conversationId, + parentMessageId: params.parentMessageId, + thread_id: params.thread_id, + /* For messages, use the assistant_id instead of model */ + model: params.assistant_id, + content: params.content, + sender: 'Assistant', + isCreatedByUser: false, + text: text.trim(), + // tokenCount, + }); + + await saveConvo(params.user, { + endpoint: EModelEndpoint.assistants, + conversationId: params.conversationId, + promptPrefix: params.promptPrefix, + instructions: params.instructions, + assistant_id: params.assistant_id, + model: params.model, + }); + + return message; +} + +/** + * Records LibreChat messageId to all response messages' metadata + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - Response thread ID. + * @param {string} params.messageId - The response `messageId` generated by LibreChat. + * @param {StepMessage[] | Message[]} params.messages - A list of messages to start the thread with. + * @return {Promise<ThreadMessage[]>} A promise that resolves to the updated messages + */ +async function addThreadMetadata({ openai, thread_id, messageId, messages }) { + const promises = []; + for (const message of messages) { + promises.push( + openai.beta.threads.messages.update(thread_id, message.id, { + metadata: { + messageId, + }, + }), + ); + } + + return await Promise.all(promises); +} + +/** + * Synchronizes LibreChat messages to Thread Messages. + * Updates the LibreChat DB with any missing Thread Messages and + * updates the missing Thread Messages' metadata with their corresponding db messageId's. + * + * Also updates the existing conversation's file_ids with any new file_ids. + * + * @param {Object} params - The parameters for synchronizing messages. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {TMessage[]} params.dbMessages - The LibreChat DB messages. + * @param {ThreadMessage[]} params.apiMessages - The thread messages from the API. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.thread_id - The current thread ID. + * @param {string} [params.assistant_id] - The current assistant ID. + * @return {Promise<TMessage[]>} A promise that resolves to the updated messages + */ +async function syncMessages({ + openai, + apiMessages, + dbMessages, + conversationId, + thread_id, + assistant_id, +}) { + let result = []; + let dbMessageMap = new Map(dbMessages.map((msg) => [msg.messageId, msg])); + + const modifyPromises = []; + const recordPromises = []; + + /** + * + * Modify API message and save newMessage to DB + * + * @param {Object} params - The parameters object + * @param {TMessage} params.dbMessage + * @param {dbMessage} params.apiMessage + */ + const processNewMessage = async ({ dbMessage, apiMessage }) => { + recordPromises.push(recordMessage({ ...dbMessage, user: openai.req.user.id })); + + if (!apiMessage.id.includes('msg_')) { + return; + } + + if (dbMessage.aggregateMessages?.length > 1) { + modifyPromises.push( + addThreadMetadata({ + openai, + thread_id, + messageId: dbMessage.messageId, + messages: dbMessage.aggregateMessages, + }), + ); + return; + } + + modifyPromises.push( + openai.beta.threads.messages.update(thread_id, apiMessage.id, { + metadata: { + messageId: dbMessage.messageId, + }, + }), + ); + }; + + let lastMessage = null; + + for (let i = 0; i < apiMessages.length; i++) { + const apiMessage = apiMessages[i]; + + // Check if the message exists in the database based on metadata + const dbMessageId = apiMessage.metadata && apiMessage.metadata.messageId; + let dbMessage = dbMessageMap.get(dbMessageId); + + if (dbMessage) { + // If message exists in DB, use its messageId and update parentMessageId + dbMessage.parentMessageId = lastMessage ? lastMessage.messageId : Constants.NO_PARENT; + lastMessage = dbMessage; + result.push(dbMessage); + continue; + } + + if (apiMessage.role === 'assistant' && lastMessage && lastMessage.role === 'assistant') { + // Aggregate assistant messages + lastMessage.content = [...lastMessage.content, ...apiMessage.content]; + lastMessage.files = [...(lastMessage.files ?? []), ...(apiMessage.files ?? [])]; + lastMessage.aggregateMessages.push({ id: apiMessage.id }); + } else { + // Handle new or missing message + const newMessage = { + thread_id, + conversationId, + messageId: v4(), + endpoint: EModelEndpoint.assistants, + parentMessageId: lastMessage ? lastMessage.messageId : Constants.NO_PARENT, + role: apiMessage.role, + isCreatedByUser: apiMessage.role === 'user', + // TODO: process generated files in content parts + content: apiMessage.content, + aggregateMessages: [{ id: apiMessage.id }], + model: apiMessage.role === 'user' ? null : apiMessage.assistant_id, + user: openai.req.user.id, + }; + + if (apiMessage.file_ids?.length) { + // TODO: retrieve file objects from API + newMessage.files = apiMessage.file_ids.map((file_id) => ({ file_id })); + } + + /* Assign assistant_id if defined */ + if (assistant_id && apiMessage.role === 'assistant' && !newMessage.model) { + apiMessage.model = assistant_id; + newMessage.model = assistant_id; + } + + result.push(newMessage); + lastMessage = newMessage; + + if (apiMessage.role === 'user') { + processNewMessage({ dbMessage: newMessage, apiMessage }); + continue; + } + } + + const nextMessage = apiMessages[i + 1]; + const processAssistant = !nextMessage || nextMessage.role === 'user'; + + if (apiMessage.role === 'assistant' && processAssistant) { + processNewMessage({ dbMessage: lastMessage, apiMessage }); + } + } + + const attached_file_ids = apiMessages.reduce((acc, msg) => { + if (msg.role === 'user' && msg.file_ids?.length) { + return [...acc, ...msg.file_ids]; + } + + return acc; + }, []); + + await Promise.all(modifyPromises); + await Promise.all(recordPromises); + + await saveConvo(openai.req.user.id, { + conversationId, + file_ids: attached_file_ids, + }); + + return result; +} + +/** + * Maps messages to their corresponding steps. Steps with message creation will be paired with their messages, + * while steps without message creation will be returned as is. + * + * @param {RunStep[]} steps - An array of steps from the run. + * @param {Message[]} messages - An array of message objects. + * @returns {(StepMessage | RunStep)[]} An array where each element is either a step with its corresponding message (StepMessage) or a step without a message (RunStep). + */ +function mapMessagesToSteps(steps, messages) { + // Create a map of messages indexed by their IDs for efficient lookup + const messageMap = messages.reduce((acc, msg) => { + acc[msg.id] = msg; + return acc; + }, {}); + + // Map each step to its corresponding message, or return the step as is if no message ID is present + return steps + .sort((a, b) => a.created_at - b.created_at) + .map((step) => { + const messageId = step.step_details?.message_creation?.message_id; + + if (messageId && messageMap[messageId]) { + return { step, message: messageMap[messageId] }; + } + return step; + }); +} + +/** + * Checks for any missing messages; if missing, + * synchronizes LibreChat messages to Thread Messages + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} [params.latestMessageId] - Optional: The latest message ID from LibreChat. + * @param {string} params.thread_id - Response thread ID. + * @param {string} params.run_id - Response Run ID. + * @param {string} params.conversationId - LibreChat conversation ID. + * @return {Promise<TMessage[]>} A promise that resolves to the updated messages + */ +async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, conversationId }) { + const promises = []; + promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery)); + promises.push(openai.beta.threads.runs.steps.list(thread_id, run_id)); + /** @type {[{ data: ThreadMessage[] }, { data: RunStep[] }]} */ + const [response, stepsResponse] = await Promise.all(promises); + + const steps = mapMessagesToSteps(stepsResponse.data, response.data); + /** @type {ThreadMessage} */ + const currentMessage = { + id: v4(), + content: [], + assistant_id: null, + created_at: Math.floor(new Date().getTime() / 1000), + object: 'thread.message', + role: 'assistant', + run_id, + thread_id, + metadata: { + messageId: latestMessageId, + }, + }; + + for (const step of steps) { + if (!currentMessage.assistant_id && step.assistant_id) { + currentMessage.assistant_id = step.assistant_id; + } + if (step.message) { + currentMessage.id = step.message.id; + currentMessage.created_at = step.message.created_at; + currentMessage.content = currentMessage.content.concat(step.message.content); + } else if (step.step_details?.type === 'tool_calls' && step.step_details?.tool_calls?.length) { + currentMessage.content = currentMessage.content.concat( + step.step_details?.tool_calls.map((toolCall) => ({ + [ContentTypes.TOOL_CALL]: { + ...toolCall, + progress: 2, + }, + type: ContentTypes.TOOL_CALL, + })), + ); + } + } + + let addedCurrentMessage = false; + const apiMessages = response.data + .map((msg) => { + if (msg.id === currentMessage.id) { + addedCurrentMessage = true; + return currentMessage; + } + return msg; + }) + .sort((a, b) => new Date(a.created_at) - new Date(b.created_at)); + + if (!addedCurrentMessage) { + apiMessages.push(currentMessage); + } + + const dbMessages = await getMessages({ conversationId }); + const assistant_id = dbMessages?.[0]?.model; + + const syncedMessages = await syncMessages({ + openai, + dbMessages, + apiMessages, + thread_id, + conversationId, + assistant_id, + }); + + return Object.values( + [...dbMessages, ...syncedMessages].reduce( + (acc, message) => ({ ...acc, [message.messageId]: message }), + {}, + ), + ); +} + +/** + * Records token usage for a given completion request. + * @param {Object} params - The parameters for initializing a thread. + * @param {number} params.prompt_tokens - The number of prompt tokens used. + * @param {number} params.completion_tokens - The number of completion tokens used. + * @param {string} params.model - The model used by the assistant run. + * @param {string} params.user - The user's ID. + * @param {string} params.conversationId - LibreChat conversation ID. + * @param {string} [params.context='message'] - The context of the usage. Defaults to 'message'. + * @return {Promise<TMessage[]>} A promise that resolves to the updated messages + */ +const recordUsage = async ({ + prompt_tokens, + completion_tokens, + model, + user, + conversationId, + context = 'message', +}) => { + await spendTokens( + { + user, + model, + context, + conversationId, + }, + { promptTokens: prompt_tokens, completionTokens: completion_tokens }, + ); +}; + +/** + * Safely replaces the annotated text within the specified range denoted by start_index and end_index, + * after verifying that the text within that range matches the given annotation text. + * Proceeds with the replacement even if a mismatch is found, but logs a warning. + * + * @param {string} originalText The original text content. + * @param {number} start_index The starting index where replacement should begin. + * @param {number} end_index The ending index where replacement should end. + * @param {string} expectedText The text expected to be found in the specified range. + * @param {string} replacementText The text to insert in place of the existing content. + * @returns {string} The text with the replacement applied, regardless of text match. + */ +function replaceAnnotation(originalText, start_index, end_index, expectedText, replacementText) { + if (start_index < 0 || end_index > originalText.length || start_index > end_index) { + logger.warn(`Invalid range specified for annotation replacement. + Attempting replacement with \`replace\` method instead... + length: ${originalText.length} + start_index: ${start_index} + end_index: ${end_index}`); + return originalText.replace(originalText, replacementText); + } + + const actualTextInRange = originalText.substring(start_index, end_index); + + if (actualTextInRange !== expectedText) { + logger.warn(`The text within the specified range does not match the expected annotation text. + Attempting replacement with \`replace\` method instead... + Expected: ${expectedText} + Actual: ${actualTextInRange}`); + + return originalText.replace(originalText, replacementText); + } + + const beforeText = originalText.substring(0, start_index); + const afterText = originalText.substring(end_index); + return beforeText + replacementText + afterText; +} + +/** + * Sorts, processes, and flattens messages to a single string. + * + * @param {object} params - The OpenAI client instance. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {RunClient} params.client - The LibreChat client that manages the run: either refers to `OpenAI` or `StreamRunManager`. + * @param {ThreadMessage[]} params.messages - An array of messages. + * @returns {Promise<{messages: ThreadMessage[], text: string}>} The sorted messages and the flattened text. + */ +async function processMessages({ openai, client, messages = [] }) { + const sorted = messages.sort((a, b) => a.created_at - b.created_at); + + let text = ''; + let edited = false; + const sources = []; + for (const message of sorted) { + message.files = []; + for (const content of message.content) { + const type = content.type; + const contentType = content[type]; + const currentFileId = contentType?.file_id; + + if (type === ContentTypes.IMAGE_FILE && !client.processedFileIds.has(currentFileId)) { + const file = await retrieveAndProcessFile({ + openai, + client, + file_id: currentFileId, + basename: `${currentFileId}.png`, + }); + + client.processedFileIds.add(currentFileId); + message.files.push(file); + continue; + } + + let currentText = contentType?.value ?? ''; + + /** @type {{ annotations: Annotation[] }} */ + const { annotations } = contentType ?? {}; + + // Process annotations if they exist + if (!annotations?.length) { + text += currentText + ' '; + continue; + } + + logger.debug('[processMessages] Processing annotations:', annotations); + for (const annotation of annotations) { + let file; + const type = annotation.type; + const annotationType = annotation[type]; + const file_id = annotationType?.file_id; + const alreadyProcessed = client.processedFileIds.has(file_id); + + const replaceCurrentAnnotation = (replacement = '') => { + currentText = replaceAnnotation( + currentText, + annotation.start_index, + annotation.end_index, + annotation.text, + replacement, + ); + edited = true; + }; + + if (alreadyProcessed) { + const { file_id } = annotationType || {}; + file = await retrieveAndProcessFile({ openai, client, file_id, unknownType: true }); + } else if (type === AnnotationTypes.FILE_PATH) { + const basename = path.basename(annotation.text); + file = await retrieveAndProcessFile({ + openai, + client, + file_id, + basename, + }); + replaceCurrentAnnotation(file.filepath); + } else if (type === AnnotationTypes.FILE_CITATION) { + file = await retrieveAndProcessFile({ + openai, + client, + file_id, + unknownType: true, + }); + sources.push(file.filename); + replaceCurrentAnnotation(`^${sources.length}^`); + } + + text += currentText + ' '; + + if (!file) { + continue; + } + + client.processedFileIds.add(file_id); + message.files.push(file); + } + } + } + + if (sources.length) { + text += '\n\n'; + for (let i = 0; i < sources.length; i++) { + text += `^${i + 1}.^ ${sources[i]}${i === sources.length - 1 ? '' : '\n'}`; + } + } + + return { messages: sorted, text, edited }; +} + +module.exports = { + initThread, + recordUsage, + processMessages, + saveUserMessage, + checkMessageGaps, + addThreadMetadata, + mapMessagesToSteps, + saveAssistantMessage, +}; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js new file mode 100644 index 00000000000..81c6ca42830 --- /dev/null +++ b/api/server/services/ToolService.js @@ -0,0 +1,356 @@ +const fs = require('fs'); +const path = require('path'); +const { StructuredTool } = require('langchain/tools'); +const { zodToJsonSchema } = require('zod-to-json-schema'); +const { Calculator } = require('langchain/tools/calculator'); +const { + Tools, + ContentTypes, + imageGenTools, + actionDelimiter, + ImageVisionTool, + openapiToFunction, + validateAndParseOpenAPISpec, +} = require('librechat-data-provider'); +const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); +const { loadActionSets, createActionTool, domainParser } = require('./ActionService'); +const { recordUsage } = require('~/server/services/Threads'); +const { loadTools } = require('~/app/clients/tools/util'); +const { redactMessage } = require('~/config/parsers'); +const { sleep } = require('~/server/utils'); +const { logger } = require('~/config'); + +/** + * Loads and formats tools from the specified tool directory. + * + * The directory is scanned for JavaScript files, excluding any files in the filter set. + * For each file, it attempts to load the file as a module and instantiate a class, if it's a subclass of `StructuredTool`. + * Each tool instance is then formatted to be compatible with the OpenAI Assistant. + * Additionally, instances of LangChain Tools are included in the result. + * + * @param {object} params - The parameters for the function. + * @param {string} params.directory - The directory path where the tools are located. + * @param {Set<string>} [params.filter=new Set()] - A set of filenames to exclude from loading. + * @returns {Record<string, FunctionTool>} An object mapping each tool's plugin key to its instance. + */ +function loadAndFormatTools({ directory, filter = new Set() }) { + const tools = []; + /* Structured Tools Directory */ + const files = fs.readdirSync(directory); + + for (const file of files) { + if (file.endsWith('.js') && !filter.has(file)) { + const filePath = path.join(directory, file); + let ToolClass = null; + try { + ToolClass = require(filePath); + } catch (error) { + logger.error(`[loadAndFormatTools] Error loading tool from ${filePath}:`, error); + continue; + } + + if (!ToolClass) { + continue; + } + + if (ToolClass.prototype instanceof StructuredTool) { + /** @type {StructuredTool | null} */ + let toolInstance = null; + try { + toolInstance = new ToolClass({ override: true }); + } catch (error) { + logger.error( + `[loadAndFormatTools] Error initializing \`${file}\` tool; if it requires authentication, is the \`override\` field configured?`, + error, + ); + continue; + } + + if (!toolInstance) { + continue; + } + + const formattedTool = formatToOpenAIAssistantTool(toolInstance); + tools.push(formattedTool); + } + } + } + + /** + * Basic Tools; schema: { input: string } + */ + const basicToolInstances = [new Calculator()]; + + for (const toolInstance of basicToolInstances) { + const formattedTool = formatToOpenAIAssistantTool(toolInstance); + tools.push(formattedTool); + } + + tools.push(ImageVisionTool); + + return tools.reduce((map, tool) => { + map[tool.function.name] = tool; + return map; + }, {}); +} + +/** + * Formats a `StructuredTool` instance into a format that is compatible + * with OpenAI's ChatCompletionFunctions. It uses the `zodToJsonSchema` + * function to convert the schema of the `StructuredTool` into a JSON + * schema, which is then used as the parameters for the OpenAI function. + * + * @param {StructuredTool} tool - The StructuredTool to format. + * @returns {FunctionTool} The OpenAI Assistant Tool. + */ +function formatToOpenAIAssistantTool(tool) { + return { + type: Tools.function, + [Tools.function]: { + name: tool.name, + description: tool.description, + parameters: zodToJsonSchema(tool.schema), + }, + }; +} + +/** + * Processes the required actions by calling the appropriate tools and returning the outputs. + * @param {OpenAIClient} client - OpenAI or StreamRunManager Client. + * @param {RequiredAction} requiredActions - The current required action. + * @returns {Promise<ToolOutput>} The outputs of the tools. + */ +const processVisionRequest = async (client, currentAction) => { + if (!client.visionPromise) { + return { + tool_call_id: currentAction.toolCallId, + output: 'No image details found.', + }; + } + + /** @type {ChatCompletion | undefined} */ + const completion = await client.visionPromise; + if (completion.usage) { + recordUsage({ + user: client.req.user.id, + model: client.req.body.model, + conversationId: (client.responseMessage ?? client.finalMessage).conversationId, + ...completion.usage, + }); + } + const output = completion?.choices?.[0]?.message?.content ?? 'No image details found.'; + return { + tool_call_id: currentAction.toolCallId, + output, + }; +}; + +/** + * Processes return required actions from run. + * @param {OpenAIClient | StreamRunManager} client - OpenAI (legacy) or StreamRunManager Client. + * @param {RequiredAction[]} requiredActions - The required actions to submit outputs for. + * @returns {Promise<ToolOutputs>} The outputs of the tools. + */ +async function processRequiredActions(client, requiredActions) { + logger.debug( + `[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, + requiredActions, + ); + const tools = requiredActions.map((action) => action.tool); + const loadedTools = await loadTools({ + user: client.req.user.id, + model: client.req.body.model ?? 'gpt-3.5-turbo-1106', + tools, + functions: true, + options: { + processFileURL, + req: client.req, + uploadImageBuffer, + openAIApiKey: client.apiKey, + fileStrategy: client.req.app.locals.fileStrategy, + returnMetadata: true, + }, + skipSpecs: true, + }); + + const ToolMap = loadedTools.reduce((map, tool) => { + map[tool.name] = tool; + return map; + }, {}); + + const promises = []; + + /** @type {Action[]} */ + let actionSets = []; + let isActionTool = false; + const ActionToolMap = {}; + const ActionBuildersMap = {}; + + for (let i = 0; i < requiredActions.length; i++) { + const currentAction = requiredActions[i]; + if (currentAction.tool === ImageVisionTool.function.name) { + promises.push(processVisionRequest(client, currentAction)); + continue; + } + let tool = ToolMap[currentAction.tool] ?? ActionToolMap[currentAction.tool]; + + const handleToolOutput = async (output) => { + requiredActions[i].output = output; + + /** @type {FunctionToolCall & PartMetadata} */ + const toolCall = { + function: { + name: currentAction.tool, + arguments: JSON.stringify(currentAction.toolInput), + output, + }, + id: currentAction.toolCallId, + type: 'function', + progress: 1, + action: isActionTool, + }; + + const toolCallIndex = client.mappedOrder.get(toolCall.id); + + if (imageGenTools.has(currentAction.tool)) { + const imageOutput = output; + toolCall.function.output = `${currentAction.tool} displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.`; + + // Streams the "Finished" state of the tool call in the UI + client.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + }); + + await sleep(500); + + /** @type {ImageFile} */ + const imageDetails = { + ...imageOutput, + ...currentAction.toolInput, + }; + + const image_file = { + [ContentTypes.IMAGE_FILE]: imageDetails, + type: ContentTypes.IMAGE_FILE, + // Replace the tool call output with Image file + index: toolCallIndex, + }; + + client.addContentData(image_file); + + // Update the stored tool call + client.seenToolCalls && client.seenToolCalls.set(toolCall.id, toolCall); + + return { + tool_call_id: currentAction.toolCallId, + output: toolCall.function.output, + }; + } + + client.seenToolCalls && client.seenToolCalls.set(toolCall.id, toolCall); + client.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + // TODO: to append tool properties to stream, pass metadata rest to addContentData + // result: tool.result, + }); + + return { + tool_call_id: currentAction.toolCallId, + output, + }; + }; + + if (!tool) { + // throw new Error(`Tool ${currentAction.tool} not found.`); + + if (!actionSets.length) { + actionSets = + (await loadActionSets({ + assistant_id: client.req.body.assistant_id, + })) ?? []; + } + + const actionSet = actionSets.find((action) => + currentAction.tool.includes(domainParser(client.req, action.metadata.domain, true)), + ); + + if (!actionSet) { + // TODO: try `function` if no action set is found + // throw new Error(`Tool ${currentAction.tool} not found.`); + continue; + } + + let builders = ActionBuildersMap[actionSet.metadata.domain]; + + if (!builders) { + const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec); + if (!validationResult.spec) { + throw new Error( + `Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, + ); + } + const { requestBuilders } = openapiToFunction(validationResult.spec); + ActionToolMap[actionSet.metadata.domain] = requestBuilders; + builders = requestBuilders; + } + + const functionName = currentAction.tool.replace( + `${actionDelimiter}${domainParser(client.req, actionSet.metadata.domain, true)}`, + '', + ); + const requestBuilder = builders[functionName]; + + if (!requestBuilder) { + // throw new Error(`Tool ${currentAction.tool} not found.`); + continue; + } + + tool = createActionTool({ action: actionSet, requestBuilder }); + isActionTool = !!tool; + ActionToolMap[currentAction.tool] = tool; + } + + if (currentAction.tool === 'calculator') { + currentAction.toolInput = currentAction.toolInput.input; + } + + try { + const promise = tool + ._call(currentAction.toolInput) + .then(handleToolOutput) + .catch((error) => { + logger.error(`Error processing tool ${currentAction.tool}`, error); + return { + tool_call_id: currentAction.toolCallId, + output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message)}`, + }; + }); + promises.push(promise); + } catch (error) { + logger.error( + `tool_call_id: ${currentAction.toolCallId} | Error processing tool ${currentAction.tool}`, + error, + ); + promises.push( + Promise.resolve({ + tool_call_id: currentAction.toolCallId, + error: error.message, + }), + ); + } + } + + return { + tool_outputs: await Promise.all(promises), + }; +} + +module.exports = { + formatToOpenAIAssistantTool, + loadAndFormatTools, + processRequiredActions, +}; diff --git a/api/server/utils/countTokens.js b/api/server/utils/countTokens.js index 34c070aa8c2..641e3861014 100644 --- a/api/server/utils/countTokens.js +++ b/api/server/utils/countTokens.js @@ -3,6 +3,20 @@ const p50k_base = require('tiktoken/encoders/p50k_base.json'); const cl100k_base = require('tiktoken/encoders/cl100k_base.json'); const logger = require('~/config/winston'); +/** + * Counts the number of tokens in a given text using a specified encoding model. + * + * This function utilizes the 'Tiktoken' library to encode text based on the selected model. + * It supports two models, 'text-davinci-003' and 'gpt-3.5-turbo', each with its own encoding strategy. + * For 'text-davinci-003', the 'p50k_base' encoder is used, whereas for other models, the 'cl100k_base' encoder is applied. + * In case of an error during encoding, the error is logged, and the function returns 0. + * + * @async + * @param {string} text - The text to be tokenized. Defaults to an empty string if not provided. + * @param {string} modelName - The name of the model used for tokenizing. Defaults to 'gpt-3.5-turbo'. + * @returns {Promise<number>} The number of tokens in the provided text. Returns 0 if an error occurs. + * @throws Logs the error to a logger and rethrows if any error occurs during tokenization. + */ const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => { let encoder = null; try { diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index 9b5fed67c6b..8989084e5ab 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -19,4 +19,27 @@ function decrypt(encryptedValue) { return decrypted; } -module.exports = { encrypt, decrypt }; +// Programatically generate iv +function encryptV2(value) { + const gen_iv = crypto.randomBytes(16); + const cipher = crypto.createCipheriv(algorithm, key, gen_iv); + let encrypted = cipher.update(value, 'utf8', 'hex'); + encrypted += cipher.final('hex'); + return gen_iv.toString('hex') + ':' + encrypted; +} + +function decryptV2(encryptedValue) { + const parts = encryptedValue.split(':'); + // Already decrypted from an earlier invocation + if (parts.length === 1) { + return parts[0]; + } + const gen_iv = Buffer.from(parts.shift(), 'hex'); + const encrypted = parts.join(':'); + const decipher = crypto.createDecipheriv(algorithm, key, gen_iv); + let decrypted = decipher.update(encrypted, 'hex', 'utf8'); + decrypted += decipher.final('utf8'); + return decrypted; +} + +module.exports = { encrypt, decrypt, encryptV2, decryptV2 }; diff --git a/api/server/utils/emails/passwordReset.handlebars b/api/server/utils/emails/passwordReset.handlebars index 2d0d5426ccd..d41566c598e 100644 --- a/api/server/utils/emails/passwordReset.handlebars +++ b/api/server/utils/emails/passwordReset.handlebars @@ -1,11 +1,186 @@ -<html> - <head> - <style> +<!DOCTYPE HTML + PUBLIC "-//W3C//DTD XHTML 1.0 Transitional //EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> +<html xmlns="http://www.w3.org/1999/xhtml" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office"> + +<head> + <!--[if gte mso 9]> +<xml> +<o:OfficeDocumentSettings> + <o:AllowPNG /> + <o:PixelsPerInch>96</o:PixelsPerInch> +</o:OfficeDocumentSettings> +</xml> +<![endif]--> + <meta http-equiv="Content-Type" content="text/html; charset=UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <meta name="x-apple-disable-message-reformatting"> + <meta name="color-scheme" content="light dark"> + <!--[if !mso]><!--> + <meta http-equiv="X-UA-Compatible" content="IE=edge"> + <!--<![endif]--> + <title> + + + + + + + + + + + + +
+ +
+
+
+ + +
+
+ +
+ + + + + + + +
+
+
Hi {{name}},
+
+
+ + + + + + +
+
+
+
Your password has been updated successfully!
+
+
+
+ + + + + + +
+
+
Best regards,
+
The {{appName}} Team
+
+
+ + + + + + +
+
+
+
© {{year}} {{appName}}. All rights + reserved.
+
+
+
+ +
+ +
+
+ + +
+
+
+ +
+ + + - - - -

Hi {{name}},

-

Your password has been changed successfully.

- \ No newline at end of file diff --git a/api/server/utils/emails/requestPasswordReset.handlebars b/api/server/utils/emails/requestPasswordReset.handlebars index 1bf9853c684..e579ec0d5c4 100644 --- a/api/server/utils/emails/requestPasswordReset.handlebars +++ b/api/server/utils/emails/requestPasswordReset.handlebars @@ -1,13 +1,239 @@ - - - + + + + + + + + + + + +
+ +
+
+
+ + +
+
+ +
+ + + + + + + +
+ +

+
+
You have requested to reset your password. +
+
+

+ +
+ + + + + + +
+
+
Hi {{name}},
+
+
+ + + + + + +
+
+

Please click the button below to reset your password.

+
+
+ + + + + + +
+ + +
+ + + + + + +
+
+
+
If you did not request a password reset, please ignore this email.
+
+
+
+ + + + + + +
+
+
Best regards,
+
The {{appName}} Team
+
+
+ + + + + + +
+
+
+
© {{year}} {{appName}}. All rights + reserved.
+
+
+
+ +
+ +
+
+ + +
+
+
+ +
+ + + - - - -

Hi {{name}},

-

You have requested to reset your password.

-

Please click the link below to reset your password.

- Reset Password - \ No newline at end of file diff --git a/api/server/utils/emails/verifyEmail.handlebars b/api/server/utils/emails/verifyEmail.handlebars new file mode 100644 index 00000000000..2855d4647e4 --- /dev/null +++ b/api/server/utils/emails/verifyEmail.handlebars @@ -0,0 +1,239 @@ + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+
+ + +
+
+ +
+ + + + + + + +
+ +

+
+
Welcome to {{appName}}!
+
+

+ +
+ + + + + + +
+
+
+
Dear {{name}},
+
+
+
+ + + + + + +
+
+
+
Thank you for registering with {{appName}}. To complete your registration and verify your email address, please click the button below:
+
+
+
+ + + + + + +
+ + +
+ + + + + + +
+
+
+
If you did not create an account with {{appName}}, please ignore this email.
+
+
+
+ + + + + + +
+
+
Best regards,
+
The {{appName}} Team
+
+
+ + + + + + +
+
+
+
© {{year}} {{appName}}. All rights + reserved.
+
+
+
+ +
+ +
+
+ + +
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/api/server/utils/files.js b/api/server/utils/files.js new file mode 100644 index 00000000000..63cf95d3ab9 --- /dev/null +++ b/api/server/utils/files.js @@ -0,0 +1,47 @@ +const sharp = require('sharp'); + +/** + * Determines the file type of a buffer + * @param {Buffer} dataBuffer + * @param {boolean} [returnFileType=false] - Optional. If true, returns the file type instead of the file extension. + * @returns {Promise} - Returns the file extension if found, else null + * */ +const determineFileType = async (dataBuffer, returnFileType) => { + const fileType = await import('file-type'); + const type = await fileType.fileTypeFromBuffer(dataBuffer); + if (returnFileType) { + return type; + } + return type ? type.ext : null; // Returns extension if found, else null +}; + +/** + * Get buffer metadata + * @param {Buffer} buffer + * @returns {Promise<{ bytes: number, type: string, dimensions: Record, extension: string}>} + */ +const getBufferMetadata = async (buffer) => { + const fileType = await determineFileType(buffer, true); + const bytes = buffer.length; + let extension = fileType ? fileType.ext : 'unknown'; + + /** @type {Record} */ + let dimensions = {}; + + if (fileType && fileType.mime.startsWith('image/') && extension !== 'unknown') { + const imageMetadata = await sharp(buffer).metadata(); + dimensions = { + width: imageMetadata.width, + height: imageMetadata.height, + }; + } + + return { + bytes, + type: fileType?.mime ?? 'unknown', + dimensions, + extension, + }; +}; + +module.exports = { determineFileType, getBufferMetadata }; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index b8d17106622..bfa37e279f9 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,43 +1,20 @@ +const { Capabilities, defaultRetrievalModels } = require('librechat-data-provider'); +const { getCitations, citeText } = require('./citations'); const partialRight = require('lodash/partialRight'); const { sendMessage } = require('./streamResponse'); -const { getCitations, citeText } = require('./citations'); -const cursor = ''; const citationRegex = /\[\^\d+?\^]/g; const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { let i = 0; - let code = ''; - let precode = ''; - let codeBlock = false; let tokens = addSpaceIfNeeded(generation); const progressCallback = async (partial, { res, text, bing = false, ...rest }) => { let chunk = partial === text ? '' : partial; tokens += chunk; - precode += chunk; tokens = tokens.replaceAll('[DONE]', ''); - if (codeBlock) { - code += chunk; - } - - if (precode.includes('```') && codeBlock) { - codeBlock = false; - precode = precode.replace(/```/g, ''); - code = ''; - } - - if (precode.includes('```') && code === '') { - precode = precode.replace(/```/g, ''); - codeBlock = true; - } - - if (tokens.match(/^\n(?!:::plugins:::)/)) { - tokens = tokens.replace(/^\n/, ''); - } - if (bing) { tokens = citeText(tokens, true); } @@ -51,7 +28,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { const sendIntermediateMessage = (res, payload, extraTokens = '') => { tokens += extraTokens; sendMessage(res, { - text: tokens?.length === 0 ? cursor : tokens, + text: tokens?.length === 0 ? '' : tokens, message: true, initial: i === 0, ...payload, @@ -174,16 +151,35 @@ function isEnabled(value) { const isUserProvided = (value) => value === 'user_provided'; /** - * Extracts the value of an environment variable from a string. - * @param {string} value - The value to be processed, possibly containing an env variable placeholder. - * @returns {string} - The actual value from the environment variable or the original value. + * Generate the configuration for a given key and base URL. + * @param {string} key + * @param {string} baseURL + * @returns {boolean | { userProvide: boolean, userProvideURL?: boolean }} */ -function extractEnvVariable(value) { - const envVarMatch = value.match(/^\${(.+)}$/); - if (envVarMatch) { - return process.env[envVarMatch[1]] || value; +function generateConfig(key, baseURL, assistants = false) { + if (!key) { + return false; + } + + /** @type {{ userProvide: boolean, userProvideURL?: boolean }} */ + const config = { userProvide: isUserProvided(key) }; + + if (baseURL) { + config.userProvideURL = isUserProvided(baseURL); + } + + if (assistants) { + config.retrievalModels = defaultRetrievalModels; + config.capabilities = [ + Capabilities.code_interpreter, + Capabilities.image_vision, + Capabilities.retrieval, + Capabilities.actions, + Capabilities.tools, + ]; } - return value; + + return config; } module.exports = { @@ -194,5 +190,5 @@ module.exports = { formatAction, addSpaceIfNeeded, isUserProvided, - extractEnvVariable, + generateConfig, }; diff --git a/api/server/utils/handleText.spec.js b/api/server/utils/handleText.spec.js index a5566fb1b2b..ea440a89a57 100644 --- a/api/server/utils/handleText.spec.js +++ b/api/server/utils/handleText.spec.js @@ -1,4 +1,4 @@ -const { isEnabled, extractEnvVariable } = require('./handleText'); +const { isEnabled } = require('./handleText'); describe('isEnabled', () => { test('should return true when input is "true"', () => { @@ -48,51 +48,4 @@ describe('isEnabled', () => { test('should return false when input is an array', () => { expect(isEnabled([])).toBe(false); }); - - describe('extractEnvVariable', () => { - const originalEnv = process.env; - - beforeEach(() => { - jest.resetModules(); - process.env = { ...originalEnv }; - }); - - afterAll(() => { - process.env = originalEnv; - }); - - test('should return the value of the environment variable', () => { - process.env.TEST_VAR = 'test_value'; - expect(extractEnvVariable('${TEST_VAR}')).toBe('test_value'); - }); - - test('should return the original string if the envrionment variable is not defined correctly', () => { - process.env.TEST_VAR = 'test_value'; - expect(extractEnvVariable('${ TEST_VAR }')).toBe('${ TEST_VAR }'); - }); - - test('should return the original string if environment variable is not set', () => { - expect(extractEnvVariable('${NON_EXISTENT_VAR}')).toBe('${NON_EXISTENT_VAR}'); - }); - - test('should return the original string if it does not contain an environment variable', () => { - expect(extractEnvVariable('some_string')).toBe('some_string'); - }); - - test('should handle empty strings', () => { - expect(extractEnvVariable('')).toBe(''); - }); - - test('should handle strings without variable format', () => { - expect(extractEnvVariable('no_var_here')).toBe('no_var_here'); - }); - - test('should not process multiple variable formats', () => { - process.env.FIRST_VAR = 'first'; - process.env.SECOND_VAR = 'second'; - expect(extractEnvVariable('${FIRST_VAR} and ${SECOND_VAR}')).toBe( - '${FIRST_VAR} and ${SECOND_VAR}', - ); - }); - }); }); diff --git a/api/server/utils/index.js b/api/server/utils/index.js index d51cdd1d4eb..e87a4680fc9 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -5,6 +5,8 @@ const handleText = require('./handleText'); const cryptoUtils = require('./crypto'); const citations = require('./citations'); const sendEmail = require('./sendEmail'); +const queue = require('./queue'); +const files = require('./files'); const math = require('./math'); module.exports = { @@ -15,5 +17,7 @@ module.exports = { countTokens, removePorts, sendEmail, + ...files, + ...queue, math, }; diff --git a/api/server/utils/queue.js b/api/server/utils/queue.js new file mode 100644 index 00000000000..c32adaeffd8 --- /dev/null +++ b/api/server/utils/queue.js @@ -0,0 +1,69 @@ +/** + * A leaky bucket queue structure to manage API requests. + * @type {{queue: Array, interval: NodeJS.Timer | null}} + */ +const _LB = { + queue: [], + interval: null, +}; + +/** + * Interval in milliseconds to control the rate of API requests. + * Adjust the interval according to your rate limit needs. + */ +const _LB_INTERVAL_MS = Math.ceil(1000 / 60); // 60 req/s + +/** + * Executes the next function in the leaky bucket queue. + * This function is called at regular intervals defined by _LB_INTERVAL_MS. + */ +const _LB_EXEC_NEXT = async () => { + if (_LB.queue.length === 0) { + clearInterval(_LB.interval); + _LB.interval = null; + return; + } + + const next = _LB.queue.shift(); + if (!next) { + return; + } + + const { asyncFunc, args, callback } = next; + + try { + const data = await asyncFunc(...args); + callback(null, data); + } catch (e) { + callback(e); + } +}; + +/** + * Adds an async function call to the leaky bucket queue. + * @param {Function} asyncFunc - The async function to be executed. + * @param {Array} args - Arguments to pass to the async function. + * @param {Function} callback - Callback function for handling the result or error. + */ +function LB_QueueAsyncCall(asyncFunc, args, callback) { + _LB.queue.push({ asyncFunc, args, callback }); + + if (_LB.interval === null) { + _LB.interval = setInterval(_LB_EXEC_NEXT, _LB_INTERVAL_MS); + } +} + +/** + * Delays the execution for a specified number of milliseconds. + * + * @param {number} ms - The number of milliseconds to delay. + * @return {Promise} A promise that resolves after the specified delay. + */ +function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +module.exports = { + sleep, + LB_QueueAsyncCall, +}; diff --git a/api/server/utils/sendEmail.js b/api/server/utils/sendEmail.js index 2f85f89dcd1..1ce33549299 100644 --- a/api/server/utils/sendEmail.js +++ b/api/server/utils/sendEmail.js @@ -2,6 +2,7 @@ const fs = require('fs'); const path = require('path'); const nodemailer = require('nodemailer'); const handlebars = require('handlebars'); +const { isEnabled } = require('~/server/utils/handleText'); const logger = require('~/config/winston'); const sendEmail = async (email, subject, payload, template) => { @@ -13,7 +14,7 @@ const sendEmail = async (email, subject, payload, template) => { requireTls: process.env.EMAIL_ENCRYPTION === 'starttls', tls: { // Whether to accept unsigned certificates - rejectUnauthorized: process.env.EMAIL_ALLOW_SELFSIGNED === 'true', + rejectUnauthorized: !isEnabled(process.env.EMAIL_ALLOW_SELFSIGNED), }, auth: { user: process.env.EMAIL_USERNAME, diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index 3511f144cc7..b7a691d91ae 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -16,12 +16,12 @@ const handleError = (res, message) => { /** * Sends message data in Server Sent Events format. - * @param {object} res - - The server response. - * @param {string} message - The message to be sent. + * @param {Express.Response} res - - The server response. + * @param {string | Object} message - The message to be sent. * @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'. */ const sendMessage = (res, message, event = 'message') => { - if (message.length === 0) { + if (typeof message === 'string' && message.length === 0) { return; } res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); @@ -32,6 +32,13 @@ const sendMessage = (res, message, event = 'message') => { * @async * @param {object} res - The server response. * @param {object} options - The options for handling the error containing message properties. + * @param {object} options.user - The user ID. + * @param {string} options.sender - The sender of the message. + * @param {string} options.conversationId - The conversation ID. + * @param {string} options.messageId - The message ID. + * @param {string} options.parentMessageId - The parent message ID. + * @param {string} options.text - The error message. + * @param {boolean} options.shouldSaveMessage - [Optional] Whether the message should be saved. Default is true. * @param {function} callback - [Optional] The callback function to be executed. */ const sendError = async (res, options, callback) => { @@ -43,7 +50,7 @@ const sendError = async (res, options, callback) => { parentMessageId, text, shouldSaveMessage, - overrideProps = {}, + ...rest } = options; const errorMessage = { sender, @@ -55,7 +62,7 @@ const sendError = async (res, options, callback) => { final: true, text, isCreatedByUser: false, - ...overrideProps, + ...rest, }; if (callback && typeof callback === 'function') { await callback(); @@ -88,7 +95,28 @@ const sendError = async (res, options, callback) => { handleError(res, errorMessage); }; +/** + * Sends the response based on whether headers have been sent or not. + * @param {Express.Response} res - The server response. + * @param {Object} data - The data to be sent. + * @param {string} [errorMessage] - The error message, if any. + */ +const sendResponse = (res, data, errorMessage) => { + if (!res.headersSent) { + if (errorMessage) { + return res.status(500).json({ error: errorMessage }); + } + return res.json(data); + } + + if (errorMessage) { + return sendError(res, { ...data, text: errorMessage }); + } + return sendMessage(res, data); +}; + module.exports = { + sendResponse, handleError, sendMessage, sendError, diff --git a/api/strategies/localStrategy.js b/api/strategies/localStrategy.js index 916766e6287..4408382cc42 100644 --- a/api/strategies/localStrategy.js +++ b/api/strategies/localStrategy.js @@ -1,7 +1,8 @@ +const { errorsToString } = require('librechat-data-provider'); const { Strategy: PassportLocalStrategy } = require('passport-local'); -const User = require('../models/User'); -const { loginSchema, errorsToString } = require('./validators'); -const logger = require('../utils/logger'); +const { loginSchema } = require('./validators'); +const logger = require('~/utils/logger'); +const User = require('~/models/User'); async function validateLoginRequest(req) { const { error } = loginSchema.safeParse(req.body); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 7219f24ba41..b6abd8f7059 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -2,6 +2,7 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); const passport = require('passport'); +const jwtDecode = require('jsonwebtoken/decode'); const { Issuer, Strategy: OpenIDStrategy } = require('openid-client'); const { logger } = require('~/config'); const User = require('~/models/User'); @@ -44,7 +45,9 @@ async function setupOpenId() { client_secret: process.env.OPENID_CLIENT_SECRET, redirect_uris: [process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL], }); - + const requiredRole = process.env.OPENID_REQUIRED_ROLE; + const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; + const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; const openidLogin = new OpenIDStrategy( { client, @@ -71,6 +74,36 @@ async function setupOpenId() { fullName = userinfo.username || userinfo.email; } + if (requiredRole) { + let decodedToken = ''; + if (requiredRoleTokenKind === 'access') { + decodedToken = jwtDecode(tokenset.access_token); + } else if (requiredRoleTokenKind === 'id') { + decodedToken = jwtDecode(tokenset.id_token); + } + const pathParts = requiredRoleParameterPath.split('.'); + let found = true; + let roles = pathParts.reduce((o, key) => { + if (o === null || o === undefined || !(key in o)) { + found = false; + return []; + } + return o[key]; + }, decodedToken); + + if (!found) { + console.error( + `Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`, + ); + } + + if (!roles.includes(requiredRole)) { + return done(null, false, { + message: `You must have the "${requiredRole}" role to log in.`, + }); + } + } + if (!user) { user = new User({ provider: 'openid', diff --git a/api/strategies/process.js b/api/strategies/process.js index f5a12a26a25..9b791023195 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -1,5 +1,6 @@ const { FileSources } = require('librechat-data-provider'); -const uploadAvatar = require('~/server/services/Files/images/avatar'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const User = require('~/models/User'); /** @@ -7,7 +8,7 @@ const User = require('~/models/User'); * '?manual=true', it updates the user's avatar with the provided URL. For local file storage, it directly updates * the avatar URL, while for other storage types, it processes the avatar URL using the specified file strategy. * - * @param {User} oldUser - The existing user object that needs to be updated. Expected to have an 'avatar' property. + * @param {User} oldUser - The existing user object that needs to be updated. * @param {string} avatarUrl - The new avatar URL to be set for the user. * * @returns {Promise} @@ -19,13 +20,17 @@ const handleExistingUser = async (oldUser, avatarUrl) => { const fileStrategy = process.env.CDN_PROVIDER; const isLocal = fileStrategy === FileSources.local; - if (isLocal && !oldUser.avatar.includes('?manual=true')) { + if (isLocal && (oldUser.avatar === null || !oldUser.avatar.includes('?manual=true'))) { oldUser.avatar = avatarUrl; await oldUser.save(); - } else if (!isLocal && !oldUser.avatar.includes('?manual=true')) { + } else if (!isLocal && (oldUser.avatar === null || !oldUser.avatar.includes('?manual=true'))) { const userId = oldUser._id; - const newavatarUrl = await uploadAvatar({ userId, input: avatarUrl, fileStrategy }); - oldUser.avatar = newavatarUrl; + const webPBuffer = await resizeAvatar({ + userId, + input: avatarUrl, + }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + oldUser.avatar = await processAvatar({ buffer: webPBuffer, userId }); await oldUser.save(); } }; @@ -78,8 +83,12 @@ const createNewUser = async ({ if (!isLocal) { const userId = newUser._id; - const newavatarUrl = await uploadAvatar({ userId, input: avatarUrl, fileStrategy }); - newUser.avatar = newavatarUrl; + const webPBuffer = await resizeAvatar({ + userId, + input: avatarUrl, + }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + newUser.avatar = await processAvatar({ buffer: webPBuffer, userId }); await newUser.save(); } diff --git a/api/strategies/validators.js b/api/strategies/validators.js index 22e4fa6ec5a..e8ae300f03c 100644 --- a/api/strategies/validators.js +++ b/api/strategies/validators.js @@ -1,17 +1,20 @@ const { z } = require('zod'); -function errorsToString(errors) { - return errors - .map((error) => { - let field = error.path.join('.'); - let message = error.message; - - return `${field}: ${message}`; - }) - .join(' '); -} - -const allowedCharactersRegex = /^[a-zA-Z0-9_.@#$%&*()\p{Script=Latin}\p{Script=Common}]+$/u; +const allowedCharactersRegex = new RegExp( + '^[' + + 'a-zA-Z0-9_.@#$%&*()' + // Basic Latin characters and symbols + '\\p{Script=Latin}' + // Latin script characters + '\\p{Script=Common}' + // Characters common across scripts + '\\p{Script=Cyrillic}' + // Cyrillic script for Russian, etc. + '\\p{Script=Devanagari}' + // Devanagari script for Hindi, etc. + '\\p{Script=Han}' + // Han script for Chinese characters, etc. + '\\p{Script=Arabic}' + // Arabic script + '\\p{Script=Hiragana}' + // Hiragana script for Japanese + '\\p{Script=Katakana}' + // Katakana script for Japanese + '\\p{Script=Hangul}' + // Hangul script for Korean + ']+$', // End of string + 'u', // Use Unicode mode +); const injectionPatternsRegex = /('|--|\$ne|\$gt|\$lt|\$or|\{|\}|\*|;|<|>|\/|=)/i; const usernameSchema = z @@ -72,5 +75,4 @@ const registerSchema = z module.exports = { loginSchema, registerSchema, - errorsToString, }; diff --git a/api/strategies/validators.spec.js b/api/strategies/validators.spec.js index bd4e2192fbb..312f06923d5 100644 --- a/api/strategies/validators.spec.js +++ b/api/strategies/validators.spec.js @@ -1,4 +1,6 @@ -const { loginSchema, registerSchema, errorsToString } = require('./validators'); +// file deepcode ignore NoHardcodedPasswords: No hard-coded passwords in tests +const { errorsToString } = require('librechat-data-provider'); +const { loginSchema, registerSchema } = require('./validators'); describe('Zod Schemas', () => { describe('loginSchema', () => { @@ -402,9 +404,6 @@ describe('Zod Schemas', () => { it('should reject invalid usernames', () => { const invalidUsernames = [ - 'Дмитрий', // Cyrillic characters - 'محمد', // Arabic characters - '张伟', // Chinese characters 'john{doe}', // Contains `{` and `}` 'j', // Only one character 'a'.repeat(81), // More than 80 characters diff --git a/api/typedefs.js b/api/typedefs.js index 7bb956c9aec..e844e1eb909 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -8,12 +8,234 @@ * @memberof typedefs */ +/** + * @exports Anthropic + * @typedef {import('@anthropic-ai/sdk').default} Anthropic + * @memberof typedefs + */ + +/** + * @exports AssistantStreamEvent + * @typedef {import('openai').default.Beta.AssistantStreamEvent} AssistantStreamEvent + * @memberof typedefs + */ + +/** + * @exports AssistantStream + * @typedef {AsyncIterable} AssistantStream + * @memberof typedefs + */ + +/** + * @exports RunCreateAndStreamParams + * @typedef {import('openai').OpenAI.Beta.Threads.RunCreateAndStreamParams} RunCreateAndStreamParams + * @memberof typedefs + */ + +/** + * @exports ChatCompletionContentPartImage + * @typedef {import('openai').OpenAI.ChatCompletionContentPartImage} ChatCompletionContentPartImage + * @memberof typedefs + */ + +/** + * @exports ChatCompletion + * @typedef {import('openai').OpenAI.ChatCompletion} ChatCompletion + * @memberof typedefs + */ + +/** + * @exports ChatCompletionPayload + * @typedef {import('openai').OpenAI.ChatCompletionCreateParams} ChatCompletionPayload + * @memberof typedefs + */ + +/** + * @exports ChatCompletionMessages + * @typedef {import('openai').OpenAI.ChatCompletionMessageParam} ChatCompletionMessages + * @memberof typedefs + */ + +/** + * @exports CohereChatStreamRequest + * @typedef {import('cohere-ai').Cohere.ChatStreamRequest} CohereChatStreamRequest + * @memberof typedefs + */ + +/** + * @exports CohereChatRequest + * @typedef {import('cohere-ai').Cohere.ChatRequest} CohereChatRequest + * @memberof typedefs + */ + +/** + * @exports OpenAIRequestOptions + * @typedef {import('openai').OpenAI.RequestOptions} OpenAIRequestOptions + * @memberof typedefs + */ + +/** + * @exports ThreadCreated + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadCreated} ThreadCreated + * @memberof typedefs + */ + +/** + * @exports ThreadRunCreated + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunCreated} ThreadRunCreated + * @memberof typedefs + */ + +/** + * @exports ThreadRunQueued + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunQueued} ThreadRunQueued + * @memberof typedefs + */ + +/** + * @exports ThreadRunInProgress + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunInProgress} ThreadRunInProgress + * @memberof typedefs + */ + +/** + * @exports ThreadRunRequiresAction + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunRequiresAction} ThreadRunRequiresAction + * @memberof typedefs + */ + +/** + * @exports ThreadRunCompleted + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunCompleted} ThreadRunCompleted + * @memberof typedefs + */ + +/** + * @exports ThreadRunFailed + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunFailed} ThreadRunFailed + * @memberof typedefs + */ + +/** + * @exports ThreadRunCancelling + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunCancelling} ThreadRunCancelling + * @memberof typedefs + */ + +/** + * @exports ThreadRunCancelled + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunCancelled} ThreadRunCancelled + * @memberof typedefs + */ + +/** + * @exports ThreadRunExpired + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunExpired} ThreadRunExpired + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepCreated + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepCreated} ThreadRunStepCreated + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepInProgress + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepInProgress} ThreadRunStepInProgress + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepDelta + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepDelta} ThreadRunStepDelta + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepCompleted + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepCompleted} ThreadRunStepCompleted + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepFailed + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepFailed} ThreadRunStepFailed + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepCancelled + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepCancelled} ThreadRunStepCancelled + * @memberof typedefs + */ + +/** + * @exports ThreadRunStepExpired + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadRunStepExpired} ThreadRunStepExpired + * @memberof typedefs + */ + +/** + * @exports ThreadMessageCreated + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadMessageCreated} ThreadMessageCreated + * @memberof typedefs + */ + +/** + * @exports ThreadMessageInProgress + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadMessageInProgress} ThreadMessageInProgress + * @memberof typedefs + */ + +/** + * @exports ThreadMessageDelta + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadMessageDelta} ThreadMessageDelta + * @memberof typedefs + */ + +/** + * @exports ThreadMessageCompleted + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadMessageCompleted} ThreadMessageCompleted + * @memberof typedefs + */ + +/** + * @exports ThreadMessageIncomplete + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ThreadMessageIncomplete} ThreadMessageIncomplete + * @memberof typedefs + */ + +/** + * @exports ErrorEvent + * @typedef {import('openai').default.Beta.AssistantStreamEvent.ErrorEvent} ErrorEvent + * @memberof typedefs + */ + +/** + * @exports ToolCallDeltaObject + * @typedef {import('openai').default.Beta.Threads.Runs.Steps.ToolCallDeltaObject} ToolCallDeltaObject + * @memberof typedefs + */ + +/** + * @exports ToolCallDelta + * @typedef {import('openai').default.Beta.Threads.Runs.Steps.ToolCallDelta} ToolCallDelta + * @memberof typedefs + */ + /** * @exports Assistant * @typedef {import('librechat-data-provider').Assistant} Assistant * @memberof typedefs */ +/** + * @exports AssistantDocument + * @typedef {import('librechat-data-provider').AssistantDocument} AssistantDocument + * @memberof typedefs + */ + /** * @exports OpenAIFile * @typedef {import('librechat-data-provider').File} OpenAIFile @@ -26,12 +248,83 @@ * @memberof typedefs */ +/** + * @exports TAzureModelConfig + * @typedef {import('librechat-data-provider').TAzureModelConfig} TAzureModelConfig + * @memberof typedefs + */ + +/** + * @exports TAzureGroup + * @typedef {import('librechat-data-provider').TAzureGroup} TAzureGroup + * @memberof typedefs + */ + +/** + * @exports TAzureGroups + * @typedef {import('librechat-data-provider').TAzureGroups} TAzureGroups + * @memberof typedefs + */ + +/** + * @exports TAzureModelGroupMap + * @typedef {import('librechat-data-provider').TAzureModelGroupMap} TAzureModelGroupMap + * @memberof typedefs + */ +/** + * @exports TAzureGroupMap + * @typedef {import('librechat-data-provider').TAzureGroupMap} TAzureGroupMap + * @memberof typedefs + */ + +/** + * @exports TAzureConfig + * @typedef {import('librechat-data-provider').TAzureConfig} TAzureConfig + * @memberof typedefs + */ + +/** + * @exports TModelsConfig + * @typedef {import('librechat-data-provider').TModelsConfig} TModelsConfig + * @memberof typedefs + */ + +/** + * @exports TPlugin + * @typedef {import('librechat-data-provider').TPlugin} TPlugin + * @memberof typedefs + */ + +/** + * @exports TCustomConfig + * @typedef {import('librechat-data-provider').TCustomConfig} TCustomConfig + * @memberof typedefs + */ + +/** + * @exports TEndpoint + * @typedef {import('librechat-data-provider').TEndpoint} TEndpoint + * @memberof typedefs + */ + +/** + * @exports TEndpointsConfig + * @typedef {import('librechat-data-provider').TEndpointsConfig} TEndpointsConfig + * @memberof typedefs + */ + /** * @exports TMessage * @typedef {import('librechat-data-provider').TMessage} TMessage * @memberof typedefs */ +/** + * @exports TPlugin + * @typedef {import('librechat-data-provider').TPlugin} TPlugin + * @memberof typedefs + */ + /** * @exports FileSources * @typedef {import('librechat-data-provider').FileSources} FileSources @@ -39,12 +332,75 @@ */ /** - * @exports ImageMetadata - * @typedef {Object} ImageMetadata + * @exports TMessage + * @typedef {import('librechat-data-provider').TMessage} TMessage + * @memberof typedefs + */ + +/** + * @exports ImageFile + * @typedef {import('librechat-data-provider').ImageFile} ImageFile + * @memberof typedefs + */ + +/** + * @exports TMessageContentParts + * @typedef {import('librechat-data-provider').TMessageContentParts} TMessageContentParts + * @memberof typedefs + */ + +/** + * @exports StreamContentData + * @typedef {import('librechat-data-provider').StreamContentData} StreamContentData + * @memberof typedefs + */ + +/** + * @exports ActionRequest + * @typedef {import('librechat-data-provider').ActionRequest} ActionRequest + * @memberof typedefs + */ + +/** + * @exports Action + * @typedef {import('librechat-data-provider').Action} Action + * @memberof typedefs + */ + +/** + * @exports ActionMetadata + * @typedef {import('librechat-data-provider').ActionMetadata} ActionMetadata + * @memberof typedefs + */ + +/** + * @exports ActionAuth + * @typedef {import('librechat-data-provider').ActionAuth} ActionAuth + * @memberof typedefs + */ + +/** + * @exports DeleteFilesBody + * @typedef {import('librechat-data-provider').DeleteFilesBody} DeleteFilesBody + * @memberof typedefs + */ + +/** + * @exports FileMetadata + * @typedef {Object} FileMetadata * @property {string} file_id - The identifier of the file. * @property {string} [temp_file_id] - The temporary identifier of the file. + * @property {string} endpoint - The conversation endpoint origin for the file upload. + * @property {string} [assistant_id] - The assistant ID if file upload is in the `knowledge` context. + * @memberof typedefs + */ + +/** + * @typedef {Object} ImageOnlyMetadata * @property {number} width - The width of the image. * @property {number} height - The height of the image. + * + * @typedef {FileMetadata & ImageOnlyMetadata} ImageMetadata * @memberof typedefs */ @@ -54,6 +410,18 @@ * @memberof typedefs */ +/** + * @exports uploadImageBuffer + * @typedef {import('~/server/services/Files/process').uploadImageBuffer} uploadImageBuffer + * @memberof typedefs + */ + +/** + * @exports processFileURL + * @typedef {import('~/server/services/Files/process').processFileURL} processFileURL + * @memberof typedefs + */ + /** * @exports AssistantCreateParams * @typedef {import('librechat-data-provider').AssistantCreateParams} AssistantCreateParams @@ -78,21 +446,154 @@ * @memberof typedefs */ +/** + * @exports ContentPart + * @typedef {import('librechat-data-provider').ContentPart} ContentPart + * @memberof typedefs + */ + +/** + * @exports StepTypes + * @typedef {import('librechat-data-provider').StepTypes} StepTypes + * @memberof typedefs + */ + +/** + * @exports TContentData + * @typedef {import('librechat-data-provider').TContentData} TContentData + * @memberof typedefs + */ + +/** + * @exports ContentPart + * @typedef {import('librechat-data-provider').ContentPart} ContentPart + * @memberof typedefs + */ + +/** + * @exports PartMetadata + * @typedef {import('librechat-data-provider').PartMetadata} PartMetadata + * @memberof typedefs + */ + /** * @exports ThreadMessage - * @typedef {import('openai').OpenAI.Beta.Threads.ThreadMessage} ThreadMessage + * @typedef {import('openai').OpenAI.Beta.Threads.Message} ThreadMessage + * @memberof typedefs + */ + +/** + * @exports Annotation + * @typedef {import('openai').OpenAI.Beta.Threads.Messages.Annotation} Annotation + * @memberof typedefs + */ + +/** + * @exports TAssistantEndpoint + * @typedef {import('librechat-data-provider').TAssistantEndpoint} TAssistantEndpoint + * @memberof typedefs + */ + +/** + * Represents details of the message creation by the run step, including the ID of the created message. + * + * @exports MessageCreationStepDetails + * @typedef {Object} MessageCreationStepDetails + * @property {Object} message_creation - Details of the message creation. + * @property {string} message_creation.message_id - The ID of the message that was created by this run step. + * @property {'message_creation'} type - Always 'message_creation'. + * @memberof typedefs + */ + +/** + * Represents a text log output from the Code Interpreter tool call. + * @typedef {Object} CodeLogOutput + * @property {'logs'} type - Always 'logs'. + * @property {string} logs - The text output from the Code Interpreter tool call. + */ + +/** + * Represents an image output from the Code Interpreter tool call. + * @typedef {Object} CodeImageOutput + * @property {'image'} type - Always 'image'. + * @property {Object} image - The image object. + * @property {string} image.file_id - The file ID of the image. + */ + +/** + * Details of the Code Interpreter tool call the run step was involved in. + * Includes the tool call ID, the code interpreter definition, and the type of tool call. + * + * @typedef {Object} CodeToolCall + * @property {string} id - The ID of the tool call. + * @property {Object} code_interpreter - The Code Interpreter tool call definition. + * @property {string} code_interpreter.input - The input to the Code Interpreter tool call. + * @property {Array<(CodeLogOutput | CodeImageOutput)>} code_interpreter.outputs - The outputs from the Code Interpreter tool call. + * @property {'code_interpreter'} type - The type of tool call, always 'code_interpreter'. + * @memberof typedefs + */ + +/** + * Details of a Function tool call the run step was involved in. + * Includes the tool call ID, the function definition, and the type of tool call. + * + * @typedef {Object} FunctionToolCall + * @property {string} id - The ID of the tool call object. + * @property {Object} function - The definition of the function that was called. + * @property {string} function.arguments - The arguments passed to the function. + * @property {string} function.name - The name of the function. + * @property {string|null} function.output - The output of the function, null if not submitted. + * @property {'function'} type - The type of tool call, always 'function'. * @memberof typedefs */ /** + * Details of a Retrieval tool call the run step was involved in. + * Includes the tool call ID and the type of tool call. + * + * @typedef {Object} RetrievalToolCall + * @property {string} id - The ID of the tool call object. + * @property {unknown} retrieval - An empty object for now. + * @property {'retrieval'} type - The type of tool call, always 'retrieval'. + * @memberof typedefs + */ + +/** + * Details of the tool calls involved in a run step. + * Can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`. + * + * @typedef {Object} ToolCallsStepDetails + * @property {Array} tool_calls - An array of tool calls the run step was involved in. + * @property {'tool_calls'} type - Always 'tool_calls'. + * @memberof typedefs + */ + +/** + * Details of the tool calls involved in a run step. + * Can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`. + * + * @exports StepToolCall + * @typedef {(CodeToolCall | RetrievalToolCall | FunctionToolCall) & PartMetadata} StepToolCall + * @memberof typedefs + */ + +/** + * Represents a tool call object required for certain actions in the OpenAI API, + * including the function definition and type of the tool call. + * * @exports RequiredActionFunctionToolCall - * @typedef {import('openai').OpenAI.Beta.Threads.RequiredActionFunctionToolCall} RequiredActionFunctionToolCall + * @typedef {Object} RequiredActionFunctionToolCall + * @property {string} id - The ID of the tool call, referenced when submitting tool outputs. + * @property {Object} function - The function definition associated with the tool call. + * @property {string} function.arguments - The arguments that the model expects to be passed to the function. + * @property {string} function.name - The name of the function. + * @property {'function'} type - The type of tool call the output is required for, currently always 'function'. * @memberof typedefs */ /** * @exports RunManager - * @typedef {import('./server/services/Runs/RunMananger.js').RunManager} RunManager + * @typedef {import('./server/services/Runs/RunManager.js').RunManager} RunManager * @memberof typedefs */ @@ -100,7 +601,7 @@ * @exports Thread * @typedef {Object} Thread * @property {string} id - The identifier of the thread. - * @property {string} object - The object type, always 'thread'. + * @property {'thread'} object - The object type, always 'thread'. * @property {number} created_at - The Unix timestamp (in seconds) for when the thread was created. * @property {Object} [metadata] - Optional metadata associated with the thread. * @property {Message[]} [messages] - An array of messages associated with the thread. @@ -111,12 +612,12 @@ * @exports Message * @typedef {Object} Message * @property {string} id - The identifier of the message. - * @property {string} object - The object type, always 'thread.message'. + * @property {'thread.message'} object - The object type, always 'thread.message'. * @property {number} created_at - The Unix timestamp (in seconds) for when the message was created. * @property {string} thread_id - The thread ID that this message belongs to. - * @property {string} role - The entity that produced the message. One of 'user' or 'assistant'. + * @property {'user'|'assistant'} role - The entity that produced the message. One of 'user' or 'assistant'. * @property {Object[]} content - The content of the message in an array of text and/or images. - * @property {string} content[].type - The type of content, either 'text' or 'image_file'. + * @property {'text'|'image_file'} content[].type - The type of content, either 'text' or 'image_file'. * @property {Object} [content[].text] - The text content, present if type is 'text'. * @property {string} content[].text.value - The data that makes up the text. * @property {Object[]} [content[].text.annotations] - Annotations for the text content. @@ -158,7 +659,7 @@ /** * @exports FunctionTool * @typedef {Object} FunctionTool - * @property {string} type - The type of tool, 'function'. + * @property {'function'} type - The type of tool, 'function'. * @property {Object} function - The function definition. * @property {string} function.description - A description of what the function does. * @property {string} function.name - The name of the function to be called. @@ -169,7 +670,7 @@ /** * @exports Tool * @typedef {Object} Tool - * @property {string} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. + * @property {'code_interpreter'|'retrieval'|'function'} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. * @property {FunctionTool} [function] - The function tool, present if type is 'function'. * @memberof typedefs */ @@ -182,7 +683,7 @@ * @property {number} created_at - The Unix timestamp (in seconds) for when the run was created. * @property {string} thread_id - The ID of the thread that was executed on as a part of this run. * @property {string} assistant_id - The ID of the assistant used for execution of this run. - * @property {string} status - The status of the run (e.g., 'queued', 'completed'). + * @property {'queued'|'in_progress'|'requires_action'|'cancelling'|'cancelled'|'failed'|'completed'|'expired'} status - The status of the run: queued, in_progress, requires_action, cancelling, cancelled, failed, completed, or expired. * @property {Object} [required_action] - Details on the action required to continue the run. * @property {string} required_action.type - The type of required action, always 'submit_tool_outputs'. * @property {Object} required_action.submit_tool_outputs - Details on the tool outputs needed for the run to continue. @@ -202,9 +703,15 @@ * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run was completed. * @property {string} [model] - The model that the assistant used for this run. * @property {string} [instructions] - The instructions that the assistant used for this run. + * @property {string} [additional_instructions] - Optional. Appends additional instructions + * at theend of the instructions for the run. This is useful for modifying * @property {Tool[]} [tools] - The list of tools used for this run. * @property {string[]} [file_ids] - The list of File IDs used for this run. * @property {Object} [metadata] - Metadata associated with this run. + * @property {Object} [usage] - Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + * @property {number} [usage.completion_tokens] - Number of completion tokens used over the course of the run. + * @property {number} [usage.prompt_tokens] - Number of prompt tokens used over the course of the run. + * @property {number} [usage.total_tokens] - Total number of tokens used (prompt + completion). * @memberof typedefs */ @@ -217,11 +724,11 @@ * @property {string} assistant_id - The ID of the assistant associated with the run step. * @property {string} thread_id - The ID of the thread that was run. * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. + * @property {'message_creation' | 'tool_calls'} type - The type of run step. + * @property {'in_progress' | 'cancelled' | 'failed' | 'completed' | 'expired'} status - The status of the run step. + * @property {MessageCreationStepDetails | ToolCallsStepDetails} step_details - The details of the run step. * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. + * @property {'server_error' | 'rate_limit_exceeded'} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. * @property {string} last_error.message - A human-readable description of the error. * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. @@ -241,8 +748,8 @@ * @property {string} assistant_id - The ID of the assistant associated with the run step. * @property {string} thread_id - The ID of the thread that was run. * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. + * @property {'message_creation'|'tool_calls'} type - The type of run step, either 'message_creation' or 'tool_calls'. + * @property {'in_progress'|'cancelled'|'failed'|'completed'|'expired'} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. * @property {Object} step_details - The details of the run step. * @property {Object} [last_error] - The last error associated with this run step. * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. @@ -350,6 +857,41 @@ * @memberof typedefs */ +/** + * @exports RequiredAction + * @typedef {Object} RequiredAction + * @property {string} tool - The name of the function. + * @property {Object} toolInput - The args to invoke the function with. + * @property {string} toolCallId - The ID of the tool call. + * @property {Run['id']} run_id - Run identifier. + * @property {Thread['id']} thread_id - Thread identifier. + * @memberof typedefs + */ + +/** + * @exports StructuredTool + * @typedef {Object} StructuredTool + * @property {string} name - The name of the function. + * @property {string} description - The description of the function. + * @property {import('zod').ZodTypeAny} schema - The structured zod schema. + * @memberof typedefs + */ + +/** + * @exports ToolOutput + * @typedef {Object} ToolOutput + * @property {string} tool_call_id - The ID of the tool call. + * @property {Object} output - The output of the tool, which can vary in structure. + * @memberof typedefs + */ + +/** + * @exports ToolOutputs + * @typedef {Object} ToolOutputs + * @property {ToolOutput[]} tool_outputs - Array of tool outputs. + * @memberof typedefs + */ + /** * @typedef {Object} ModelOptions * @property {string} modelName - The name of the model. @@ -385,3 +927,203 @@ * @property {string} [azureOpenAIApiVersion] - The Azure OpenAI API version. * @memberof typedefs */ + +/** + * @typedef {Object} TokenConfig + * A configuration object mapping model keys to their respective prompt, completion rates, and context limit. + * @property {number} prompt - The prompt rate + * @property {number} completion - The completion rate + * @property {number} context - The maximum context length supported by the model. + * @memberof typedefs + */ + +/** + * @typedef {Record} EndpointTokenConfig + * An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit. + * @memberof typedefs + */ + +/** + * @typedef {Object} ResponseMessage + * @property {string} conversationId - The ID of the conversation. + * @property {string} thread_id - The ID of the thread. + * @property {string} messageId - The ID of the message (from LibreChat). + * @property {string} parentMessageId - The ID of the parent message. + * @property {string} user - The ID of the user. + * @property {string} assistant_id - The ID of the assistant. + * @property {string} role - The role of the response. + * @property {string} model - The model used in the response. + * @property {ContentPart[]} content - The content parts accumulated from the run. + * @memberof typedefs + */ + +/** + * @typedef {Object} RunResponse + * @property {Run} run - The detailed information about the run. + * @property {RunStep[]} steps - An array of steps taken during the run. + * @property {StepMessage[]} messages - An array of messages related to the run. + * @property {ResponseMessage} finalMessage - The final response message, with all content parts. + * @property {string} text - The final response text, accumulated from message parts + * @memberof typedefs + */ + +/** + * @callback InProgressFunction + * @param {Object} params - The parameters for the in progress step. + * @param {RunStep} params.step - The step object with details about the message creation. + * @returns {Promise} - A promise that resolves when the step is processed. + * @memberof typedefs + */ + +// /** +// * @typedef {OpenAI & { +// * req: Express.Request, +// * res: Express.Response +// * getPartialText: () => string, +// * processedFileIds: Set, +// * mappedOrder: Map, +// * completeToolCallSteps: Set, +// * seenCompletedMessages: Set, +// * seenToolCalls: Map, +// * progressCallback: (options: Object) => void, +// * addContentData: (data: TContentData) => void, +// * responseMessage: ResponseMessage, +// * }} OpenAIClient - for reference only +// */ + +/** + * @typedef {Object} RunClient + * + * @property {Express.Request} req - The Express request object. + * @property {Express.Response} res - The Express response object. + * @property {?import('https-proxy-agent').HttpsProxyAgent} httpAgent - An optional HTTP proxy agent for the request. + + * @property {() => string} getPartialText - Retrieves the current tokens accumulated by `progressCallback`. + * + * Note: not used until real streaming is implemented by OpenAI. + * + * @property {string} responseText -The accumulated text values for the current run. + * @property {Set} processedFileIds - A set of IDs for processed files. + * @property {Map} mappedOrder - A map to maintain the order of individual `tool_calls` and `steps`. + * @property {Set} [attachedFileIds] - A set of user attached file ids; necessary to track which files are downloadable. + * @property {Set} completeToolCallSteps - A set of completed tool call steps. + * @property {Set} seenCompletedMessages - A set of completed messages that have been seen/processed. + * @property {Map} seenToolCalls - A map of tool calls that have been seen/processed. + * @property {object | undefined} locals - Local variables for the request. + * @property {AzureOptions} locals.azureOptions - Local Azure options for the request. + * @property {(data: TContentData) => void} addContentData - Updates the response message's relevant + * @property {InProgressFunction} in_progress - Updates the response message's relevant + * content array with the part by index & sends intermediate SSE message with content data. + * + * Note: does not send intermediate SSE message for messages, which are streamed + * (may soon be streamed) directly from OpenAI API. + * + * @property {ResponseMessage} responseMessage - A message object for responses. + * + * @typedef {OpenAI & RunClient} OpenAIClient + */ + +/** + * The body of the request to create a run, specifying the assistant, model, + * instructions, and any additional parameters needed for the run. + * + * @typedef {Object} CreateRunBody + * @property {string} assistant_id - The ID of the assistant to use for this run. + * @property {string} [model] - Optional. The ID of the model to be used for this run. + * @property {string} [instructions] - Optional. Override the default system message of the assistant. + * @property {string} [additional_instructions] - Optional. Appends additional instructions + * at the end of the instructions for the run. Useful for modifying behavior on a per-run basis without overriding other instructions. + * @property {Object[]} [tools] - Optional. Override the tools the assistant can use for this run. Should include tool call ID and the type of tool call. + * @property {string[]} [file_ids] - Optional. List of File IDs the assistant can use for this run. + * **Note:** The API seems to prefer files added to messages, not runs. + * @property {Object} [metadata] - Optional. Metadata for the run. + * @memberof typedefs + */ + +/** + * @typedef {Object} StreamRunManager + * Manages streaming and processing of run steps, messages, and tool calls within a thread. + * + * @property {number} index - Tracks the current index for step or message processing. + * @property {Map} steps - Stores run steps by their IDs. + * @property {Map} mappedOrder - Maps step or message IDs to their processing order index. + * @property {Map} orderedRunSteps - Stores run steps in order of processing. + * @property {Set} processedFileIds - Keeps track of file IDs that have been processed. + * @property {Map} progressCallbacks - Stores callbacks for reporting progress on step or message processing. + * @property {boolean} submittedToolOutputs - Indicates whether tool outputs have been submitted. + * @property {Object|null} run - Holds the current run object. + * @property {Object} req - The HTTP request object associated with the run. + * @property {Object} res - The HTTP response object for sending back data. + * @property {Object} openai - The OpenAI client instance. + * @property {string} apiKey - The API key used for OpenAI requests. + * @property {string} thread_id - The ID of the thread associated with the run. + * @property {Object} initialRunBody - The initial body of the run request. + * @property {Object.} clientHandlers - Custom handlers provided by the client. + * @property {Object} streamOptions - Options for streaming the run. + * @property {Object} finalMessage - The final message object to be constructed and sent. + * @property {Array} messages - An array of messages processed during the run. + * @property {string} text - Accumulated text from text content data. + * @property {Object.} handlers - Internal event handlers for different types of streaming events. + * + * @method addContentData Adds content data to the final message or sends it immediately depending on type. + * @method runAssistant Initializes and manages the streaming of a thread run. + * @method handleEvent Dispatches streaming events to the appropriate handlers. + * @method handleThreadCreated Handles the event when a thread is created. + * @method handleRunEvent Handles various run state events. + * @method handleRunStepEvent Handles events related to individual run steps. + * @method handleCodeImageOutput Processes and handles code-generated image outputs. + * @method createToolCallStream Initializes streaming for tool call outputs. + * @method handleNewToolCall Handles the creation of a new tool call within a run step. + * @method handleCompletedToolCall Handles the completion of tool call processing. + * @method handleRunStepDeltaEvent Handles updates (deltas) for run steps. + * @method handleMessageDeltaEvent Handles updates (deltas) for messages. + * @method handleErrorEvent Handles error events during streaming. + * @method getStepIndex Retrieves or assigns an index for a given step or message key. + * @method generateToolCallKey Generates a unique key for a tool call within a step. + * @method onRunRequiresAction Handles actions required by a run to proceed. + * @method onRunStepCreated Handles the creation of a new run step. + * @method onRunStepCompleted Handles the completion of a run step. + * @method handleMessageEvent Handles events related to messages within the run. + * @method messageCompleted Handles the completion of a message processing. + */ + +/* Native app/client methods */ + +/** + * Accumulates tokens and sends them to the client for processing. + * @callback onTokenProgress + * @param {string} token - The current token generated by the model. + * @returns {Promise} + * @memberof typedefs + */ + +/** + * Main entrypoint for API completion calls + * @callback sendCompletion + * @param {Array | string} payload - The messages or prompt to send to the model + * @param {object} opts - Options for the completion + * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress + * @param {AbortController} opts.abortController - AbortController instance + * @returns {Promise} + * @memberof typedefs + */ + +/** + * Legacy completion handler for OpenAI API. + * @callback getCompletion + * @param {Array | string} input - Array of messages or a single prompt string + * @param {(event: object | string) => Promise} onProgress - SSE progress handler + * @param {onTokenProgress} onTokenProgress - Token progress handler + * @param {AbortController} [abortController] - AbortController instance + * @returns {Promise} - Completion response + * @memberof typedefs + */ + +/** + * Cohere Stream handling. Note: abortController is not supported here. + * @callback cohereChatCompletion + * @param {object} params + * @param {CohereChatStreamRequest | CohereChatRequest} params.payload + * @param {onTokenProgress} params.onTokenProgress + * @memberof typedefs + */ diff --git a/api/utils/azureUtils.js b/api/utils/azureUtils.js index 8083ff4fb3b..91d62b20e5e 100644 --- a/api/utils/azureUtils.js +++ b/api/utils/azureUtils.js @@ -6,7 +6,7 @@ const { isEnabled } = require('~/server/utils'); * @returns {string} The sanitized model name. */ const sanitizeModelName = (modelName) => { - // Replace periods with empty strings and other disallowed characters as needed + // Replace periods with empty strings and other disallowed characters as needed. return modelName.replace(/\./g, ''); }; @@ -84,16 +84,19 @@ const getAzureCredentials = () => { * * @param {Object} params - The parameters object. * @param {string} params.baseURL - The baseURL to inspect for replacement placeholders. - * @param {AzureOptions} params.azure - The baseURL to inspect for replacement placeholders. + * @param {AzureOptions} params.azureOptions - The azure options object containing the instance and deployment names. * @returns {string} The complete baseURL with credentials injected for the Azure OpenAI API. */ -function constructAzureURL({ baseURL, azure }) { +function constructAzureURL({ baseURL, azureOptions }) { let finalURL = baseURL; // Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available - if (azure) { - finalURL = finalURL.replace('${INSTANCE_NAME}', azure.azureOpenAIApiInstanceName ?? ''); - finalURL = finalURL.replace('${DEPLOYMENT_NAME}', azure.azureOpenAIApiDeploymentName ?? ''); + if (azureOptions) { + finalURL = finalURL.replace('${INSTANCE_NAME}', azureOptions.azureOpenAIApiInstanceName ?? ''); + finalURL = finalURL.replace( + '${DEPLOYMENT_NAME}', + azureOptions.azureOpenAIApiDeploymentName ?? '', + ); } return finalURL; diff --git a/api/utils/azureUtils.spec.js b/api/utils/azureUtils.spec.js index 77db26b0911..4d844513856 100644 --- a/api/utils/azureUtils.spec.js +++ b/api/utils/azureUtils.spec.js @@ -199,7 +199,7 @@ describe('constructAzureURL', () => { test('replaces both placeholders when both properties are provided', () => { const url = constructAzureURL({ baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}', - azure: { + azureOptions: { azureOpenAIApiInstanceName: 'instance1', azureOpenAIApiDeploymentName: 'deployment1', }, @@ -210,7 +210,7 @@ describe('constructAzureURL', () => { test('replaces only INSTANCE_NAME when only azureOpenAIApiInstanceName is provided', () => { const url = constructAzureURL({ baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}', - azure: { + azureOptions: { azureOpenAIApiInstanceName: 'instance2', }, }); @@ -220,7 +220,7 @@ describe('constructAzureURL', () => { test('replaces only DEPLOYMENT_NAME when only azureOpenAIApiDeploymentName is provided', () => { const url = constructAzureURL({ baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}', - azure: { + azureOptions: { azureOpenAIApiDeploymentName: 'deployment2', }, }); @@ -230,12 +230,12 @@ describe('constructAzureURL', () => { test('does not replace any placeholders when azure object is empty', () => { const url = constructAzureURL({ baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}', - azure: {}, + azureOptions: {}, }); expect(url).toBe('https://example.com//'); }); - test('returns baseURL as is when azure object is not provided', () => { + test('returns baseURL as is when `azureOptions` object is not provided', () => { const url = constructAzureURL({ baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}', }); @@ -245,7 +245,7 @@ describe('constructAzureURL', () => { test('returns baseURL as is when no placeholders are set', () => { const url = constructAzureURL({ baseURL: 'https://example.com/my_custom_instance/my_deployment', - azure: { + azureOptions: { azureOpenAIApiInstanceName: 'instance1', azureOpenAIApiDeploymentName: 'deployment1', }, @@ -258,7 +258,7 @@ describe('constructAzureURL', () => { 'https://${INSTANCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_NAME}'; const url = constructAzureURL({ baseURL, - azure: { + azureOptions: { azureOpenAIApiInstanceName: 'instance1', azureOpenAIApiDeploymentName: 'deployment1', }, diff --git a/api/utils/extractBaseURL.js b/api/utils/extractBaseURL.js index 730473c4102..09bbb55056f 100644 --- a/api/utils/extractBaseURL.js +++ b/api/utils/extractBaseURL.js @@ -1,3 +1,5 @@ +const { CohereConstants } = require('librechat-data-provider'); + /** * Extracts a valid OpenAI baseURL from a given string, matching "url/v1," followed by an optional suffix. * The suffix can be one of several predefined values (e.g., 'openai', 'azure-openai', etc.), @@ -12,9 +14,17 @@ * - `https://api.example.com/v1/replicate` -> `https://api.example.com/v1/replicate` * * @param {string} url - The URL to be processed. - * @returns {string} The matched pattern or input if no match is found. + * @returns {string | undefined} The matched pattern or input if no match is found. */ function extractBaseURL(url) { + if (!url || typeof url !== 'string') { + return undefined; + } + + if (url.startsWith(CohereConstants.API_URL)) { + return null; + } + if (!url.includes('/v1')) { return url; } diff --git a/api/utils/index.js b/api/utils/index.js index a40c53b6aba..7b539cbb141 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,13 +1,15 @@ const loadYaml = require('./loadYaml'); const tokenHelpers = require('./tokens'); const azureUtils = require('./azureUtils'); +const logAxiosError = require('./logAxiosError'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); module.exports = { - ...azureUtils, + loadYaml, ...tokenHelpers, + ...azureUtils, + logAxiosError, extractBaseURL, findMessageContent, - loadYaml, }; diff --git a/api/utils/logAxiosError.js b/api/utils/logAxiosError.js new file mode 100644 index 00000000000..17fac85f47d --- /dev/null +++ b/api/utils/logAxiosError.js @@ -0,0 +1,45 @@ +const { logger } = require('~/config'); + +/** + * Logs Axios errors based on the error object and a custom message. + * + * @param {Object} options - The options object. + * @param {string} options.message - The custom message to be logged. + * @param {Error} options.error - The Axios error object. + */ +const logAxiosError = ({ message, error }) => { + const timedOutMessage = 'Cannot read properties of undefined (reading \'status\')'; + if (error.response) { + logger.error( + `${message} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ + error.message ? error.message : '' + }. Error response data:\n`, + { + headers: error.response?.headers, + status: error.response?.status, + data: error.response?.data, + }, + ); + } else if (error.request) { + logger.error( + `${message} The request was made but no response was received: ${ + error.message ? error.message : '' + }. Error Request:\n`, + { + request: error.request, + }, + ); + } else if (error?.message?.includes(timedOutMessage)) { + logger.error( + `${message}\nThe request either timed out or was unsuccessful. Error message:\n`, + error, + ); + } else { + logger.error( + `${message}\nSomething happened in setting up the request. Error message:\n`, + error, + ); + } +}; + +module.exports = logAxiosError; diff --git a/api/utils/tokens.js b/api/utils/tokens.js index ce6c51732aa..bc1a5a10f3a 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -1,3 +1,4 @@ +const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const models = [ @@ -45,38 +46,61 @@ const openAIModels = { 'gpt-4-32k': 32758, // -10 from max 'gpt-4-32k-0314': 32758, // -10 from max 'gpt-4-32k-0613': 32758, // -10 from max - 'gpt-3.5-turbo': 4092, // -5 from max + 'gpt-4-1106': 127990, // -10 from max + 'gpt-4-0125': 127990, // -10 from max + 'gpt-4-turbo': 127990, // -10 from max + 'gpt-3.5-turbo': 16375, // -10 from max 'gpt-3.5-turbo-0613': 4092, // -5 from max 'gpt-3.5-turbo-0301': 4092, // -5 from max 'gpt-3.5-turbo-16k': 16375, // -10 from max 'gpt-3.5-turbo-16k-0613': 16375, // -10 from max 'gpt-3.5-turbo-1106': 16375, // -10 from max - 'gpt-4-1106': 127990, // -10 from max + 'gpt-3.5-turbo-0125': 16375, // -10 from max 'mistral-': 31990, // -10 from max }; +const cohereModels = { + 'command-light': 4086, // -10 from max + 'command-light-nightly': 8182, // -10 from max + command: 4086, // -10 from max + 'command-nightly': 8182, // -10 from max + 'command-r': 127500, // -500 from max + 'command-r-plus:': 127500, // -500 from max +}; + +const googleModels = { + /* Max I/O is combined so we subtract the amount from max response tokens for actual total */ + gemini: 32750, // -10 from max + 'text-bison-32k': 32758, // -10 from max + 'chat-bison-32k': 32758, // -10 from max + 'code-bison-32k': 32758, // -10 from max + 'codechat-bison-32k': 32758, + /* Codey, -5 from max: 6144 */ + 'code-': 6139, + 'codechat-': 6139, + /* PaLM2, -5 from max: 8192 */ + 'text-': 8187, + 'chat-': 8187, +}; + +const anthropicModels = { + 'claude-': 100000, + 'claude-2': 100000, + 'claude-2.1': 200000, + 'claude-3-haiku': 200000, + 'claude-3-sonnet': 200000, + 'claude-3-opus': 200000, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels }; + // Order is important here: by model series and context size (gpt-4 then gpt-3, ascending) const maxTokensMap = { - [EModelEndpoint.openAI]: openAIModels, - [EModelEndpoint.custom]: openAIModels, - [EModelEndpoint.google]: { - /* Max I/O is combined so we subtract the amount from max response tokens for actual total */ - gemini: 32750, // -10 from max - 'text-bison-32k': 32758, // -10 from max - 'chat-bison-32k': 32758, // -10 from max - 'code-bison-32k': 32758, // -10 from max - 'codechat-bison-32k': 32758, - /* Codey, -5 from max: 6144 */ - 'code-': 6139, - 'codechat-': 6139, - /* PaLM2, -5 from max: 8192 */ - 'text-': 8187, - 'chat-': 8187, - }, - [EModelEndpoint.anthropic]: { - 'claude-2.1': 200000, - 'claude-': 100000, - }, + [EModelEndpoint.azureOpenAI]: openAIModels, + [EModelEndpoint.openAI]: aggregateModels, + [EModelEndpoint.custom]: aggregateModels, + [EModelEndpoint.google]: googleModels, + [EModelEndpoint.anthropic]: anthropicModels, }; /** @@ -85,6 +109,7 @@ const maxTokensMap = { * * @param {string} modelName - The name of the model to look up. * @param {string} endpoint - The endpoint (default is 'openAI'). + * @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup * @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found. * * @example @@ -92,16 +117,21 @@ const maxTokensMap = { * getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767 * getModelMaxTokens('unknown-model'); // Returns undefined */ -function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) { +function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) { if (typeof modelName !== 'string') { return undefined; } - const tokensMap = maxTokensMap[endpoint]; + /** @type {EndpointTokenConfig | Record} */ + const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint]; if (!tokensMap) { return undefined; } + if (tokensMap[modelName]?.context) { + return tokensMap[modelName].context; + } + if (tokensMap[modelName]) { return tokensMap[modelName]; } @@ -109,7 +139,8 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) { const keys = Object.keys(tokensMap); for (let i = keys.length - 1; i >= 0; i--) { if (modelName.includes(keys[i])) { - return tokensMap[keys[i]]; + const result = tokensMap[keys[i]]; + return result?.context ?? result; } } @@ -145,17 +176,70 @@ function matchModelName(modelName, endpoint = EModelEndpoint.openAI) { const keys = Object.keys(tokensMap); for (let i = keys.length - 1; i >= 0; i--) { - if (modelName.includes(keys[i])) { - return keys[i]; + const modelKey = keys[i]; + if (modelName.includes(modelKey)) { + return modelKey; } } return modelName; } +const modelSchema = z.object({ + id: z.string(), + pricing: z.object({ + prompt: z.string(), + completion: z.string(), + }), + context_length: z.number(), +}); + +const inputSchema = z.object({ + data: z.array(modelSchema), +}); + +/** + * Processes a list of model data from an API and organizes it into structured data based on URL and specifics of rates and context. + * @param {{ data: Array> }} input The input object containing base URL and data fetched from the API. + * @returns {EndpointTokenConfig} The processed model data. + */ +function processModelData(input) { + const validationResult = inputSchema.safeParse(input); + if (!validationResult.success) { + throw new Error('Invalid input data'); + } + const { data } = validationResult.data; + + /** @type {EndpointTokenConfig} */ + const tokenConfig = {}; + + for (const model of data) { + const modelKey = model.id; + if (modelKey === 'openrouter/auto') { + model.pricing = { + prompt: '0.00001', + completion: '0.00003', + }; + } + const prompt = parseFloat(model.pricing.prompt) * 1000000; + const completion = parseFloat(model.pricing.completion) * 1000000; + + tokenConfig[modelKey] = { + prompt, + completion, + context: model.context_length, + }; + } + + return tokenConfig; +} + module.exports = { tiktokenModels: new Set(models), maxTokensMap, + inputSchema, + modelSchema, getModelMaxTokens, matchModelName, + processModelData, }; diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 2cb7985d312..641b300458a 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -80,6 +80,23 @@ describe('getModelMaxTokens', () => { ); }); + // 01/25 Update + test('should return correct tokens for gpt-4-turbo/0125 matches', () => { + expect(getModelMaxTokens('gpt-4-turbo')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-turbo'], + ); + expect(getModelMaxTokens('gpt-4-turbo-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-turbo'], + ); + expect(getModelMaxTokens('gpt-4-0125')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4-0125']); + expect(getModelMaxTokens('gpt-4-0125-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4-0125'], + ); + expect(getModelMaxTokens('gpt-3.5-turbo-0125')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-3.5-turbo-0125'], + ); + }); + test('should return correct tokens for Anthropic models', () => { const models = [ 'claude-2.1', @@ -124,6 +141,69 @@ describe('getModelMaxTokens', () => { maxTokensMap[EModelEndpoint.google]['chat-'], ); }); + + test('should return correct tokens when using a custom endpointTokenConfig', () => { + const customTokenConfig = { + 'custom-model': 12345, + }; + expect(getModelMaxTokens('custom-model', EModelEndpoint.openAI, customTokenConfig)).toBe(12345); + }); + + test('should prioritize endpointTokenConfig over the default configuration', () => { + const customTokenConfig = { + 'gpt-4-32k': 9999, + }; + expect(getModelMaxTokens('gpt-4-32k', EModelEndpoint.openAI, customTokenConfig)).toBe(9999); + }); + + test('should return undefined if the model is not found in custom endpointTokenConfig', () => { + const customTokenConfig = { + 'custom-model': 12345, + }; + expect( + getModelMaxTokens('nonexistent-model', EModelEndpoint.openAI, customTokenConfig), + ).toBeUndefined(); + }); + + test('should return correct tokens for exact match in azureOpenAI models', () => { + expect(getModelMaxTokens('gpt-4-turbo', EModelEndpoint.azureOpenAI)).toBe( + maxTokensMap[EModelEndpoint.azureOpenAI]['gpt-4-turbo'], + ); + }); + + test('should return undefined for no match in azureOpenAI models', () => { + expect( + getModelMaxTokens('nonexistent-azure-model', EModelEndpoint.azureOpenAI), + ).toBeUndefined(); + }); + + test('should return undefined for undefined, null, or number model argument with azureOpenAI endpoint', () => { + expect(getModelMaxTokens(undefined, EModelEndpoint.azureOpenAI)).toBeUndefined(); + expect(getModelMaxTokens(null, EModelEndpoint.azureOpenAI)).toBeUndefined(); + expect(getModelMaxTokens(1234, EModelEndpoint.azureOpenAI)).toBeUndefined(); + }); + + test('should respect custom endpointTokenConfig over azureOpenAI defaults', () => { + const customTokenConfig = { + 'custom-azure-model': 4096, + }; + expect( + getModelMaxTokens('custom-azure-model', EModelEndpoint.azureOpenAI, customTokenConfig), + ).toBe(4096); + }); + + test('should return correct tokens for partial match with custom endpointTokenConfig in azureOpenAI', () => { + const customTokenConfig = { + 'azure-custom-': 1024, + }; + expect( + getModelMaxTokens('azure-custom-gpt-3', EModelEndpoint.azureOpenAI, customTokenConfig), + ).toBe(1024); + }); + + test('should return undefined for a model when using an unsupported endpoint', () => { + expect(getModelMaxTokens('azure-gpt-3', 'unsupportedEndpoint')).toBeUndefined(); + }); }); describe('matchModelName', () => { @@ -166,6 +246,16 @@ describe('matchModelName', () => { expect(matchModelName('gpt-4-1106-vision-preview')).toBe('gpt-4-1106'); }); + // 01/25 Update + it('should return the closest matching key for gpt-4-turbo/0125 matches', () => { + expect(matchModelName('openai/gpt-4-0125')).toBe('gpt-4-0125'); + expect(matchModelName('gpt-4-turbo-preview')).toBe('gpt-4-turbo'); + expect(matchModelName('gpt-4-turbo-vision-preview')).toBe('gpt-4-turbo'); + expect(matchModelName('gpt-4-0125')).toBe('gpt-4-0125'); + expect(matchModelName('gpt-4-0125-preview')).toBe('gpt-4-0125'); + expect(matchModelName('gpt-4-0125-vision-preview')).toBe('gpt-4-0125'); + }); + // Tests for Google models it('should return the exact model name if it exists in maxTokensMap - Google models', () => { expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k'); diff --git a/bun.lockb b/bun.lockb index 9fd1d656d17..b85c088716c 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/client/index.html b/client/index.html index 3654c8766ba..bad6b4194c8 100644 --- a/client/index.html +++ b/client/index.html @@ -2,7 +2,7 @@ - + LibreChat diff --git a/client/nginx.conf b/client/nginx.conf index 29d01073e2e..49a2dd16fbe 100644 --- a/client/nginx.conf +++ b/client/nginx.conf @@ -14,12 +14,12 @@ server { # The default limits for image uploads as of 11/22/23 is 20MB/file, and 25MB/request client_max_body_size 25M; - location /api { - proxy_pass http://api:3080/api; + location /api/ { + proxy_pass http://api:3080$request_uri; } location / { - proxy_pass http://api:3080; + proxy_pass http://api:3080/; } ######################################## SSL ######################################## diff --git a/client/package.json b/client/package.json index 19c8c01caf6..ae42996cd91 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "0.6.6", + "version": "0.7.0", "description": "", "type": "module", "scripts": { @@ -25,17 +25,24 @@ "bugs": { "url": "https://github.com/danny-avila/LibreChat/issues" }, - "homepage": "https://github.com/danny-avila/LibreChat#readme", + "homepage": "https://librechat.ai", "dependencies": { + "@ariakit/react": "^0.4.5", + "@dicebear/collection": "^7.0.4", + "@dicebear/core": "^7.0.4", "@headlessui/react": "^1.7.13", + "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.2", "@radix-ui/react-checkbox": "^1.0.3", + "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-dialog": "^1.0.2", "@radix-ui/react-dropdown-menu": "^2.0.2", "@radix-ui/react-hover-card": "^1.0.5", "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.0", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-radio-group": "^1.1.3", + "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slider": "^1.1.1", "@radix-ui/react-switch": "^1.0.3", @@ -43,12 +50,14 @@ "@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-tooltip": "^1.0.6", "@tanstack/react-query": "^4.28.0", + "@tanstack/react-table": "^8.11.7", "@zattoo/use-double-click": "1.2.0", "axios": "^1.3.4", "class-variance-authority": "^0.6.0", "clsx": "^1.2.1", "copy-to-clipboard": "^3.3.3", "cross-env": "^7.0.3", + "date-fns": "^3.3.1", "downloadjs": "^1.4.7", "export-from-json": "^1.7.2", "filenamify": "^6.0.0", @@ -57,6 +66,7 @@ "librechat-data-provider": "*", "lodash": "^4.17.21", "lucide-react": "^0.220.0", + "match-sorter": "^6.3.4", "rc-input-number": "^7.4.2", "react": "^18.2.0", "react-dnd": "^16.0.1", @@ -66,6 +76,7 @@ "react-hook-form": "^7.43.9", "react-lazy-load-image-component": "^1.6.0", "react-markdown": "^8.0.6", + "react-resizable-panels": "^1.0.9", "react-router-dom": "^6.11.2", "react-textarea-autosize": "^8.4.0", "react-transition-group": "^4.4.5", @@ -115,7 +126,7 @@ "tailwindcss": "^3.4.1", "ts-jest": "^29.1.0", "typescript": "^5.0.4", - "vite": "^5.0.7", + "vite": "^5.1.1", "vite-plugin-html": "^3.2.0", "vite-plugin-node-polyfills": "^0.17.0" } diff --git a/client/public/assets/ShuttleAI_Fibonacci.png b/client/public/assets/ShuttleAI_Fibonacci.png new file mode 100644 index 00000000000..eddeb4e362d Binary files /dev/null and b/client/public/assets/ShuttleAI_Fibonacci.png differ diff --git a/client/public/assets/anyscale.png b/client/public/assets/anyscale.png new file mode 100644 index 00000000000..d86830c76dd Binary files /dev/null and b/client/public/assets/anyscale.png differ diff --git a/client/public/assets/cohere.png b/client/public/assets/cohere.png new file mode 100644 index 00000000000..3da0b837371 Binary files /dev/null and b/client/public/assets/cohere.png differ diff --git a/client/public/assets/fireworks.png b/client/public/assets/fireworks.png new file mode 100644 index 00000000000..4011e358cff Binary files /dev/null and b/client/public/assets/fireworks.png differ diff --git a/client/public/assets/groq.png b/client/public/assets/groq.png new file mode 100644 index 00000000000..83ea028f95a Binary files /dev/null and b/client/public/assets/groq.png differ diff --git a/client/public/assets/mistral.png b/client/public/assets/mistral.png index ff2f3e8b63b..beaffab92cc 100644 Binary files a/client/public/assets/mistral.png and b/client/public/assets/mistral.png differ diff --git a/client/public/assets/ollama.png b/client/public/assets/ollama.png new file mode 100644 index 00000000000..53979f88708 Binary files /dev/null and b/client/public/assets/ollama.png differ diff --git a/client/public/assets/perplexity.png b/client/public/assets/perplexity.png new file mode 100644 index 00000000000..e3edc716d2a Binary files /dev/null and b/client/public/assets/perplexity.png differ diff --git a/client/public/assets/shuttleai.png b/client/public/assets/shuttleai.png new file mode 100644 index 00000000000..411b5ad3400 Binary files /dev/null and b/client/public/assets/shuttleai.png differ diff --git a/client/public/assets/together.png b/client/public/assets/together.png new file mode 100644 index 00000000000..0401507937e Binary files /dev/null and b/client/public/assets/together.png differ diff --git a/client/public/fonts/Inter-Bold.woff2 b/client/public/fonts/Inter-Bold.woff2 new file mode 100644 index 00000000000..0f1b157633c Binary files /dev/null and b/client/public/fonts/Inter-Bold.woff2 differ diff --git a/client/public/fonts/Inter-BoldItalic.woff2 b/client/public/fonts/Inter-BoldItalic.woff2 new file mode 100644 index 00000000000..bc50f24c873 Binary files /dev/null and b/client/public/fonts/Inter-BoldItalic.woff2 differ diff --git a/client/public/fonts/Inter-Italic.woff2 b/client/public/fonts/Inter-Italic.woff2 new file mode 100644 index 00000000000..4c24ce28152 Binary files /dev/null and b/client/public/fonts/Inter-Italic.woff2 differ diff --git a/client/public/fonts/Inter-Regular.woff2 b/client/public/fonts/Inter-Regular.woff2 new file mode 100644 index 00000000000..b8699af29b0 Binary files /dev/null and b/client/public/fonts/Inter-Regular.woff2 differ diff --git a/client/public/fonts/Inter-SemiBold.woff2 b/client/public/fonts/Inter-SemiBold.woff2 new file mode 100644 index 00000000000..95c48b184ea Binary files /dev/null and b/client/public/fonts/Inter-SemiBold.woff2 differ diff --git a/client/public/fonts/Inter-SemiBoldItalic.woff2 b/client/public/fonts/Inter-SemiBoldItalic.woff2 new file mode 100644 index 00000000000..ddfe19e839c Binary files /dev/null and b/client/public/fonts/Inter-SemiBoldItalic.woff2 differ diff --git a/client/public/fonts/roboto-mono-latin-400-italic.woff2 b/client/public/fonts/roboto-mono-latin-400-italic.woff2 new file mode 100644 index 00000000000..75d29cff8e3 Binary files /dev/null and b/client/public/fonts/roboto-mono-latin-400-italic.woff2 differ diff --git a/client/public/fonts/roboto-mono-latin-400-normal.woff2 b/client/public/fonts/roboto-mono-latin-400-normal.woff2 new file mode 100644 index 00000000000..53d081f3a53 Binary files /dev/null and b/client/public/fonts/roboto-mono-latin-400-normal.woff2 differ diff --git a/client/public/fonts/roboto-mono-latin-700-normal.woff2 b/client/public/fonts/roboto-mono-latin-700-normal.woff2 new file mode 100644 index 00000000000..92fe38dd414 Binary files /dev/null and b/client/public/fonts/roboto-mono-latin-700-normal.woff2 differ diff --git a/client/public/fonts/signifier-bold-italic.woff2 b/client/public/fonts/signifier-bold-italic.woff2 deleted file mode 100644 index cebb25db24a..00000000000 Binary files a/client/public/fonts/signifier-bold-italic.woff2 and /dev/null differ diff --git a/client/public/fonts/signifier-bold.woff2 b/client/public/fonts/signifier-bold.woff2 deleted file mode 100644 index b76fecbacb3..00000000000 Binary files a/client/public/fonts/signifier-bold.woff2 and /dev/null differ diff --git a/client/public/fonts/signifier-light-italic.woff2 b/client/public/fonts/signifier-light-italic.woff2 deleted file mode 100644 index dc144f106c8..00000000000 Binary files a/client/public/fonts/signifier-light-italic.woff2 and /dev/null differ diff --git a/client/public/fonts/signifier-light.woff2 b/client/public/fonts/signifier-light.woff2 deleted file mode 100644 index 1077c6b9e9c..00000000000 Binary files a/client/public/fonts/signifier-light.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-buch-kursiv.woff2 b/client/public/fonts/soehne-buch-kursiv.woff2 deleted file mode 100644 index 8d4b03588c2..00000000000 Binary files a/client/public/fonts/soehne-buch-kursiv.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-buch.woff2 b/client/public/fonts/soehne-buch.woff2 deleted file mode 100644 index b1ceb94fa0d..00000000000 Binary files a/client/public/fonts/soehne-buch.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-halbfett-kursiv.woff2 b/client/public/fonts/soehne-halbfett-kursiv.woff2 deleted file mode 100644 index f7fd3c64b00..00000000000 Binary files a/client/public/fonts/soehne-halbfett-kursiv.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-halbfett.woff2 b/client/public/fonts/soehne-halbfett.woff2 deleted file mode 100644 index 19ed66001ea..00000000000 Binary files a/client/public/fonts/soehne-halbfett.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-kraftig-kursiv.woff2 b/client/public/fonts/soehne-kraftig-kursiv.woff2 deleted file mode 100644 index 669ab6920f2..00000000000 Binary files a/client/public/fonts/soehne-kraftig-kursiv.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-kraftig.woff2 b/client/public/fonts/soehne-kraftig.woff2 deleted file mode 100644 index 59c98a170f6..00000000000 Binary files a/client/public/fonts/soehne-kraftig.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-mono-buch-kursiv.woff2 b/client/public/fonts/soehne-mono-buch-kursiv.woff2 deleted file mode 100644 index c20b7426345..00000000000 Binary files a/client/public/fonts/soehne-mono-buch-kursiv.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-mono-buch.woff2 b/client/public/fonts/soehne-mono-buch.woff2 deleted file mode 100644 index 68e14f30396..00000000000 Binary files a/client/public/fonts/soehne-mono-buch.woff2 and /dev/null differ diff --git a/client/public/fonts/soehne-mono-halbfett.woff2 b/client/public/fonts/soehne-mono-halbfett.woff2 deleted file mode 100644 index e14cbdc5361..00000000000 Binary files a/client/public/fonts/soehne-mono-halbfett.woff2 and /dev/null differ diff --git a/client/src/App.jsx b/client/src/App.jsx index 10c9ab9b509..ce2ec3b6dec 100644 --- a/client/src/App.jsx +++ b/client/src/App.jsx @@ -6,7 +6,7 @@ import { HTML5Backend } from 'react-dnd-html5-backend'; import { ReactQueryDevtools } from '@tanstack/react-query-devtools'; import { QueryClient, QueryClientProvider, QueryCache } from '@tanstack/react-query'; import { ScreenshotProvider, ThemeProvider, useApiErrorBoundary } from './hooks'; -import { ToastProvider, AssistantsProvider } from './Providers'; +import { ToastProvider } from './Providers'; import Toast from './components/ui/Toast'; import { router } from './routes'; @@ -29,14 +29,12 @@ const App = () => { - - - - - - - - + + + + + + diff --git a/client/src/Providers/AssistantsContext.tsx b/client/src/Providers/AssistantsContext.tsx index 51561887996..10079083a2f 100644 --- a/client/src/Providers/AssistantsContext.tsx +++ b/client/src/Providers/AssistantsContext.tsx @@ -1,14 +1,10 @@ +import { useForm, FormProvider } from 'react-hook-form'; import { createContext, useContext } from 'react'; +import { defaultAssistantFormValues } from 'librechat-data-provider'; import type { UseFormReturn } from 'react-hook-form'; -import type { CreationForm } from '~/common'; -import useCreationForm from './useCreationForm'; +import type { AssistantForm } from '~/common'; -// type AssistantsContextType = { -// // open: boolean; -// // setOpen: Dispatch>; -// form: UseFormReturn; -// }; -type AssistantsContextType = UseFormReturn; +type AssistantsContextType = UseFormReturn; export const AssistantsContext = createContext({} as AssistantsContextType); @@ -23,7 +19,9 @@ export function useAssistantsContext() { } export default function AssistantsProvider({ children }) { - const hookValues = useCreationForm(); + const methods = useForm({ + defaultValues: defaultAssistantFormValues, + }); - return {children}; + return {children}; } diff --git a/client/src/Providers/AssistantsMapContext.tsx b/client/src/Providers/AssistantsMapContext.tsx new file mode 100644 index 00000000000..850e7d31290 --- /dev/null +++ b/client/src/Providers/AssistantsMapContext.tsx @@ -0,0 +1,8 @@ +import { createContext, useContext } from 'react'; +import { useAssistantsMap } from '~/hooks/Assistants'; +type AssistantsMapContextType = ReturnType; + +export const AssistantsMapContext = createContext( + {} as AssistantsMapContextType, +); +export const useAssistantsMapContext = () => useContext(AssistantsMapContext); diff --git a/client/src/Providers/FileMapContext.tsx b/client/src/Providers/FileMapContext.tsx new file mode 100644 index 00000000000..2e189cacb7d --- /dev/null +++ b/client/src/Providers/FileMapContext.tsx @@ -0,0 +1,6 @@ +import { createContext, useContext } from 'react'; +import { useFileMap } from '~/hooks/Files'; +type FileMapContextType = ReturnType; + +export const FileMapContext = createContext({} as FileMapContextType); +export const useFileMapContext = () => useContext(FileMapContext); diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index ab8b65d785d..32e5c25dc49 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -2,4 +2,6 @@ export { default as ToastProvider } from './ToastContext'; export { default as AssistantsProvider } from './AssistantsContext'; export * from './ChatContext'; export * from './ToastContext'; +export * from './FileMapContext'; export * from './AssistantsContext'; +export * from './AssistantsMapContext'; diff --git a/client/src/Providers/useCreationForm.ts b/client/src/Providers/useCreationForm.ts deleted file mode 100644 index 6fadf4c947a..00000000000 --- a/client/src/Providers/useCreationForm.ts +++ /dev/null @@ -1,19 +0,0 @@ -// import { useState } from 'react'; -import { useForm } from 'react-hook-form'; -import type { CreationForm } from '~/common'; - -export default function useViewPromptForm() { - return useForm({ - defaultValues: { - assistant: '', - id: '', - name: '', - description: '', - instructions: '', - model: 'gpt-3.5-turbo-1106', - function: false, - code_interpreter: false, - retrieval: false, - }, - }); -} diff --git a/client/src/common/assistants-types.ts b/client/src/common/assistants-types.ts index 7dc6906e7a8..3b9ad27da36 100644 --- a/client/src/common/assistants-types.ts +++ b/client/src/common/assistants-types.ts @@ -1,19 +1,23 @@ -import type { Option } from './types'; +import { Capabilities } from 'librechat-data-provider'; import type { Assistant } from 'librechat-data-provider'; +import type { Option, ExtendedFile } from './types'; -export type TAssistantOption = string | (Option & Assistant); +export type TAssistantOption = + | string + | (Option & Assistant & { files?: Array<[string, ExtendedFile]> }); export type Actions = { - function: boolean; - code_interpreter: boolean; - retrieval: boolean; + [Capabilities.code_interpreter]: boolean; + [Capabilities.image_vision]: boolean; + [Capabilities.retrieval]: boolean; }; -export type CreationForm = { +export type AssistantForm = { assistant: TAssistantOption; id: string; name: string | null; description: string | null; instructions: string | null; model: string; + functions: string[]; } & Actions; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 1ca169a0c1c..a5a8a01c11b 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,21 +1,107 @@ import { FileSources } from 'librechat-data-provider'; +import type { ColumnDef } from '@tanstack/react-table'; +import type { SetterOrUpdater } from 'recoil'; import type { + TSetOption as SetOption, TConversation, TMessage, TPreset, TLoginUser, TUser, EModelEndpoint, + Action, + AuthTypeEnum, + AuthorizationTypeEnum, + TokenExchangeMethodEnum, } from 'librechat-data-provider'; import type { UseMutationResult } from '@tanstack/react-query'; +import type { LucideIcon } from 'lucide-react'; + +export type GenericSetter = (value: T | ((currentValue: T) => T)) => void; + +export type LastSelectedModels = Record; + +export const mainTextareaId = 'prompt-textarea'; + +export enum IconContext { + landing = 'landing', + menuItem = 'menu-item', + nav = 'nav', + message = 'message', +} + +export type NavLink = { + title: string; + label?: string; + icon: LucideIcon; + Component?: React.ComponentType; + onClick?: () => void; + variant?: 'default' | 'ghost'; + id: string; +}; + +export interface NavProps { + isCollapsed: boolean; + links: NavLink[]; + resize?: (size: number) => void; + defaultActive?: string; +} + +interface ColumnMeta { + meta: { + size: number | string; + }; +} + +export enum Panel { + builder = 'builder', + actions = 'actions', +} + +export type FileSetter = + | SetterOrUpdater> + | React.Dispatch>>; + +export type ActionAuthForm = { + /* General */ + type: AuthTypeEnum; + saved_auth_fields: boolean; + /* API key */ + api_key: string; // not nested + authorization_type: AuthorizationTypeEnum; + custom_auth_header: string; + /* OAuth */ + oauth_client_id: string; // not nested + oauth_client_secret: string; // not nested + authorization_url: string; + client_url: string; + scope: string; + token_exchange_method: TokenExchangeMethodEnum; +}; + +export type AssistantPanelProps = { + index?: number; + action?: Action; + actions?: Action[]; + assistant_id?: string; + activePanel?: string; + setAction: React.Dispatch>; + setCurrentAssistantId: React.Dispatch>; + setActivePanel: React.Dispatch>; +}; + +export type AugmentedColumnDef = ColumnDef & ColumnMeta; + +export type TSetOption = SetOption; -export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void; export type TSetExample = ( i: number, type: string, newValue: number | string | boolean | null, ) => void; +export const defaultDebouncedDelay = 450; + export enum ESide { Top = 'top', Right = 'right', @@ -72,7 +158,7 @@ export type TSetOptionsPayload = { setAgentOption: TSetOption; // getConversation: () => TConversation | TPreset | null; checkPluginSelection: (value: string) => boolean; - setTools: (newValue: string) => void; + setTools: (newValue: string, remove?: boolean) => void; }; export type TPresetItemProps = { @@ -99,6 +185,7 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; + resubmitFiles?: boolean; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; @@ -136,9 +223,9 @@ export type TAdditionalProps = { setSiblingIdx: (value: number) => void; }; -export type TMessageContent = TInitialProps & TAdditionalProps; +export type TMessageContentProps = TInitialProps & TAdditionalProps; -export type TText = Pick; +export type TText = Pick & { className?: string }; export type TEditProps = Pick & Omit; export type TDisplayProps = TText & @@ -172,6 +259,11 @@ export type TDialogProps = { onOpenChange: (open: boolean) => void; }; +export type TPluginStoreDialogProps = { + isOpen: boolean; + setIsOpen: (open: boolean) => void; +}; + export type TResError = { response: { data: { message: string } }; message: string; @@ -198,7 +290,7 @@ export type TAuthConfig = { test?: boolean; }; -export type IconProps = Pick & +export type IconProps = Pick & Pick & { size?: number; button?: boolean; @@ -207,6 +299,8 @@ export type IconProps = Pick & className?: string; endpoint?: EModelEndpoint | string | null; endpointType?: EModelEndpoint | null; + assistantName?: string; + error?: boolean; }; export type Option = Record & { @@ -214,13 +308,15 @@ export type Option = Record & { value: string | number | null; }; +export type OptionWithIcon = Option & { icon?: React.ReactNode }; + export type TOptionSettings = { showExamples?: boolean; isCodeChat?: boolean; }; export interface ExtendedFile { - file: File; + file?: File; file_id: string; temp_file_id?: string; type?: string; @@ -229,9 +325,17 @@ export interface ExtendedFile { width?: number; height?: number; size: number; - preview: string; + preview?: string; progress: number; source?: FileSources; + attached?: boolean; + embedded?: boolean; } export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; + +export interface SwitcherProps { + endpoint?: EModelEndpoint | null; + endpointKeyProvided: boolean; + isCollapsed: boolean; +} diff --git a/client/src/components/Auth/Login.tsx b/client/src/components/Auth/Login.tsx index b7eeb59a5e8..c811ac0a146 100644 --- a/client/src/components/Auth/Login.tsx +++ b/client/src/components/Auth/Login.tsx @@ -3,6 +3,8 @@ import { useNavigate } from 'react-router-dom'; import { useGetStartupConfig } from 'librechat-data-provider/react-query'; import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components'; import { useAuthContext } from '~/hooks/AuthContext'; +import { ThemeSelector } from '~/components/ui'; +import SocialButton from './SocialButton'; import { getLoginError } from '~/utils'; import { useLocalize } from '~/hooks'; import LoginForm from './LoginForm'; @@ -11,7 +13,6 @@ function Login() { const { login, error, isAuthenticated } = useAuthContext(); const { data: startupConfig } = useGetStartupConfig(); const localize = useLocalize(); - const navigate = useNavigate(); useEffect(() => { @@ -20,112 +21,155 @@ function Login() { } }, [isAuthenticated, navigate]); + if (!startupConfig) { + return null; + } + + const socialLogins = startupConfig.socialLogins ?? []; + + const providerComponents = { + discord: ( + + ), + facebook: ( + + ), + github: ( + + ), + google: ( + + ), + openid: ( + + startupConfig.openidImageUrl ? ( + OpenID Logo + ) : ( + + ) + } + label={startupConfig.openidLabel} + id="openid" + /> + ), + }; + + const privacyPolicy = startupConfig.interface?.privacyPolicy; + const termsOfService = startupConfig.interface?.termsOfService; + + const privacyPolicyRender = privacyPolicy?.externalUrl && ( + + {localize('com_ui_privacy_policy')} + + ); + + const termsOfServiceRender = termsOfService?.externalUrl && ( + + {localize('com_ui_terms_of_service')} + + ); + return ( -
-
-

+
+
+ +
+
+

{localize('com_auth_welcome_back')}

{error && (
{localize(getLoginError(error))}
)} - {startupConfig?.emailLoginEnabled && } - {startupConfig?.registrationEnabled && ( -

+ {startupConfig.emailLoginEnabled && } + {startupConfig.registrationEnabled && ( +

{' '} {localize('com_auth_no_account')}{' '} - + {localize('com_auth_sign_up')}

)} - {startupConfig?.socialLoginEnabled && startupConfig?.emailLoginEnabled && ( - <> -
-
Or
-
-
- - )} - {startupConfig?.googleLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.facebookLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.openidLoginEnabled && startupConfig?.socialLoginEnabled && ( + {startupConfig.socialLoginEnabled && ( <> -
- - {startupConfig.openidImageUrl ? ( - OpenID Logo - ) : ( - - )} -

{startupConfig.openidLabel}

-
+ {startupConfig.emailLoginEnabled && ( + <> +
+
+ Or +
+
+
+ + )} +
+ {socialLogins.map((provider) => providerComponents[provider] || null)}
)} - {startupConfig?.githubLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.discordLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - +
+
+ {privacyPolicyRender} + {privacyPolicyRender && termsOfServiceRender && ( +
)} + {termsOfServiceRender}
); diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index eddb824f764..102c4826576 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -1,3 +1,4 @@ +import React from 'react'; import { useForm } from 'react-hook-form'; import { useLocalize } from '~/hooks'; import { TLoginUser } from 'librechat-data-provider'; @@ -6,15 +7,23 @@ type TLoginFormProps = { onSubmit: (data: TLoginUser) => void; }; -function LoginForm({ onSubmit }: TLoginFormProps) { +const LoginForm: React.FC = ({ onSubmit }) => { const localize = useLocalize(); - const { register, handleSubmit, formState: { errors }, } = useForm(); + const renderError = (fieldName: string) => { + const errorMessage = errors[fieldName]?.message; + return errorMessage ? ( + + {String(errorMessage)} + + ) : null; + }; + return (
+ />
- {errors.email && ( - - {/* @ts-ignore not sure why*/} - {errors.email.message} - - )} + {renderError('email')}
@@ -71,35 +65,23 @@ function LoginForm({ onSubmit }: TLoginFormProps) { aria-label={localize('com_auth_password')} {...register('password', { required: localize('com_auth_password_required'), - minLength: { - value: 8, - message: localize('com_auth_password_min_length'), - }, - maxLength: { - value: 128, - message: localize('com_auth_password_max_length'), - }, + minLength: { value: 8, message: localize('com_auth_password_min_length') }, + maxLength: { value: 128, message: localize('com_auth_password_max_length') }, })} aria-invalid={!!errors.password} - className="peer block w-full appearance-none rounded-md border border-gray-300 bg-gray-50 px-2.5 pb-2.5 pt-5 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0" + className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-black/10 bg-white px-2.5 pb-2.5 pt-5 text-sm text-gray-800 focus:border-green-500 focus:outline-none dark:border-white/20 dark:bg-gray-900 dark:text-white dark:focus:border-green-500" placeholder=" " - > + />
- - {errors.password && ( - - {/* @ts-ignore not sure why*/} - {errors.password.message} - - )} + {renderError('password')}
- + {localize('com_auth_password_forgot')}
@@ -107,12 +89,13 @@ function LoginForm({ onSubmit }: TLoginFormProps) { aria-label="Sign in" data-testid="login-button" type="submit" - className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none"> + className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-550 focus:bg-green-550 focus:outline-none disabled:cursor-not-allowed disabled:hover:bg-green-500" + > {localize('com_auth_continue')}
); -} +}; export default LoginForm; diff --git a/client/src/components/Auth/Registration.tsx b/client/src/components/Auth/Registration.tsx index 9ef96048874..02e462a0f24 100644 --- a/client/src/components/Auth/Registration.tsx +++ b/client/src/components/Auth/Registration.tsx @@ -1,15 +1,16 @@ import { useForm } from 'react-hook-form'; -import { useState, useEffect } from 'react'; import { useNavigate } from 'react-router-dom'; +import React, { useState, useEffect } from 'react'; import { useRegisterUserMutation, useGetStartupConfig } from 'librechat-data-provider/react-query'; import type { TRegisterUser } from 'librechat-data-provider'; import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components'; +import { ThemeSelector } from '~/components/ui'; +import SocialButton from './SocialButton'; import { useLocalize } from '~/hooks'; -function Registration() { +const Registration: React.FC = () => { const navigate = useNavigate(); const { data: startupConfig } = useGetStartupConfig(); - const localize = useLocalize(); const { @@ -22,23 +23,20 @@ function Registration() { const [error, setError] = useState(false); const [errorMessage, setErrorMessage] = useState(''); const registerUser = useRegisterUserMutation(); - const password = watch('password'); - const onRegisterUserFormSubmit = (data: TRegisterUser) => { - registerUser.mutate(data, { - onSuccess: () => { - navigate('/c/new'); - }, - onError: (error) => { - setError(true); + const onRegisterUserFormSubmit = async (data: TRegisterUser) => { + try { + await registerUser.mutateAsync(data); + navigate('/c/new'); + } catch (error) { + setError(true); + //@ts-ignore - error is of type unknown + if (error.response?.data?.message) { //@ts-ignore - error is of type unknown - if (error.response?.data?.message) { - //@ts-ignore - error is of type unknown - setErrorMessage(error.response?.data?.message); - } - }, - }); + setErrorMessage(error.response?.data?.message); + } + } }; useEffect(() => { @@ -47,15 +45,123 @@ function Registration() { } }, [startupConfig, navigate]); + if (!startupConfig) { + return null; + } + + const socialLogins = startupConfig.socialLogins ?? []; + + const renderInput = (id: string, label: string, type: string, validation: object) => ( +
+
+ + +
+ {errors[id] && ( + + {String(errors[id]?.message) ?? ''} + + )} +
+ ); + + const providerComponents = { + discord: ( + + ), + facebook: ( + + ), + github: ( + + ), + google: ( + + ), + openid: ( + + startupConfig.openidImageUrl ? ( + OpenID Logo + ) : ( + + ) + } + label={startupConfig.openidLabel} + id="openid" + /> + ), + }; + return ( -
-
-

+
+
+ +
+
+

{localize('com_auth_create_account')}

{error && (
@@ -66,308 +172,95 @@ function Registration() { className="mt-6" aria-label="Registration form" method="POST" - onSubmit={handleSubmit((data) => onRegisterUserFormSubmit(data))} + onSubmit={handleSubmit(onRegisterUserFormSubmit)} > -
-
- - -
- - {errors.name && ( - - {/* @ts-ignore not sure why*/} - {errors.name.message} - - )} -
-
-
- - -
- - {errors.username && ( - - {/* @ts-ignore not sure why */} - {errors.username.message} - - )} -
-
-
- - -
- {errors.email && ( - - {/* @ts-ignore - Type 'string | FieldError | Merge> | undefined' is not assignable to type 'ReactNode' */} - {errors.email.message} - - )} -
-
-
- - -
- - {errors.password && ( - - {/* @ts-ignore not sure why */} - {errors.password.message} - - )} -
-
-
- { - // e.preventDefault(); - // return false; - // }} - {...register('confirm_password', { - validate: (value) => - value === password || localize('com_auth_password_not_match'), - })} - aria-invalid={!!errors.confirm_password} - className="peer block w-full appearance-none rounded-md border border-gray-300 bg-gray-50 px-2.5 pb-2.5 pt-5 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0" - placeholder=" " - > - -
- - {errors.confirm_password && ( - - {/* @ts-ignore not sure why */} - {errors.confirm_password.message} - - )} -
+ {renderInput('name', 'com_auth_full_name', 'text', { + required: localize('com_auth_name_required'), + minLength: { + value: 3, + message: localize('com_auth_name_min_length'), + }, + maxLength: { + value: 80, + message: localize('com_auth_name_max_length'), + }, + })} + {renderInput('username', 'com_auth_username', 'text', { + minLength: { + value: 2, + message: localize('com_auth_username_min_length'), + }, + maxLength: { + value: 80, + message: localize('com_auth_username_max_length'), + }, + })} + {renderInput('email', 'com_auth_email', 'email', { + required: localize('com_auth_email_required'), + minLength: { + value: 1, + message: localize('com_auth_email_min_length'), + }, + maxLength: { + value: 120, + message: localize('com_auth_email_max_length'), + }, + pattern: { + value: /\S+@\S+\.\S+/, + message: localize('com_auth_email_pattern'), + }, + })} + {renderInput('password', 'com_auth_password', 'password', { + required: localize('com_auth_password_required'), + minLength: { + value: 8, + message: localize('com_auth_password_min_length'), + }, + maxLength: { + value: 128, + message: localize('com_auth_password_max_length'), + }, + })} + {renderInput('confirm_password', 'com_auth_password_confirm', 'password', { + validate: (value) => value === password || localize('com_auth_password_not_match'), + })}
-

- {' '} +

{localize('com_auth_already_have_account')}{' '} - + {localize('com_auth_login')}

- {startupConfig?.socialLoginEnabled && ( - <> -
-
Or
-
-
- - )} - {startupConfig?.googleLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.facebookLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.openidLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> - - - )} - {startupConfig?.githubLoginEnabled && startupConfig?.socialLoginEnabled && ( + {startupConfig.socialLoginEnabled && ( <> - - - )} - {startupConfig?.discordLoginEnabled && startupConfig?.socialLoginEnabled && ( - <> -
- - -

{localize('com_auth_discord_login')}

-
+ {startupConfig.emailLoginEnabled && ( + <> +
+
+ Or +
+
+
+ + )} +
+ {socialLogins.map((provider) => providerComponents[provider] || null)}
)}
); -} +}; export default Registration; diff --git a/client/src/components/Auth/RequestPasswordReset.tsx b/client/src/components/Auth/RequestPasswordReset.tsx index 4980b4f27f6..ded90c7db65 100644 --- a/client/src/components/Auth/RequestPasswordReset.tsx +++ b/client/src/components/Auth/RequestPasswordReset.tsx @@ -5,6 +5,7 @@ import { useRequestPasswordResetMutation, } from 'librechat-data-provider/react-query'; import type { TRequestPasswordReset, TRequestPasswordResetResponse } from 'librechat-data-provider'; +import { ThemeSelector } from '~/components/ui'; import { useLocalize } from '~/hooks'; function RequestPasswordReset() { @@ -48,7 +49,7 @@ function RequestPasswordReset() { setBodyText( {localize('com_auth_click')}{' '} - + {localize('com_auth_here')} {' '} {localize('com_auth_to_reset_your_password')} @@ -61,83 +62,101 @@ function RequestPasswordReset() { } }, [requestPasswordReset.isSuccess, config.data?.emailEnabled, resetLink, localize]); + const renderFormContent = () => { + if (bodyText) { + return ( +
+ {bodyText} +
+ ); + } else { + return ( +
+
+
+ + +
+ {errors.email && ( + + {/* @ts-ignore not sure why */} + {errors.email.message} + + )} +
+
+ + +
+
+ ); + } + }; + return ( -
-
-

{headerText}

+
+
+ +
+
+

+ {headerText} +

{requestError && (
{localize('com_auth_error_reset_password')}
)} - {bodyText ? ( -
- {bodyText} -
- ) : ( -
-
-
- - -
- {errors.email && ( - - {/* @ts-ignore not sure why */} - {errors.email.message} - - )} -
-
- -
-
- )} + {renderFormContent()}
); diff --git a/client/src/components/Auth/ResetPassword.tsx b/client/src/components/Auth/ResetPassword.tsx index 664c95377ce..bf6aa7944a3 100644 --- a/client/src/components/Auth/ResetPassword.tsx +++ b/client/src/components/Auth/ResetPassword.tsx @@ -3,6 +3,7 @@ import { useForm } from 'react-hook-form'; import { useNavigate, useSearchParams } from 'react-router-dom'; import { useResetPasswordMutation } from 'librechat-data-provider/react-query'; import type { TResetPassword } from 'librechat-data-provider'; +import { ThemeSelector } from '~/components/ui'; import { useLocalize } from '~/hooks'; function ResetPassword() { @@ -29,13 +30,16 @@ function ResetPassword() { if (resetPassword.isSuccess) { return ( -
-
-

+
+
+ +
+
+

{localize('com_auth_reset_password_success')}

{localize('com_auth_login_with_new_password')} @@ -52,14 +56,17 @@ function ResetPassword() { ); } else { return ( -
-
-

+
+
+ +
+
+

{localize('com_auth_reset_password')}

{resetError && (
{localize('com_auth_error_invalid_reset_token')}{' '} @@ -108,19 +115,19 @@ function ResetPassword() { }, })} aria-invalid={!!errors.password} - className="peer block w-full appearance-none rounded-md border border-gray-300 bg-gray-50 px-2.5 pb-2.5 pt-5 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0" + className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-black/10 bg-white px-2.5 pb-2.5 pt-5 text-sm text-gray-800 focus:border-green-500 focus:outline-none dark:border-white/20 dark:bg-gray-900 dark:text-white dark:focus:border-green-500" placeholder=" " >
{errors.password && ( - + {/* @ts-ignore not sure why */} {errors.password.message} @@ -142,30 +149,30 @@ function ResetPassword() { value === password || localize('com_auth_password_not_match'), })} aria-invalid={!!errors.confirm_password} - className="peer block w-full appearance-none rounded-md border border-gray-300 bg-gray-50 px-2.5 pb-2.5 pt-5 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0" + className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-black/10 bg-white px-2.5 pb-2.5 pt-5 text-sm text-gray-800 focus:border-green-500 focus:outline-none dark:border-white/20 dark:bg-gray-900 dark:text-white dark:focus:border-green-500" placeholder=" " >
{errors.confirm_password && ( - + {/* @ts-ignore not sure why */} {errors.confirm_password.message} )} {errors.token && ( - + {/* @ts-ignore not sure why */} {errors.token.message} )} {errors.userId && ( - + {/* @ts-ignore not sure why */} {errors.userId.message} @@ -176,7 +183,7 @@ function ResetPassword() { disabled={!!errors.password || !!errors.confirm_password} type="submit" aria-label={localize('com_auth_submit_registration')} - className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none" + className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-all duration-300 hover:bg-green-550 focus:bg-green-550 focus:outline-none" > {localize('com_auth_continue')} diff --git a/client/src/components/Auth/SocialButton.tsx b/client/src/components/Auth/SocialButton.tsx new file mode 100644 index 00000000000..7e76c6f7638 --- /dev/null +++ b/client/src/components/Auth/SocialButton.tsx @@ -0,0 +1,60 @@ +import React, { useState } from 'react'; + +const SocialButton = ({ id, enabled, serverDomain, oauthPath, Icon, label }) => { + const [isHovered, setIsHovered] = useState(false); + const [isPressed, setIsPressed] = useState(false); + const [activeButton, setActiveButton] = useState(null); + + if (!enabled) { + return null; + } + + const handleMouseEnter = () => { + setIsHovered(true); + }; + + const handleMouseLeave = () => { + setIsHovered(false); + }; + + const handleMouseDown = () => { + setIsPressed(true); + setActiveButton(id); + }; + + const handleMouseUp = () => { + setIsPressed(false); + }; + + const getButtonStyles = () => { + // Define Tailwind CSS classes based on state + const baseStyles = 'border border-solid border-gray-300 dark:border-gray-800 transition-colors'; + + const pressedStyles = 'bg-blue-200 border-blue-200 dark:bg-blue-900 dark:border-blue-600'; + const hoverStyles = 'bg-gray-100 dark:bg-gray-700'; + + return `${baseStyles} ${ + isPressed && activeButton === id ? pressedStyles : isHovered ? hoverStyles : '' + }`; + }; + + return ( + + ); +}; + +export default SocialButton; diff --git a/client/src/components/Auth/__tests__/Login.spec.tsx b/client/src/components/Auth/__tests__/Login.spec.tsx index 5a70a5fec54..a076ba5c9b6 100644 --- a/client/src/components/Auth/__tests__/Login.spec.tsx +++ b/client/src/components/Auth/__tests__/Login.spec.tsx @@ -31,13 +31,14 @@ const setup = ({ isLoading: false, isError: false, data: { - googleLoginEnabled: true, + socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'], + discordLoginEnabled: true, facebookLoginEnabled: true, + githubLoginEnabled: true, + googleLoginEnabled: true, openidLoginEnabled: true, openidLabel: 'Test OpenID', openidImageUrl: 'http://test-server.com', - githubLoginEnabled: true, - discordLoginEnabled: true, registrationEnabled: true, emailLoginEnabled: true, socialLoginEnabled: true, @@ -78,23 +79,23 @@ test('renders login form', () => { expect(getByRole('button', { name: /Sign in/i })).toBeInTheDocument(); expect(getByRole('link', { name: /Sign up/i })).toBeInTheDocument(); expect(getByRole('link', { name: /Sign up/i })).toHaveAttribute('href', '/register'); - expect(getByRole('link', { name: /Login with Google/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Google/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Google/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Google/i })).toHaveAttribute( 'href', 'mock-server/oauth/google', ); - expect(getByRole('link', { name: /Login with Facebook/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Facebook/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Facebook/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Facebook/i })).toHaveAttribute( 'href', 'mock-server/oauth/facebook', ); - expect(getByRole('link', { name: /Login with Github/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Github/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Github/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Github/i })).toHaveAttribute( 'href', 'mock-server/oauth/github', ); - expect(getByRole('link', { name: /Login with Discord/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Discord/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Discord/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Discord/i })).toHaveAttribute( 'href', 'mock-server/oauth/discord', ); diff --git a/client/src/components/Auth/__tests__/Registration.spec.tsx b/client/src/components/Auth/__tests__/Registration.spec.tsx index 324d593a1a6..d4a98900709 100644 --- a/client/src/components/Auth/__tests__/Registration.spec.tsx +++ b/client/src/components/Auth/__tests__/Registration.spec.tsx @@ -32,14 +32,14 @@ const setup = ({ isLoading: false, isError: false, data: { - googleLoginEnabled: true, + socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'], + discordLoginEnabled: true, facebookLoginEnabled: true, + githubLoginEnabled: true, + googleLoginEnabled: true, openidLoginEnabled: true, openidLabel: 'Test OpenID', openidImageUrl: 'http://test-server.com', - githubLoginEnabled: true, - discordLoginEnabled: true, - emailLoginEnabled: true, registrationEnabled: true, socialLoginEnabled: true, serverDomain: 'mock-server', @@ -85,23 +85,23 @@ test('renders registration form', () => { expect(getByRole('button', { name: /Submit registration/i })).toBeInTheDocument(); expect(getByRole('link', { name: 'Login' })).toBeInTheDocument(); expect(getByRole('link', { name: 'Login' })).toHaveAttribute('href', '/login'); - expect(getByRole('link', { name: /Login with Google/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Google/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Google/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Google/i })).toHaveAttribute( 'href', 'mock-server/oauth/google', ); - expect(getByRole('link', { name: /Login with Facebook/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Facebook/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Facebook/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Facebook/i })).toHaveAttribute( 'href', 'mock-server/oauth/facebook', ); - expect(getByRole('link', { name: /Login with Github/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Github/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Github/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Github/i })).toHaveAttribute( 'href', 'mock-server/oauth/github', ); - expect(getByRole('link', { name: /Login with Discord/i })).toBeInTheDocument(); - expect(getByRole('link', { name: /Login with Discord/i })).toHaveAttribute( + expect(getByRole('link', { name: /Continue with Discord/i })).toBeInTheDocument(); + expect(getByRole('link', { name: /Continue with Discord/i })).toHaveAttribute( 'href', 'mock-server/oauth/discord', ); diff --git a/client/src/components/Chat/ChatView.tsx b/client/src/components/Chat/ChatView.tsx index 30a7edc187e..604c8f1e78a 100644 --- a/client/src/components/Chat/ChatView.tsx +++ b/client/src/components/Chat/ChatView.tsx @@ -2,16 +2,13 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; import { useParams } from 'react-router-dom'; import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query'; -import { useChatHelpers, useSSE } from '~/hooks'; -// import GenerationButtons from './Input/GenerationButtons'; +import { ChatContext, useFileMapContext } from '~/Providers'; import MessagesView from './Messages/MessagesView'; -// import OptionsBar from './Input/OptionsBar'; -import { useGetFiles } from '~/data-provider'; -import { buildTree, mapFiles } from '~/utils'; +import { useChatHelpers, useSSE } from '~/hooks'; import { Spinner } from '~/components/svg'; -import { ChatContext } from '~/Providers'; import Presentation from './Presentation'; import ChatForm from './Input/ChatForm'; +import { buildTree } from '~/utils'; import Landing from './Landing'; import Header from './Header'; import Footer from './Footer'; @@ -22,9 +19,7 @@ function ChatView({ index = 0 }: { index?: number }) { const submissionAtIndex = useRecoilValue(store.submissionByIndex(0)); useSSE(submissionAtIndex); - const { data: fileMap } = useGetFiles({ - select: mapFiles, - }); + const fileMap = useFileMapContext(); const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', { select: (data) => { @@ -38,7 +33,7 @@ function ChatView({ index = 0 }: { index?: number }) { return ( - + {isLoading && conversationId !== 'new' ? (
@@ -48,8 +43,6 @@ function ChatView({ index = 0 }: { index?: number }) { ) : ( } /> )} - {/* */} - {/* */}