Skip to content

Commit

Permalink
Check client/server compatibility (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Dec 20, 2023
1 parent aa3d96c commit 809e0ed
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from dstack._internal.server.settings import DEFAULT_PROJECT_NAME, SERVER_URL
from dstack._internal.server.utils.logging import configure_logging
from dstack._internal.server.utils.routers import (
check_client_server_compatibility,
error_detail,
error_forbidden,
get_server_client_error_details,
)
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -130,6 +130,18 @@ async def log_request(request: Request, call_next):
)
return response

@app.middleware("http")
async def check_client_version(request: Request, call_next):
if not request.url.path.startswith("/api/") or request.url.path == "/api/docs":
return await call_next(request)
response = check_client_server_compatibility(
client_version=request.headers.get("x-api-version"),
server_version=settings.SERVER_API_VERSION,
)
if response is not None:
return response
return await call_next(request)

@app.get("/healthcheck")
async def healthcheck():
return JSONResponse(content={"status": "running"})
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path

from dstack import version

DSTACK_DIR_PATH = Path("~/.dstack/").expanduser()

SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server"))
Expand All @@ -25,6 +27,8 @@
"DSTACK_ALEMBIC_MIGRATIONS_LOCATION", "dstack._internal.server:migrations"
)

SERVER_API_VERSION = os.getenv("DSTACK_SERVER_API_VERSION", version.__version__)

SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None
SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED
LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None
Expand Down
57 changes: 57 additions & 0 deletions src/dstack/_internal/server/utils/routers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict, List, Optional

from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from packaging import version

from dstack._internal.core.errors import ServerClientError
from dstack._internal.server import settings


def error_detail(msg: str, code: Optional[str] = None, **kwargs) -> Dict:
Expand Down Expand Up @@ -53,3 +56,57 @@ def request_size_exceeded(request: Request, limit: int) -> bool:
if content_length > limit:
return True
return False


def check_client_server_compatibility(
client_version: Optional[str],
server_version: Optional[str],
) -> Optional[JSONResponse]:
"""
Returns `JSONResponse` with error if client/server versions are incompatible.
Returns `None` otherwise.
"""
if server_version is None:
return None
parsed_server_version = version.parse(server_version)
if client_version is None:
return error_incompatible_versions(client_version, server_version, ask_cli_update=True)
# latest allows client to bypass compatibility check (e.g. frontend)
if client_version == "latest":
return None
try:
parsed_client_version = version.parse(client_version)
except version.InvalidVersion:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"detail": get_server_client_error_details(
ServerClientError("Bad API version specified")
)
},
)
if parsed_client_version < parsed_server_version and (
parsed_client_version.major < parsed_server_version.major
or parsed_client_version.minor < parsed_server_version.minor
):
return error_incompatible_versions(client_version, server_version, ask_cli_update=True)
elif parsed_client_version > parsed_server_version and (
parsed_client_version.major > parsed_server_version.major
or parsed_client_version.minor > parsed_server_version.minor
):
return error_incompatible_versions(client_version, server_version, ask_cli_update=False)
return None


def error_incompatible_versions(
client_version: Optional[str],
server_version: str,
ask_cli_update: bool,
) -> JSONResponse:
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
if ask_cli_update:
msg += f" Update the dstack CLI: `pip install dstack=={server_version}`."
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": get_server_client_error_details(ServerClientError(msg=msg))},
)
3 changes: 3 additions & 0 deletions src/dstack/api/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import requests

from dstack import version
from dstack._internal.core.errors import ClientError, ServerClientError
from dstack._internal.utils.logging import get_logger
from dstack.api.server._backends import BackendsAPIClient
Expand Down Expand Up @@ -45,6 +46,8 @@ def __init__(self, base_url: str, token: str):
self._token = token
self._s = requests.session()
self._s.headers.update({"Authorization": f"Bearer {token}"})
if version.__version__ is not None:
self._s.headers.update({"X-API-VERSION": version.__version__})

@property
def users(self) -> UsersAPIClient:
Expand Down
Empty file.
84 changes: 84 additions & 0 deletions src/tests/_internal/server/utils/test_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Optional

import pytest

from dstack._internal.server.utils.routers import check_client_server_compatibility


class TestCheckClientServerCompatibility:
@pytest.mark.parametrize("client_version", ["12.12.12", None])
def test_returns_none_if_server_version_is_none(self, client_version: Optional[str]):
assert (
check_client_server_compatibility(
client_version=client_version,
server_version=None,
)
is None
)

@pytest.mark.parametrize(
"client_version,server_version",
[
("0.12.4", "0.12.4"),
("0.12.4", "0.12.5"),
("0.12.5", "0.12.4"),
("1.0.5", "1.0.6"),
("0.12.4", "0.12.5rc1"),
],
)
def test_returns_none_if_compatible(
self, client_version: Optional[str], server_version: Optional[str]
):
assert (
check_client_server_compatibility(
client_version=client_version,
server_version=server_version,
)
is None
)

@pytest.mark.parametrize(
"client_version,server_version",
[
("0.12.4", "0.13.0"),
("0.12.0", "1.12.0"),
],
)
def test_returns_error_if_client_version_smaller(
self, client_version: Optional[str], server_version: Optional[str]
):
res = check_client_server_compatibility(
client_version=client_version,
server_version=server_version,
)
assert res is not None

@pytest.mark.parametrize(
"client_version,server_version",
[
("0.13.0", "0.12.4"),
("1.12.0", "0.12.0"),
],
)
def test_returns_error_if_client_version_larger(
self, client_version: Optional[str], server_version: Optional[str]
):
res = check_client_server_compatibility(
client_version=client_version,
server_version=server_version,
)
assert res is not None

@pytest.mark.parametrize(
"server_version",
[
None,
"0.1.12",
],
)
def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]):
res = check_client_server_compatibility(
client_version="latest",
server_version=server_version,
)
assert res is None

0 comments on commit 809e0ed

Please sign in to comment.