Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dstack-proxy naming tweaks #1734

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions src/dstack/_internal/gateway/deps.py

This file was deleted.

Empty file.
Empty file.
8 changes: 8 additions & 0 deletions src/dstack/_internal/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dstack-proxy is a component responsible for proxying ingress HTTP traffic to
services and models hosted by dstack. It can also perform load balancing,
collect service usage stats, obtain SSL certificates, etc.

This component can run as a standalone web application on a gateway instance or
as part of the dstack-server web application.
"""
36 changes: 36 additions & 0 deletions src/dstack/_internal/proxy/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator

from fastapi import Depends, Request
from typing_extensions import Annotated

from dstack._internal.proxy.repos.base import BaseProxyRepo


class BaseProxyDependencyInjector(ABC):
"""
dstack-proxy uses different implementations of this injector in different
environments: within dstack-serer and on a gateway instance. An object with
the injector interface stored in FastAPI's
app.state.proxy_dependency_injector configures dstack-proxy to use a
specific set of dependencies, e.g. a specific repo implementation.
"""

@abstractmethod
async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
if False:
yield # show type checkers this is a generator


async def get_injector(request: Request) -> BaseProxyDependencyInjector:
injector = request.app.state.proxy_dependency_injector
if not isinstance(injector, BaseProxyDependencyInjector):
raise RuntimeError(f"Wrong BaseProxyDependencyInjector type {type(injector)}")
return injector


async def get_proxy_repo(
injector: Annotated[BaseProxyDependencyInjector, Depends(get_injector)],
) -> AsyncGenerator[BaseProxyRepo, None]:
async for repo in injector.get_repo():
yield repo
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/repos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""dstack-proxy data access layer"""
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Project(BaseModel):
ssh_private_key: str


class BaseGatewayRepo(ABC):
class BaseProxyRepo(ABC):
@abstractmethod
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, Optional

from dstack._internal.gateway.repos.base import BaseGatewayRepo, Project, Service
from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Service


class InMemoryGatewayRepo(BaseGatewayRepo):
class InMemoryProxyRepo(BaseProxyRepo):
def __init__(self) -> None:
self.services: Dict[str, Dict[str, Service]] = {}
self.projects: Dict[str, Project] = {}
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""dstack-proxy web endpoints"""
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from fastapi.responses import RedirectResponse, Response
from typing_extensions import Annotated

from dstack._internal.gateway.deps import get_gateway_repo
from dstack._internal.gateway.repos.base import BaseGatewayRepo
from dstack._internal.gateway.services import service_proxy
from dstack._internal.proxy.deps import get_proxy_repo
from dstack._internal.proxy.repos.base import BaseProxyRepo
from dstack._internal.proxy.services import service_proxy

REDIRECTED_HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]
PROXIED_HTTP_METHODS = REDIRECTED_HTTP_METHODS + ["OPTIONS"]
Expand All @@ -27,7 +27,7 @@ async def service_reverse_proxy(
run_name: str,
path: str,
request: Request,
repo: Annotated[BaseGatewayRepo, Depends(get_gateway_repo)],
repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)],
) -> Response:
return await service_proxy.proxy(project_name, run_name, path, request, repo)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""dstack-proxy business logic layer"""
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SSHTunnel,
UnixSocket,
)
from dstack._internal.gateway.repos.base import Project, Replica, Service
from dstack._internal.proxy.repos.base import Project, Replica, Service
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import FileContent

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import httpx
from starlette.requests import ClientDisconnect

from dstack._internal.gateway.repos.base import BaseGatewayRepo, Replica, Service
from dstack._internal.gateway.services.service_connection import service_replica_connection_pool
from dstack._internal.proxy.repos.base import BaseProxyRepo, Replica, Service
from dstack._internal.proxy.services.service_connection import service_replica_connection_pool
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -17,7 +17,7 @@ async def proxy(
run_name: str,
path: str,
request: fastapi.Request,
repo: BaseGatewayRepo,
repo: BaseProxyRepo,
) -> fastapi.responses.Response:
if "Upgrade" in request.headers:
raise fastapi.exceptions.HTTPException(
Expand Down Expand Up @@ -72,7 +72,7 @@ async def proxy(


async def get_replica_client(
project_name: str, service: Service, replica: Replica, repo: BaseGatewayRepo
project_name: str, service: Service, replica: Replica, repo: BaseProxyRepo
) -> httpx.AsyncClient:
connection = await service_replica_connection_pool.get(replica.id)
if connection is None:
Expand Down
30 changes: 14 additions & 16 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import ForbiddenError, ServerClientError
from dstack._internal.core.services.configs import update_default_project
from dstack._internal.gateway.routers import service_proxy
from dstack._internal.gateway.services.service_connection import service_replica_connection_pool
from dstack._internal.proxy.routers import service_proxy
from dstack._internal.proxy.services.service_connection import service_replica_connection_pool
from dstack._internal.server import settings
from dstack._internal.server.background import start_background_tasks
from dstack._internal.server.db import get_db, get_session_ctx, migrate
Expand All @@ -34,12 +34,12 @@
volumes,
)
from dstack._internal.server.services.config import ServerConfigManager
from dstack._internal.server.services.gateway_in_server.deps import (
GatewayInServerDependencyInjector,
)
from dstack._internal.server.services.gateways import gateway_connections_pool, init_gateways
from dstack._internal.server.services.locking import advisory_lock_ctx
from dstack._internal.server.services.projects import get_or_create_default_project
from dstack._internal.server.services.proxy.deps import (
ServerProxyDependencyInjector,
)
from dstack._internal.server.services.storage import init_default_storage
from dstack._internal.server.services.users import get_or_create_admin_user
from dstack._internal.server.settings import (
Expand Down Expand Up @@ -74,7 +74,7 @@ def create_app() -> FastAPI:
)

app = FastAPI(docs_url="/api/docs", lifespan=lifespan)
app.state.gateway_dependency_injector = GatewayInServerDependencyInjector()
app.state.proxy_dependency_injector = ServerProxyDependencyInjector()
return app


Expand Down Expand Up @@ -174,10 +174,8 @@ def register_routes(app: FastAPI, ui: bool = True):
app.include_router(gateways.router)
app.include_router(volumes.root_router)
app.include_router(volumes.project_router)
if FeatureFlags.GATEWAY_IN_SERVER:
app.include_router(
service_proxy.router, prefix="/gateway/services", tags=["gateway-in-server"]
)
if FeatureFlags.PROXY:
app.include_router(service_proxy.router, prefix="/services", tags=["service-proxy"])

@app.exception_handler(ForbiddenError)
async def forbidden_error_handler(request: Request, exc: ForbiddenError):
Expand Down Expand Up @@ -243,8 +241,8 @@ async def healthcheck():
async def custom_http_exception_handler(request, exc):
if (
request.url.path.startswith("/api")
or FeatureFlags.GATEWAY_IN_SERVER
and _is_gateway_in_server_request(request)
or FeatureFlags.PROXY
and _is_proxied_service_request(request)
):
return JSONResponse(
{"detail": exc.detail},
Expand All @@ -264,16 +262,16 @@ async def index():
return RedirectResponse("/api/docs")


def _is_gateway_in_server_request(request: Request) -> bool:
if request.url.path.startswith("/gateway"):
def _is_proxied_service_request(request: Request) -> bool:
if request.url.path.startswith("/services"):
return True
# Attempt detecting requests originating from services served by gateway-in-server.
# Attempt detecting requests originating from services proxied by dstack-proxy.
# Such requests can "leak" to dstack server paths if the service does not support
# running under a path prefix properly.
referrer = URL(request.headers.get("Referer", ""))
return (
referrer.netloc == "" or referrer.netloc == request.url.netloc
) and referrer.path.startswith("/gateway/services")
) and referrer.path.startswith("/services")


def _print_dstack_logo():
Expand Down
Empty file.
12 changes: 0 additions & 12 deletions src/dstack/_internal/server/services/gateway_in_server/deps.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/dstack/_internal/server/services/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Dependencies for dstack-proxy that allow it to run as part of dstack-server.
"""
12 changes: 12 additions & 0 deletions src/dstack/_internal/server/services/proxy/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import AsyncGenerator

from dstack._internal.proxy.deps import BaseProxyDependencyInjector
from dstack._internal.proxy.repos.base import BaseProxyRepo
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.services.proxy.repo import DBProxyRepo


class ServerProxyDependencyInjector(BaseProxyDependencyInjector):
async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
async with get_session_ctx() as session:
yield DBProxyRepo(session)
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData, JobStatus, RunSpec
from dstack._internal.gateway.repos.base import BaseGatewayRepo, Project, Replica, Service
from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service
from dstack._internal.server.models import JobModel, ProjectModel


class DBGatewayRepo(BaseGatewayRepo):
class DBProxyRepo(BaseProxyRepo):
"""
A gateway repo implementation used for gateway-in-server that retrieves data from
dstack-server's database. Since the database is populated by dstack-server, all or
most writer methods in this implementation are expected to be empty.
A repo implementation used by dstack-proxy running within dstack-server.
Retrieves data from dstack-server's database. Since the database is
populated by dstack-server, all or most writer methods in this
implementation are expected to be empty.
"""

def __init__(self, session: AsyncSession) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ class FeatureFlags:
development. Feature flags are environment variables of the form DSTACK_FF_*
"""

GATEWAY_IN_SERVER = bool(os.getenv("DSTACK_FF_GATEWAY_IN_SERVER"))
PROXY = bool(os.getenv("DSTACK_FF_PROXY"))
Loading