Skip to content

Commit

Permalink
Added router tests for pools (#916)
Browse files Browse the repository at this point in the history
* Added router tests for pools
* Handle unexpected exceptions
* Added tests for process_pool.py. Improve termination handling
  • Loading branch information
Sergey Mezentsev authored Feb 21, 2024
1 parent 5e62c44 commit 718eeff
Show file tree
Hide file tree
Showing 9 changed files with 974 additions and 33 deletions.
51 changes: 35 additions & 16 deletions src/dstack/_internal/server/background/tasks/process_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional, Union
from uuid import UUID

import requests
from pydantic import parse_raw_as
from sqlalchemy import select
from sqlalchemy.orm import joinedload
Expand All @@ -26,13 +27,18 @@

PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)

TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20)

# Terminate instance if the instance has not started within 10 minutes
STARTING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds


@dataclass
class HealthStatus:
healthy: bool
reason: str

def __str__(self):
def __str__(self) -> str:
return self.reason


Expand Down Expand Up @@ -99,8 +105,8 @@ async def check_shim(instance_id: UUID) -> None:

if health.healthy:
logger.debug("check instance %s status: shim health is OK", instance.name)
instance.fail_count = 0
instance.fail_reason = None
instance.termination_deadline = None
instance.health_status = None

if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING):
instance.status = (
Expand All @@ -110,24 +116,32 @@ async def check_shim(instance_id: UUID) -> None:
else:
logger.debug("check instance %s status: shim health: %s", instance.name, health)

instance.fail_count += 1
instance.fail_reason = health.reason
if instance.termination_deadline is None:
instance.termination_deadline = (
get_current_datetime() + TERMINATION_DEADLINE_OFFSET
)
instance.health_status = health.reason

if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY):
logger.warning(
"instance %s: shim has become unavailable, marked as failed", instance.name
)
FAIL_THRESHOLD = 10 * 6 * 20 # instance_healthcheck fails 20 minutes constantly
if instance.fail_count > FAIL_THRESHOLD:
logger.warning("instance %s shim is not available", instance.name)
deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
if get_current_datetime() > deadline:
instance.status = InstanceStatus.TERMINATING
instance.termination_reason = "Termination deadline"
logger.warning("mark instance %s as TERMINATED", instance.name)

if instance.status == InstanceStatus.STARTING and instance.started_at is not None:
STARTING_TIMEOUT = 10 * 60 # 10 minutes
starting_time_threshold = instance.started_at + timedelta(seconds=STARTING_TIMEOUT)
starting_time_threshold = instance.started_at.replace(
tzinfo=datetime.timezone.utc
) + timedelta(seconds=STARTING_TIMEOUT_SECONDS)
expire_starting = starting_time_threshold < get_current_datetime()
if expire_starting:
instance.status = InstanceStatus.TERMINATING
logger.warning(
"The Instance %s can't start in %s seconds. Marked as TERMINATED",
instance.name,
STARTING_TIMEOUT_SECONDS,
)

await session.commit()

Expand All @@ -148,8 +162,13 @@ def instance_healthcheck(*, ports: Dict[int, int]) -> HealthStatus:
healthy=False,
reason=f"Service name is {resp.service}, service version: {resp.version}",
)
except requests.RequestException as e:
return HealthStatus(healthy=False, reason=f"Can't request shim: {e}")
except Exception as e:
return HealthStatus(healthy=False, reason=f"Exception ({e.__class__.__name__}): {e}")
logger.exception("Unknown exception from shim.healthcheck: %s", e)
return HealthStatus(
healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}"
)


async def terminate(instance_id: UUID) -> None:
Expand All @@ -163,11 +182,10 @@ async def terminate(instance_id: UUID) -> None:
).one()

jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data)
BACKEND_TYPE = jpd.backend
backends = await backends_services.get_project_backends(project=instance.project)
backend = next((b for b in backends if b.TYPE == BACKEND_TYPE), None)
backend = next((b for b in backends if b.TYPE == jpd.backend), None)
if backend is None:
raise ValueError(f"there is no backend {BACKEND_TYPE}")
raise ValueError(f"there is no backend {jpd.backend}")

await run_async(
backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data
Expand Down Expand Up @@ -217,6 +235,7 @@ async def terminate_idle_instance() -> None:
instance.deleted_at = get_current_datetime()
instance.finished_at = get_current_datetime()
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Idle timeout"

idle_time = current_time - last_time
logger.info(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Rework termination handling
Revision ID: 1a48dfe44a40
Revises: 9eea6af28e10
Create Date: 2024-02-21 10:11:32.350099
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "1a48dfe44a40"
down_revision = "9eea6af28e10"
branch_labels = None
depends_on = None


def upgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("fail_reason")
batch_op.drop_column("fail_count")

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(sa.Column("termination_deadline", sa.DateTime(), nullable=True))
batch_op.add_column(
sa.Column("termination_reason", sa.VARCHAR(length=4000), nullable=True)
)
batch_op.add_column(sa.Column("health_status", sa.VARCHAR(length=4000), nullable=True))


def downgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(
sa.Column("fail_count", sa.Integer(), server_default=sa.text("0"), nullable=False)
)
batch_op.add_column(sa.Column("fail_reason", sa.String(length=4000), nullable=True))

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("termination_deadline")
batch_op.drop_column("termination_reason")
batch_op.drop_column("health_status")
10 changes: 5 additions & 5 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
String,
Text,
UniqueConstraint,
text,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.sql import false
Expand Down Expand Up @@ -287,17 +286,18 @@ class InstanceModel(BaseModel):

# VM
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)

# temination policy
termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50))
termination_idle_time: Mapped[int] = mapped_column(
Integer, default=DEFAULT_POOL_TERMINATION_IDLE_TIME
)

# connection fail handling
fail_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
fail_reason: Mapped[Optional[str]] = mapped_column(String(4000))
# instance termination handling
termination_deadline: Mapped[Optional[datetime]] = mapped_column(DateTime)
termination_reason: Mapped[Optional[str]] = mapped_column(String(4000))
health_status: Mapped[Optional[str]] = mapped_column(String(4000))

# backend
backend: Mapped[BackendType] = mapped_column(Enum(BackendType))
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/routers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest
from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember
from dstack._internal.server.security.permissions import ProjectMember
from dstack._internal.server.services.runs import (
abort_runs_of_pool,
list_project_runs,
Expand All @@ -33,7 +33,7 @@ async def list_pool(
async def remove_instance(
body: schemas.RemoveInstanceRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
_, project_model = user_project
await pools.remove_instance(
Expand All @@ -45,7 +45,7 @@ async def remove_instance(
async def set_default_pool(
body: schemas.SetDefaultPoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> bool:
_, project_model = user_project
return await pools.set_default_pool(session, project_model, body.pool_name)
Expand All @@ -55,7 +55,7 @@ async def set_default_pool(
async def delete_pool(
body: schemas.DeletePoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
pool_name = body.name
_, project_model = user_project
Expand Down Expand Up @@ -87,7 +87,7 @@ async def delete_pool(
async def create_pool(
body: schemas.CreatePoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
_, project = user_project
await pools.create_pool_model(session=session, project=project, name=body.name)
Expand All @@ -97,7 +97,7 @@ async def create_pool(
async def show_pool(
body: schemas.ShowPoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> models.PoolInstances:
_, project = user_project
instances = await pools.show_pool(session, project, pool_name=body.name)
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def add_remote(
name="instance",
resources=instance_resource,
),
region="",
region="", # TODO: add region
price=0.0,
availability=InstanceAvailability.AVAILABLE,
)
Expand All @@ -361,11 +361,14 @@ async def add_remote(
name=instance_name,
project=project,
pool=pool_model,
backend=BackendType.REMOTE,
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
job_provisioning_data=local.json(),
offer=offer.json(),
region=offer.region,
price=offer.price,
termination_policy=profile.termination_policy,
termination_idle_time=profile.termination_idle_time,
)
Expand Down
Loading

0 comments on commit 718eeff

Please sign in to comment.