From 809e0ed8e41549717fee890c192b4167beb173b2 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Dec 2023 15:35:17 +0500 Subject: [PATCH] Check client/server compatibility (#810) --- src/dstack/_internal/server/app.py | 14 +++- src/dstack/_internal/server/settings.py | 4 + src/dstack/_internal/server/utils/routers.py | 57 +++++++++++++ src/dstack/api/server/__init__.py | 3 + src/tests/_internal/server/utils/__init__.py | 0 .../_internal/server/utils/test_routers.py | 84 +++++++++++++++++++ 6 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 src/tests/_internal/server/utils/__init__.py create mode 100644 src/tests/_internal/server/utils/test_routers.py diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 555516e1d..6e983dfa7 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -30,8 +30,8 @@ from dstack._internal.server.settings import DEFAULT_PROJECT_NAME, SERVER_URL from dstack._internal.server.utils.logging import configure_logging from dstack._internal.server.utils.routers import ( + check_client_server_compatibility, error_detail, - error_forbidden, get_server_client_error_details, ) from dstack._internal.utils.logging import get_logger @@ -130,6 +130,18 @@ async def log_request(request: Request, call_next): ) return response + @app.middleware("http") + async def check_client_version(request: Request, call_next): + if not request.url.path.startswith("/api/") or request.url.path == "/api/docs": + return await call_next(request) + response = check_client_server_compatibility( + client_version=request.headers.get("x-api-version"), + server_version=settings.SERVER_API_VERSION, + ) + if response is not None: + return response + return await call_next(request) + @app.get("/healthcheck") async def healthcheck(): return JSONResponse(content={"status": "running"}) diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 290a6a263..566af9546 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -1,6 +1,8 @@ import os from pathlib import Path +from dstack import version + DSTACK_DIR_PATH = Path("~/.dstack/").expanduser() SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server")) @@ -25,6 +27,8 @@ "DSTACK_ALEMBIC_MIGRATIONS_LOCATION", "dstack._internal.server:migrations" ) +SERVER_API_VERSION = os.getenv("DSTACK_SERVER_API_VERSION", version.__version__) + SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index 1a5ca5be8..6f9ffd280 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -1,8 +1,11 @@ from typing import Dict, List, Optional from fastapi import HTTPException, Request, Response, status +from fastapi.responses import JSONResponse +from packaging import version from dstack._internal.core.errors import ServerClientError +from dstack._internal.server import settings def error_detail(msg: str, code: Optional[str] = None, **kwargs) -> Dict: @@ -53,3 +56,57 @@ def request_size_exceeded(request: Request, limit: int) -> bool: if content_length > limit: return True return False + + +def check_client_server_compatibility( + client_version: Optional[str], + server_version: Optional[str], +) -> Optional[JSONResponse]: + """ + Returns `JSONResponse` with error if client/server versions are incompatible. + Returns `None` otherwise. + """ + if server_version is None: + return None + parsed_server_version = version.parse(server_version) + if client_version is None: + return error_incompatible_versions(client_version, server_version, ask_cli_update=True) + # latest allows client to bypass compatibility check (e.g. frontend) + if client_version == "latest": + return None + try: + parsed_client_version = version.parse(client_version) + except version.InvalidVersion: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "detail": get_server_client_error_details( + ServerClientError("Bad API version specified") + ) + }, + ) + if parsed_client_version < parsed_server_version and ( + parsed_client_version.major < parsed_server_version.major + or parsed_client_version.minor < parsed_server_version.minor + ): + return error_incompatible_versions(client_version, server_version, ask_cli_update=True) + elif parsed_client_version > parsed_server_version and ( + parsed_client_version.major > parsed_server_version.major + or parsed_client_version.minor > parsed_server_version.minor + ): + return error_incompatible_versions(client_version, server_version, ask_cli_update=False) + return None + + +def error_incompatible_versions( + client_version: Optional[str], + server_version: str, + ask_cli_update: bool, +) -> JSONResponse: + msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})." + if ask_cli_update: + msg += f" Update the dstack CLI: `pip install dstack=={server_version}`." + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": get_server_client_error_details(ServerClientError(msg=msg))}, + ) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 7bd397e2d..2628c2e98 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -3,6 +3,7 @@ import requests +from dstack import version from dstack._internal.core.errors import ClientError, ServerClientError from dstack._internal.utils.logging import get_logger from dstack.api.server._backends import BackendsAPIClient @@ -45,6 +46,8 @@ def __init__(self, base_url: str, token: str): self._token = token self._s = requests.session() self._s.headers.update({"Authorization": f"Bearer {token}"}) + if version.__version__ is not None: + self._s.headers.update({"X-API-VERSION": version.__version__}) @property def users(self) -> UsersAPIClient: diff --git a/src/tests/_internal/server/utils/__init__.py b/src/tests/_internal/server/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/tests/_internal/server/utils/test_routers.py b/src/tests/_internal/server/utils/test_routers.py new file mode 100644 index 000000000..adcb50376 --- /dev/null +++ b/src/tests/_internal/server/utils/test_routers.py @@ -0,0 +1,84 @@ +from typing import Optional + +import pytest + +from dstack._internal.server.utils.routers import check_client_server_compatibility + + +class TestCheckClientServerCompatibility: + @pytest.mark.parametrize("client_version", ["12.12.12", None]) + def test_returns_none_if_server_version_is_none(self, client_version: Optional[str]): + assert ( + check_client_server_compatibility( + client_version=client_version, + server_version=None, + ) + is None + ) + + @pytest.mark.parametrize( + "client_version,server_version", + [ + ("0.12.4", "0.12.4"), + ("0.12.4", "0.12.5"), + ("0.12.5", "0.12.4"), + ("1.0.5", "1.0.6"), + ("0.12.4", "0.12.5rc1"), + ], + ) + def test_returns_none_if_compatible( + self, client_version: Optional[str], server_version: Optional[str] + ): + assert ( + check_client_server_compatibility( + client_version=client_version, + server_version=server_version, + ) + is None + ) + + @pytest.mark.parametrize( + "client_version,server_version", + [ + ("0.12.4", "0.13.0"), + ("0.12.0", "1.12.0"), + ], + ) + def test_returns_error_if_client_version_smaller( + self, client_version: Optional[str], server_version: Optional[str] + ): + res = check_client_server_compatibility( + client_version=client_version, + server_version=server_version, + ) + assert res is not None + + @pytest.mark.parametrize( + "client_version,server_version", + [ + ("0.13.0", "0.12.4"), + ("1.12.0", "0.12.0"), + ], + ) + def test_returns_error_if_client_version_larger( + self, client_version: Optional[str], server_version: Optional[str] + ): + res = check_client_server_compatibility( + client_version=client_version, + server_version=server_version, + ) + assert res is not None + + @pytest.mark.parametrize( + "server_version", + [ + None, + "0.1.12", + ], + ) + def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]): + res = check_client_server_compatibility( + client_version="latest", + server_version=server_version, + ) + assert res is None