diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 1879e4e5b..e3fb19c99 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Union from uuid import UUID +import requests from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.orm import joinedload @@ -26,13 +27,18 @@ PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) +TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) + +# Terminate instance if the instance has not started within 10 minutes +STARTING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds + @dataclass class HealthStatus: healthy: bool reason: str - def __str__(self): + def __str__(self) -> str: return self.reason @@ -99,8 +105,8 @@ async def check_shim(instance_id: UUID) -> None: if health.healthy: logger.debug("check instance %s status: shim health is OK", instance.name) - instance.fail_count = 0 - instance.fail_reason = None + instance.termination_deadline = None + instance.health_status = None if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING): instance.status = ( @@ -110,24 +116,32 @@ async def check_shim(instance_id: UUID) -> None: else: logger.debug("check instance %s status: shim health: %s", instance.name, health) - instance.fail_count += 1 - instance.fail_reason = health.reason + if instance.termination_deadline is None: + instance.termination_deadline = ( + get_current_datetime() + TERMINATION_DEADLINE_OFFSET + ) + instance.health_status = health.reason if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY): - logger.warning( - "instance %s: shim has become unavailable, marked as failed", instance.name - ) - FAIL_THRESHOLD = 10 * 6 * 20 # instance_healthcheck fails 20 minutes constantly - if instance.fail_count > FAIL_THRESHOLD: + logger.warning("instance %s shim is not available", instance.name) + deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc) + if get_current_datetime() > deadline: instance.status = InstanceStatus.TERMINATING + instance.termination_reason = "Termination deadline" logger.warning("mark instance %s as TERMINATED", instance.name) if instance.status == InstanceStatus.STARTING and instance.started_at is not None: - STARTING_TIMEOUT = 10 * 60 # 10 minutes - starting_time_threshold = instance.started_at + timedelta(seconds=STARTING_TIMEOUT) + starting_time_threshold = instance.started_at.replace( + tzinfo=datetime.timezone.utc + ) + timedelta(seconds=STARTING_TIMEOUT_SECONDS) expire_starting = starting_time_threshold < get_current_datetime() if expire_starting: instance.status = InstanceStatus.TERMINATING + logger.warning( + "The Instance %s can't start in %s seconds. Marked as TERMINATED", + instance.name, + STARTING_TIMEOUT_SECONDS, + ) await session.commit() @@ -148,8 +162,13 @@ def instance_healthcheck(*, ports: Dict[int, int]) -> HealthStatus: healthy=False, reason=f"Service name is {resp.service}, service version: {resp.version}", ) + except requests.RequestException as e: + return HealthStatus(healthy=False, reason=f"Can't request shim: {e}") except Exception as e: - return HealthStatus(healthy=False, reason=f"Exception ({e.__class__.__name__}): {e}") + logger.exception("Unknown exception from shim.healthcheck: %s", e) + return HealthStatus( + healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}" + ) async def terminate(instance_id: UUID) -> None: @@ -163,11 +182,10 @@ async def terminate(instance_id: UUID) -> None: ).one() jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) - BACKEND_TYPE = jpd.backend backends = await backends_services.get_project_backends(project=instance.project) - backend = next((b for b in backends if b.TYPE == BACKEND_TYPE), None) + backend = next((b for b in backends if b.TYPE == jpd.backend), None) if backend is None: - raise ValueError(f"there is no backend {BACKEND_TYPE}") + raise ValueError(f"there is no backend {jpd.backend}") await run_async( backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data @@ -217,6 +235,7 @@ async def terminate_idle_instance() -> None: instance.deleted_at = get_current_datetime() instance.finished_at = get_current_datetime() instance.status = InstanceStatus.TERMINATED + instance.termination_reason = "Idle timeout" idle_time = current_time - last_time logger.info( diff --git a/src/dstack/_internal/server/migrations/versions/1a48dfe44a40_rework_termination_handling.py b/src/dstack/_internal/server/migrations/versions/1a48dfe44a40_rework_termination_handling.py new file mode 100644 index 000000000..d106c9299 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/1a48dfe44a40_rework_termination_handling.py @@ -0,0 +1,41 @@ +"""Rework termination handling + +Revision ID: 1a48dfe44a40 +Revises: 9eea6af28e10 +Create Date: 2024-02-21 10:11:32.350099 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "1a48dfe44a40" +down_revision = "9eea6af28e10" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_column("fail_reason") + batch_op.drop_column("fail_count") + + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column(sa.Column("termination_deadline", sa.DateTime(), nullable=True)) + batch_op.add_column( + sa.Column("termination_reason", sa.VARCHAR(length=4000), nullable=True) + ) + batch_op.add_column(sa.Column("health_status", sa.VARCHAR(length=4000), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column( + sa.Column("fail_count", sa.Integer(), server_default=sa.text("0"), nullable=False) + ) + batch_op.add_column(sa.Column("fail_reason", sa.String(length=4000), nullable=True)) + + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_column("termination_deadline") + batch_op.drop_column("termination_reason") + batch_op.drop_column("health_status") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 9f479df76..97cbf69de 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -14,7 +14,6 @@ String, Text, UniqueConstraint, - text, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.sql import false @@ -287,7 +286,7 @@ class InstanceModel(BaseModel): # VM started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) # temination policy termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) @@ -295,9 +294,10 @@ class InstanceModel(BaseModel): Integer, default=DEFAULT_POOL_TERMINATION_IDLE_TIME ) - # connection fail handling - fail_count: Mapped[int] = mapped_column(Integer, server_default=text("0")) - fail_reason: Mapped[Optional[str]] = mapped_column(String(4000)) + # instance termination handling + termination_deadline: Mapped[Optional[datetime]] = mapped_column(DateTime) + termination_reason: Mapped[Optional[str]] = mapped_column(String(4000)) + health_status: Mapped[Optional[str]] = mapped_column(String(4000)) # backend backend: Mapped[BackendType] = mapped_column(Enum(BackendType)) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 3b8aa026c..8ac142b45 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -10,7 +10,7 @@ from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest -from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember +from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services.runs import ( abort_runs_of_pool, list_project_runs, @@ -33,7 +33,7 @@ async def list_pool( async def remove_instance( body: schemas.RemoveInstanceRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> None: _, project_model = user_project await pools.remove_instance( @@ -45,7 +45,7 @@ async def remove_instance( async def set_default_pool( body: schemas.SetDefaultPoolRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> bool: _, project_model = user_project return await pools.set_default_pool(session, project_model, body.pool_name) @@ -55,7 +55,7 @@ async def set_default_pool( async def delete_pool( body: schemas.DeletePoolRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> None: pool_name = body.name _, project_model = user_project @@ -87,7 +87,7 @@ async def delete_pool( async def create_pool( body: schemas.CreatePoolRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> None: _, project = user_project await pools.create_pool_model(session=session, project=project, name=body.name) @@ -97,7 +97,7 @@ async def create_pool( async def show_pool( body: schemas.ShowPoolRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> models.PoolInstances: _, project = user_project instances = await pools.show_pool(session, project, pool_name=body.name) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 8d9a8c055..d9a8370d4 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -352,7 +352,7 @@ async def add_remote( name="instance", resources=instance_resource, ), - region="", + region="", # TODO: add region price=0.0, availability=InstanceAvailability.AVAILABLE, ) @@ -361,11 +361,14 @@ async def add_remote( name=instance_name, project=project, pool=pool_model, + backend=BackendType.REMOTE, created_at=common_utils.get_current_datetime(), started_at=common_utils.get_current_datetime(), status=InstanceStatus.PENDING, job_provisioning_data=local.json(), offer=offer.json(), + region=offer.region, + price=offer.price, termination_policy=profile.termination_policy, termination_idle_time=profile.termination_idle_time, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_pool.py b/src/tests/_internal/server/background/tasks/test_process_pool.py new file mode 100644 index 000000000..e04c732ed --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_process_pool.py @@ -0,0 +1,254 @@ +import datetime as dt +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import InstanceStatus, JobStatus +from dstack._internal.server.background.tasks.process_pools import ( + HealthStatus, + process_pools, + terminate_idle_instance, +) +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_pool, + create_project, + create_repo, + create_run, + create_user, +) +from dstack._internal.utils.common import get_current_datetime + + +class TestCheckShim: + @pytest.mark.asyncio + async def test_check_shim_transitions_starting_on_ready(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.STARTING) + instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) + instance.health_status = "ssh connect problem" + + await session.commit() + + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.READY + assert instance.termination_deadline is None + assert instance.health_status is None + + @pytest.mark.asyncio + async def test_check_shim_transitions_starting_on_terminating( + self, test_db, session: AsyncSession + ): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.STARTING) + instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) + instance.health_status = "ssh connect problem" + + await session.commit() + + health_reason = "Shim problem" + + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=False, reason=health_reason) + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline is not None + assert instance.health_status == health_reason + + @pytest.mark.asyncio + async def test_check_shim_transitions_creating_on_busy(self, test_db, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await create_pool(session, project) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PENDING, + ) + + instance = await create_instance(session, project, pool, status=InstanceStatus.CREATING) + instance.termination_deadline = get_current_datetime().replace( + tzinfo=dt.timezone.utc + ) + dt.timedelta(days=1) + instance.health_status = "ssh connect problem" + instance.job = job + + await session.commit() + + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.BUSY + assert instance.termination_deadline is None + assert instance.health_status is None + assert instance.job == job + + @pytest.mark.asyncio + async def test_check_shim_start_termination_deadline(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.READY) + + health_status = "SSH connection fail" + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=False, reason=health_status) + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.READY + assert instance.termination_deadline is not None + assert instance.termination_deadline.replace( + tzinfo=dt.timezone.utc + ) > get_current_datetime() + dt.timedelta(minutes=19) + assert instance.health_status == health_status + + @pytest.mark.asyncio + async def test_check_shim_stop_termination_deadline(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.READY) + instance.termination_deadline = get_current_datetime() + dt.timedelta(minutes=19) + await session.commit() + + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.READY + assert instance.termination_deadline is None + assert instance.health_status is None + + @pytest.mark.asyncio + async def test_check_shim_terminate_instance_by_dedaline(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.READY) + termination_deadline_time = get_current_datetime() + dt.timedelta(minutes=-19) + instance.termination_deadline = termination_deadline_time + await session.commit() + + health_status = "Not ok" + with patch( + "dstack._internal.server.background.tasks.process_pools.instance_healthcheck" + ) as healthcheck: + healthcheck.return_value = HealthStatus(healthy=False, reason=health_status) + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.TERMINATING + assert ( + instance.termination_deadline.replace(tzinfo=dt.timezone.utc) + == termination_deadline_time + ) + assert instance.termination_reason == "Termination deadline" + assert instance.health_status == health_status + + +class TestIdleTime: + @pytest.mark.asyncio + async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.READY) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + with patch( + "dstack._internal.server.background.tasks.process_pools.terminate_job_provisioning_data_instance" + ): + await terminate_idle_instance() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == "Idle timeout" + + +class TestTerminate: + @pytest.mark.asyncio + async def test_terminate(self, test_db, session: AsyncSession): + project = await create_project(session=session) + pool = await create_pool(session, project) + + instance = await create_instance(session, project, pool, status=InstanceStatus.TERMINATING) + + reason = "some reason" + instance.termination_reason = reason + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + with patch( + "dstack._internal.server.background.tasks.process_pools.backends_services.get_project_backends" + ) as get_backends: + backend = Mock() + backend.TYPE = BackendType.DATACRUNCH + backend.compute.return_value.terminate_instance.return_value = Mock() + + get_backends.return_value = [backend] + + await process_pools() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == "some reason" + assert instance.deleted == True + assert instance.deleted_at is not None + assert instance.finished_at is not None diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 0fd6943aa..306c84c0d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -17,8 +17,6 @@ create_user, ) -MODULE = "dstack._internal.server.services.jobs" - class TestProcessFinishedJobs: @pytest.mark.asyncio @@ -63,7 +61,7 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A ssh_proxy=None, ), ) - with patch(f"{MODULE}.terminate_job_provisioning_data_instance"): + with patch("dstack._internal.server.background.tasks.process_finished_jobs.submit_stop"): await process_finished_jobs() await session.refresh(job) assert job is not None diff --git a/src/tests/_internal/server/routers/test_pools.py b/src/tests/_internal/server/routers/test_pools.py new file mode 100644 index 000000000..206fe76ba --- /dev/null +++ b/src/tests/_internal/server/routers/test_pools.py @@ -0,0 +1,537 @@ +import datetime as dt + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.main import app +from dstack._internal.server.schemas.pools import ( + CreatePoolRequest, + DeletePoolRequest, + RemoveInstanceRequest, + SetDefaultPoolRequest, + ShowPoolRequest, +) +from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_instance, + create_pool, + create_project, + create_user, + get_auth_headers, +) + +client = TestClient(app) + +TEST_POOL_NAME = "test_router_pool_name" + + +class TestListPool: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/pool/list", + json={}, + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_default_and_list(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + pool = result[0] + expected = [ + { + "name": "default-pool", + "default": True, + "created_at": str(pool["created_at"]), + "total_instances": 0, + "available_instances": 0, + } + ] + assert result == expected + + @pytest.mark.asyncio + async def test_list_pools(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + + await create_pool(session, project, pool_name=TEST_POOL_NAME) + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + pool = result[0] + expected = [ + { + "name": TEST_POOL_NAME, + "default": False, + "created_at": str(pool["created_at"]), + "total_instances": 0, + "available_instances": 0, + } + ] + assert result == expected + + +class TestDeletePool: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/pool/delete", + json=DeletePoolRequest(name=TEST_POOL_NAME, force=False).dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_delete_last_pool(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + response = client.post( + f"/api/project/{project.name}/pool/delete", + headers=get_auth_headers(user.token), + json=DeletePoolRequest(name=TEST_POOL_NAME, force=False).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + + default_pool = result[0] + assert default_pool["name"] == DEFAULT_POOL_NAME + assert dt.datetime.fromisoformat(default_pool["created_at"]) > pool.created_at + + @pytest.mark.asyncio + async def test_delete_pool(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool_left = await create_pool(session, project, pool_name=f"{TEST_POOL_NAME}-left") + pool_right = await create_pool(session, project, pool_name=f"{TEST_POOL_NAME}-right") + response = client.post( + f"/api/project/{project.name}/pool/delete", + headers=get_auth_headers(user.token), + json=DeletePoolRequest(name=pool_left.name, force=False).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + + default_pool = result[0] + assert default_pool["name"] == pool_right.name + + @pytest.mark.asyncio + async def test_delete_missing(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + response = client.post( + f"/api/project/{project.name}/pool/delete", + headers=get_auth_headers(user.token), + json=DeletePoolRequest(name="missing name", force=False).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + + default_pool = result[0] + assert default_pool["name"] == pool.name + + +class TestSetDefaultPool: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/pool/set_default", + json=SetDefaultPoolRequest(pool_name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_set_default(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + await create_pool(session, project, pool_name=f"{TEST_POOL_NAME}-left") + pool_right = await create_pool(session, project, pool_name=f"{TEST_POOL_NAME}-right") + response = client.post( + f"/api/project/{project.name}/pool/set_default", + headers=get_auth_headers(user.token), + json=SetDefaultPoolRequest(pool_name=pool_right.name).dict(), + ) + assert response.status_code == 200 + assert response.json() == True + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 2 + + default_pool = [p for p in result if p["default"]][0] + assert default_pool["name"] == pool_right.name + + @pytest.mark.asyncio + async def test_set_default_missing(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + response = client.post( + f"/api/project/{project.name}/pool/set_default", + headers=get_auth_headers(user.token), + json=SetDefaultPoolRequest(pool_name="missing pool").dict(), + ) + assert response.status_code == 200 + assert response.json() == False + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + + result_pool = result[0] + assert result_pool["name"] == pool.name + assert result_pool["default"] == False + + +class TestCreatePool: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/pool/create", + json=CreatePoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_pool(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = client.post( + f"/api/project/{project.name}/pool/create", + headers=get_auth_headers(user.token), + json=CreatePoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + response = client.post( + f"/api/project/{project.name}/pool/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 1 + + default_pool = result[0] + assert default_pool["name"] == TEST_POOL_NAME + + @pytest.mark.asyncio + async def test_duplicate_name(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + response = client.post( + f"/api/project/{project.name}/pool/create", + headers=get_auth_headers(user.token), + json=CreatePoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + with pytest.raises(ValueError): + response = client.post( + f"/api/project/{project.name}/pool/create", + headers=get_auth_headers(user.token), + json=CreatePoolRequest(name=TEST_POOL_NAME).dict(), + ) + + +class TestShowPool: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/pool/show", + json=CreatePoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_show_pool(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + await create_instance(session, project, pool) + + response = client.post( + f"/api/project/{project.name}/pool/show", + headers=get_auth_headers(user.token), + json=ShowPoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 200 + assert response.json() == { + "name": "test_router_pool_name", + "instances": [ + { + "backend": "datacrunch", + "instance_type": { + "name": "instance", + "resources": { + "cpus": 1, + "memory_mib": 512, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "", + }, + }, + "name": "test_instance", + "job_name": None, + "job_status": None, + "hostname": "running_instance.ip", + "status": "ready", + "created": "2023-01-02T03:04:00", + "region": "en", + "price": 0.1, + } + ], + } + + @pytest.mark.asyncio + async def test_show_missing_pool(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + await create_instance(session, project, pool) + + response = client.post( + f"/api/project/{project.name}/pool/show", + headers=get_auth_headers(user.token), + json=ShowPoolRequest(name="missing_pool").dict(), + ) + assert response.status_code == 400 + assert response.json() == { + "detail": [{"msg": "Pool is not found", "code": "resource_not_exists"}] + } + + +class TestAddRemote: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + remote = AddRemoteInstanceRequest( + instance_name="test_instance_name", + host="localhost", + port="22", + resources=ResourcesSpec(cpu=1), + profile=Profile(name="test_profile"), + ) + response = client.post( + f"/api/project/{project.name}/pool/add_remote", + json=remote.dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_add_remote(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + remote = AddRemoteInstanceRequest( + instance_name="test_instance_name", + host="localhost", + port="22", + resources=ResourcesSpec(cpu=1), + profile=Profile(name="test_profile"), + ) + response = client.post( + f"/api/project/{project.name}/pool/add_remote", + headers=get_auth_headers(user.token), + json=remote.dict(), + ) + assert response.status_code == 200 + assert response.json() == True + + +class TestRemoveInstance: + @pytest.mark.asyncio + async def test_returns_403_if_not_authenticated(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + remote = AddRemoteInstanceRequest( + instance_name="test_instance_name", + host="localhost", + port="22", + resources=ResourcesSpec(cpu=1), + profile=Profile(name="test_profile"), + ) + response = client.post( + f"/api/project/{project.name}/pool/add_remote", + json=remote.dict(), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_remove_instance(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + pool = await create_pool(session, project, pool_name=TEST_POOL_NAME) + instance = await create_instance(session, project, pool) + + response = client.post( + f"/api/project/{project.name}/pool/remove", + headers=get_auth_headers(user.token), + json=RemoveInstanceRequest( + pool_name=TEST_POOL_NAME, instance_name=instance.name + ).dict(), + ) + assert response.status_code == 200 + assert response.json() is None + + response = client.post( + f"/api/project/{project.name}/pool/show", + headers=get_auth_headers(user.token), + json=ShowPoolRequest(name=TEST_POOL_NAME).dict(), + ) + assert response.status_code == 200 + assert response.json() == { + "name": "test_router_pool_name", + "instances": [ + { + "backend": "datacrunch", + "instance_type": { + "name": "instance", + "resources": { + "cpus": 1, + "memory_mib": 512, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "", + }, + }, + "name": "test_instance", + "job_name": None, + "job_status": None, + "hostname": "running_instance.ip", + "status": "terminating", + "created": "2023-01-02T03:04:00", + "region": "en", + "price": 0.1, + } + ], + } diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 85211cf26..1e30cd068 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -14,13 +14,17 @@ InstanceAvailability, InstanceOfferWithAvailability, InstanceType, + LaunchedInstanceInfo, Resources, + SSHKey, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME -from dstack._internal.core.models.runs import JobSpec, JobStatus, RunSpec +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import JobSpec, JobStatus, Requirements, RunSpec from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.main import app from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.schemas.runs import CreateInstanceRequest from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( create_job, @@ -734,3 +738,88 @@ async def test_returns_400_if_runs_active(self, test_db, session: AsyncSession): assert len(res.scalars().all()) == 1 res = await session.execute(select(JobModel)) assert len(res.scalars().all()) == 1 + + +class TestCreateInstance: + @pytest.mark.asyncio + async def test_returns_403_if_not_project_member(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = client.post( + f"/api/project/{project.name}/runs/create_instance", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_instance(self, test_db, session: AsyncSession): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + request = CreateInstanceRequest( + pool_name=DEFAULT_POOL_NAME, + profile=Profile(name="test_profile"), + requirements=Requirements(resources=ResourcesSpec(cpu=1)), + ssh_key=SSHKey(public="test_public_key"), + ) + + with patch( + "dstack._internal.server.services.runs.get_run_plan_by_requirements" + ) as run_plan_by_req: + offers = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="eu", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ) + instance_info = LaunchedInstanceInfo( + instance_id="test_instance", + region="eu", + ip_address="127.0.0.1", + username="ubuntu", + ssh_port=22, + dockerized=False, + ) + backend = Mock() + backend.compute.return_value.get_offers.return_value = [offers] + backend.compute.return_value.create_instance.return_value = instance_info + backend.TYPE = BackendType.AWS + run_plan_by_req.return_value = [(backend, offers)] + + response = client.post( + f"/api/project/{project.name}/runs/create_instance", + headers=get_auth_headers(user.token), + json=request.dict(), + ) + assert response.status_code == 200 + + result = response.json() + expected = { + "backend": "aws", + "instance_type": { + "name": "instance", + "resources": { + "cpus": 1, + "memory_mib": 512, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "", + }, + }, + "name": result["name"], + "job_name": None, + "job_status": None, + "hostname": "127.0.0.1", + "status": "starting", + "created": result["created"], + "region": "eu", + "price": 1.0, + } + assert result == expected