Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added router tests for pools #916

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions src/dstack/_internal/server/background/tasks/process_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional, Union
from uuid import UUID

import requests
from pydantic import parse_raw_as
from sqlalchemy import select
from sqlalchemy.orm import joinedload
Expand Down Expand Up @@ -32,7 +33,7 @@ class HealthStatus:
healthy: bool
reason: str

def __str__(self):
def __str__(self) -> str:
return self.reason


Expand Down Expand Up @@ -99,8 +100,8 @@ async def check_shim(instance_id: UUID) -> None:

if health.healthy:
logger.debug("check instance %s status: shim health is OK", instance.name)
instance.fail_count = 0
instance.fail_reason = None
instance.termination_deadline = None
instance.health_status = None

if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING):
instance.status = (
Expand All @@ -110,24 +111,29 @@ async def check_shim(instance_id: UUID) -> None:
else:
logger.debug("check instance %s status: shim health: %s", instance.name, health)

instance.fail_count += 1
instance.fail_reason = health.reason
if instance.termination_deadline is None:
instance.termination_deadline = get_current_datetime() + timedelta(minutes=20)
TheBits marked this conversation as resolved.
Show resolved Hide resolved
instance.health_status = health.reason

if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY):
logger.warning(
"instance %s shim is not available, marked as failed", instance.name
)
FAIL_THRESHOLD = 10 * 6 * 20 # instance_healthcheck fails 20 minutes constantly
if instance.fail_count > FAIL_THRESHOLD:
logger.warning("instance %s shim is not available", instance.name)
deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
if get_current_datetime() > deadline:
instance.status = InstanceStatus.TERMINATING
instance.termination_reason = "Termination deadline"
logger.warning("mark instance %s as TERMINATED", instance.name)

if instance.status == InstanceStatus.STARTING and instance.started_at is not None:
STARTING_TIMEOUT = 10 * 60 # 10 minutes
starting_time_threshold = instance.started_at + timedelta(seconds=STARTING_TIMEOUT)
starting_timeout = 10 * 60 # 10 minutes
TheBits marked this conversation as resolved.
Show resolved Hide resolved
starting_time_threshold = instance.started_at + timedelta(seconds=starting_timeout)
expire_starting = starting_time_threshold < get_current_datetime()
if expire_starting:
instance.status = InstanceStatus.TERMINATING
logger.warning(
"The Instance %s canot start in %s seconds. Marked as TERMINATED",
TheBits marked this conversation as resolved.
Show resolved Hide resolved
instance.name,
starting_timeout,
)

await session.commit()

Expand All @@ -148,8 +154,13 @@ def instance_healthcheck(*, ports: Dict[int, int]) -> HealthStatus:
healthy=False,
reason=f"Service name is {resp.service}, service version: {resp.version}",
)
except requests.RequestException as e:
return HealthStatus(healthy=False, reason=f"Can't request shim: {e}")
except Exception as e:
return HealthStatus(healthy=False, reason=f"Exception ({e.__class__.__name__}): {e}")
logger.exception("Unkonw exception from shim.healthcheck: %s", e)
TheBits marked this conversation as resolved.
Show resolved Hide resolved
return HealthStatus(
healthy=False, reason=f"Unkonw exception ({e.__class__.__name__}): {e}"
TheBits marked this conversation as resolved.
Show resolved Hide resolved
)


async def terminate(instance_id: UUID) -> None:
Expand All @@ -163,11 +174,10 @@ async def terminate(instance_id: UUID) -> None:
).one()

jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data)
BACKEND_TYPE = jpd.backend
backends = await backends_services.get_project_backends(project=instance.project)
backend = next((b for b in backends if b.TYPE == BACKEND_TYPE), None)
backend = next((b for b in backends if b.TYPE == jpd.backend), None)
if backend is None:
raise ValueError(f"there is no backend {BACKEND_TYPE}")
raise ValueError(f"there is no backend {jpd.backend}")

await run_async(
backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data
Expand Down Expand Up @@ -217,6 +227,7 @@ async def terminate_idle_instance() -> None:
instance.deleted_at = get_current_datetime()
instance.finished_at = get_current_datetime()
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Idle timeout"

idle_time = current_time - last_time
logger.info(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Rework termination handling

Revision ID: 1a48dfe44a40
Revises: 9eea6af28e10
Create Date: 2024-02-21 10:11:32.350099

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "1a48dfe44a40"
down_revision = "9eea6af28e10"
branch_labels = None
depends_on = None


def upgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("fail_reason")
batch_op.drop_column("fail_count")

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(sa.Column("termination_deadline", sa.DateTime(), nullable=True))
batch_op.add_column(
sa.Column("termination_reason", sa.VARCHAR(length=4000), nullable=True)
)
batch_op.add_column(sa.Column("health_status", sa.VARCHAR(length=4000), nullable=True))


def downgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(
sa.Column("fail_count", sa.Integer(), server_default=sa.text("0"), nullable=False)
)
batch_op.add_column(sa.Column("fail_reason", sa.String(length=4000), nullable=True))

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("termination_deadline")
batch_op.drop_column("termination_reason")
batch_op.drop_column("health_status")
10 changes: 5 additions & 5 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
String,
Text,
UniqueConstraint,
text,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.sql import false
Expand Down Expand Up @@ -284,17 +283,18 @@ class InstanceModel(BaseModel):

# VM
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)

# temination policy
termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50))
termination_idle_time: Mapped[int] = mapped_column(
Integer, default=DEFAULT_TERMINATION_IDLE_TIME
)

# connection fail handling
fail_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
fail_reason: Mapped[Optional[str]] = mapped_column(String(4000))
# instance termination handling
termination_deadline: Mapped[Optional[datetime]] = mapped_column(DateTime)
termination_reason: Mapped[Optional[str]] = mapped_column(String(4000))
health_status: Mapped[Optional[str]] = mapped_column(String(4000))

# backend
backend: Mapped[BackendType] = mapped_column(Enum(BackendType))
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/routers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def show_pool(
async def add_instance(
body: AddRemoteInstanceRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
TheBits marked this conversation as resolved.
Show resolved Hide resolved
) -> bool:
_, project = user_project
result = await pools.add_remote(
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def add_remote(
name="instance",
resources=instance_resource,
),
region="",
region="", # TODO: add region
price=0.0,
availability=InstanceAvailability.AVAILABLE,
)
Expand All @@ -361,11 +361,14 @@ async def add_remote(
name=instance_name,
project=project,
pool=pool_model,
backend=BackendType.REMOTE,
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
job_provisioning_data=local.json(),
offer=offer.json(),
region=offer.region,
price=offer.price,
termination_policy=profile.termination_policy,
termination_idle_time=profile.termination_idle_time,
)
Expand Down
Loading
Loading