Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if the build exists before starting an instance #534

Merged
merged 13 commits into from
Jul 5, 2023
13 changes: 13 additions & 0 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from datetime import datetime
from typing import Generator, List, Optional

import dstack._internal.core.build
from dstack._internal.backend.base import artifacts as base_artifacts
from dstack._internal.backend.base import build as base_build
from dstack._internal.backend.base import cache as base_cache
from dstack._internal.backend.base import jobs as base_jobs
from dstack._internal.backend.base import repos as base_repos
Expand All @@ -14,6 +16,7 @@
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.artifact import Artifact
from dstack._internal.core.build import BuildPlan
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobHead, JobStatus
from dstack._internal.core.log_event import LogEvent
Expand Down Expand Up @@ -230,6 +233,10 @@ def get_signed_download_url(self, object_key: str) -> str:
def get_signed_upload_url(self, object_key: str) -> str:
pass

@abstractmethod
def predict_build_plan(self, job: Job) -> BuildPlan:
pass


class ComponentBasedBackend(Backend):
@abstractmethod
Expand Down Expand Up @@ -264,6 +271,7 @@ def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:
return base_jobs.list_jobs(self.storage(), repo_id, run_name)

def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus):
self.predict_build_plan(job) # raises exception on missing build
base_jobs.run_job(self.storage(), self.compute(), job, failed_to_start_job_new_status)

def stop_job(self, repo_id: str, abort: bool, job_id: str):
Expand Down Expand Up @@ -435,3 +443,8 @@ def delete_configuration_cache(
base_cache.delete_configuration_cache(
self.storage(), repo_id, hub_user_name, configuration_path
)

def predict_build_plan(self, job: Job) -> BuildPlan:
return base_build.predict_build_plan(
self.storage(), job, dstack._internal.core.build.DockerPlatform.amd64
)
63 changes: 63 additions & 0 deletions cli/dstack/_internal/backend/base/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path
from platform import uname as platform_uname
from typing import Optional

import cpuinfo

from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.build import BuildNotFoundError, BuildPlan, DockerPlatform
from dstack._internal.core.job import Job
from dstack._internal.utils.escape import escape_head


def predict_build_plan(
storage: Storage, job: Job, platform: Optional[DockerPlatform]
) -> BuildPlan:
if job.build_policy in ["force-build", "build-only"]:
return BuildPlan.yes

if platform is None:
platform = guess_docker_platform()
if build_exists(storage, job, platform):
return BuildPlan.use

if job.build_commands:
if job.build_policy == "use-build":
raise BuildNotFoundError("Build not found. Run `dstack build` or add `--build` flag")
return BuildPlan.yes

if job.optional_build_commands and job.build_policy == "build":
return BuildPlan.yes
return BuildPlan.no


def build_exists(storage: Storage, job: Job, platform: DockerPlatform) -> bool:
prefix = _get_build_head_prefix(job, platform)
return len(storage.list_objects(prefix)) > 0


def _get_build_head_prefix(job: Job, platform: DockerPlatform) -> str:
parts = [
job.configuration_type.value,
job.configuration_path or "",
(Path("/workflow") / (job.working_dir or "")).as_posix(),
job.image_name,
platform.value,
# digest
# timestamp_utc
]
parts = ";".join(escape_head(p) for p in parts)
return f"builds/{job.repo_ref.repo_id}/{parts};"


def guess_docker_platform() -> DockerPlatform:
uname = platform_uname()
if uname.system == "Darwin":
brand = cpuinfo.get_cpu_info().get("brand_raw")
m_arch = "m1" in brand.lower() or "m2" in brand.lower()
arch = "arm64" if m_arch else "x86_64"
else:
arch = uname.machine
if uname.system == "Darwin" and arch in ["arm64", "aarch64"]:
return DockerPlatform.arm64
return DockerPlatform.amd64
2 changes: 2 additions & 0 deletions cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import yaml

from dstack._internal.backend.base import runners
from dstack._internal.backend.base.build import predict_build_plan
from dstack._internal.backend.base.compute import Compute, NoCapacityError
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.build import DockerPlatform
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobErrorCode, JobHead, JobStatus, SpotPolicy
Expand Down
7 changes: 7 additions & 0 deletions cli/dstack/_internal/backend/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Optional

from dstack._internal.backend.base import ComponentBasedBackend
from dstack._internal.backend.base import build as base_build
from dstack._internal.backend.local.compute import LocalCompute
from dstack._internal.backend.local.config import LocalConfig
from dstack._internal.backend.local.logs import LocalLogging
from dstack._internal.backend.local.secrets import LocalSecretsManager
from dstack._internal.backend.local.storage import LocalStorage
from dstack._internal.core.build import BuildPlan
from dstack._internal.core.job import Job


class LocalBackend(ComponentBasedBackend):
Expand Down Expand Up @@ -39,3 +42,7 @@ def secrets_manager(self) -> LocalSecretsManager:

def logging(self) -> LocalLogging:
return self._logging

def predict_build_plan(self, job: Job) -> BuildPlan:
# guess platform from uname
return base_build.predict_build_plan(self.storage(), job, platform=None)
8 changes: 4 additions & 4 deletions cli/dstack/_internal/cli/commands/build/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def _command(self, args: argparse.Namespace):
ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)

run_plan = hub_client.get_run_plan(
provider_name=provider_name, provider_data=provider_data, args=args
configuration_path=configuration_path,
provider_name=provider_name,
provider_data=provider_data,
args=args,
)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configuration_path, run_plan)
Expand All @@ -69,9 +72,6 @@ def _command(self, args: argparse.Namespace):
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
if run.status == JobStatus.FAILED:
console.print("\nProvisioning failed\n")
exit(1)
_poll_run(
hub_client,
run,
Expand Down
15 changes: 13 additions & 2 deletions cli/dstack/_internal/cli/commands/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def _command(self, args: Namespace):
ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)

run_plan = hub_client.get_run_plan(
provider_name=provider_name, provider_data=provider_data, args=args
configuration_path=configuration_path,
provider_name=provider_name,
provider_data=provider_data,
args=args,
)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configuration_path, run_plan)
Expand Down Expand Up @@ -184,12 +187,20 @@ def _print_run_plan(configuration_file: str, run_plan: RunPlan):
table.add_column("INSTANCE")
table.add_column("RESOURCES")
table.add_column("SPOT POLICY")
table.add_column("BUILD")
job_plan = run_plan.job_plans[0]
instance = job_plan.instance_type.instance_name or "-"
instance_info = _format_resources(job_plan.instance_type)
spot = job_plan.job.spot_policy.value
build_plan = job_plan.build_plan.value.title()
table.add_row(
configuration_file, run_plan.hub_user_name, run_plan.project, instance, instance_info, spot
configuration_file,
run_plan.hub_user_name,
run_plan.project,
instance,
instance_info,
spot,
build_plan,
)
console.print(table)
console.print()
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/_internal/cli/commands/run/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _parse_dev_environment_configuration_data(
"sea_green3]Command Palette[/sea_green3], executing [sea_green3]Shell Command: Install 'code' command in "
"PATH[/sea_green3], and restarting terminal.[/]\n"
)
provider_data["optional_build"].append("pip install -q --no-cache-dir ipykernel")
for key in ["optional_build", "commands"]:
provider_data[key].append("pip install -q --no-cache-dir ipykernel")
provider_data["commands"].extend(configuration_data.get("init") or [])
return provider_name, provider_data

Expand Down
18 changes: 18 additions & 0 deletions cli/dstack/_internal/core/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from enum import Enum

from dstack._internal.core.error import DstackError


class DockerPlatform(str, Enum):
amd64 = "amd64"
arm64 = "arm64"


class BuildPlan(str, Enum):
no = "no"
use = "use"
yes = "yes"


class BuildNotFoundError(DstackError):
code = "build_not_found"
2 changes: 2 additions & 0 deletions cli/dstack/_internal/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from pydantic import BaseModel

from dstack._internal.core.build import BuildPlan
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job


class JobPlan(BaseModel):
job: Job
instance_type: InstanceType
build_plan: BuildPlan


class RunPlan(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions cli/dstack/_internal/hub/routers/runners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.hub.models import StopRunners
Expand Down Expand Up @@ -29,6 +30,11 @@ async def run_runners(project_name: str, job: Job):
NoMatchingInstanceError.message, code=NoMatchingInstanceError.code
),
)
except BuildNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)


@router.post("/{project_name}/runners/stop")
Expand Down
10 changes: 9 additions & 1 deletion cli/dstack/_internal/hub/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.responses import PlainTextResponse

from dstack._internal.backend.base import Backend
from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.core.plan import JobPlan, RunPlan
Expand Down Expand Up @@ -35,7 +36,14 @@ async def get_run_plan(
msg=NoMatchingInstanceError.message, code=NoMatchingInstanceError.code
),
)
job_plans.append(JobPlan(job=job, instance_type=instance_type))
try:
build = backend.predict_build_plan(job)
except BuildNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)
job_plans.append(JobPlan(job=job, instance_type=instance_type, build_plan=build))
run_plan = RunPlan(project=project_name, hub_user_name=user.name, job_plans=job_plans)
return run_plan

Expand Down
5 changes: 5 additions & 0 deletions cli/dstack/api/hub/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests

from dstack._internal.core.artifact import Artifact
from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobHead
from dstack._internal.core.log_event import LogEvent
Expand Down Expand Up @@ -83,6 +84,8 @@ def get_run_plan(self, jobs: List[Job]) -> RunPlan:
body = resp.json()
if body["detail"]["code"] == NoMatchingInstanceError.code:
raise HubClientError(body["detail"]["msg"])
elif body["detail"]["code"] == BuildNotFoundError.code:
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

def create_run(self) -> str:
Expand Down Expand Up @@ -168,6 +171,8 @@ def run_job(self, job: Job):
body = resp.json()
if body["detail"]["code"] == NoMatchingInstanceError.code:
raise HubClientError(body["detail"]["msg"])
elif body["detail"]["code"] == BuildNotFoundError.code:
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

def stop_job(self, job_id: str, abort: bool):
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/api/hub/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def delete_configuration_cache(self, configuration_path: str):

def get_run_plan(
self,
configuration_path: str,
provider_name: str,
provider_data: Optional[Dict[str, Any]] = None,
args: Optional[argparse.Namespace] = None,
Expand All @@ -277,7 +278,7 @@ def get_run_plan(
run_name="dry-run",
ssh_key_pub="",
)
jobs = provider.get_jobs(repo=self.repo)
jobs = provider.get_jobs(repo=self.repo, configuration_path=configuration_path)
run_plan = self._api_client.get_run_plan(jobs)
return run_plan

Expand Down
16 changes: 12 additions & 4 deletions runner/internal/backend/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/dstackai/dstack/runner/internal/container"
"io"
"io/ioutil"
"path"
Expand Down Expand Up @@ -321,13 +322,20 @@ func (s *AWSBackend) GetRepoArchive(ctx context.Context, path, dir string) error
return gerrors.Wrap(base.GetRepoArchive(ctx, s.storage, path, dir))
}

func (s *AWSBackend) GetBuildDiffInfo(ctx context.Context, spec *container.BuildSpec) (*base.StorageObject, error) {
obj, err := base.GetBuildDiffInfo(ctx, s.storage, spec)
if err != nil {
return nil, gerrors.Wrap(err)
}
return obj, nil
}

func (s *AWSBackend) GetBuildDiff(ctx context.Context, key, dst string) error {
_ = base.DownloadFile(ctx, s.storage, key, dst)
return nil
return gerrors.Wrap(base.DownloadFile(ctx, s.storage, key, dst))
}

func (s *AWSBackend) PutBuildDiff(ctx context.Context, src, key string) error {
return gerrors.Wrap(base.UploadFile(ctx, s.storage, src, key))
func (s *AWSBackend) PutBuildDiff(ctx context.Context, src string, spec *container.BuildSpec) error {
return gerrors.Wrap(base.PutBuildDiff(ctx, s.storage, src, spec))
}

func (s *AWSBackend) GetTMPDir(ctx context.Context) string {
Expand Down
16 changes: 12 additions & 4 deletions runner/internal/backend/azure/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"errors"
"fmt"
"github.com/dstackai/dstack/runner/internal/backend/base"
"github.com/dstackai/dstack/runner/internal/container"
"io"
"os"
"path"
Expand Down Expand Up @@ -171,7 +172,7 @@
azbackend.config.StorageAccount,
)
logger := NewAzureLogger(loggingClient, azbackend.state.Job.JobID, logGroup, logName)
logger.Launch(ctx)

Check failure on line 175 in runner/internal/backend/azure/backend.go

View workflow job for this annotation

GitHub Actions / runner-test-master

Error return value of `logger.Launch` is not checked (errcheck)
return logger
}

Expand Down Expand Up @@ -236,13 +237,20 @@
return gerrors.Wrap(base.GetRepoArchive(ctx, azbackend.storage, path, dir))
}

func (azbackend *AzureBackend) GetBuildDiffInfo(ctx context.Context, spec *container.BuildSpec) (*base.StorageObject, error) {
obj, err := base.GetBuildDiffInfo(ctx, azbackend.storage, spec)
if err != nil {
return nil, gerrors.Wrap(err)
}
return obj, nil
}

func (azbackend *AzureBackend) GetBuildDiff(ctx context.Context, key, dst string) error {
_ = base.DownloadFile(ctx, azbackend.storage, key, dst)
return nil
return gerrors.Wrap(base.DownloadFile(ctx, azbackend.storage, key, dst))
}

func (azbackend *AzureBackend) PutBuildDiff(ctx context.Context, src, key string) error {
return gerrors.Wrap(base.UploadFile(ctx, azbackend.storage, src, key))
func (azbackend *AzureBackend) PutBuildDiff(ctx context.Context, src string, spec *container.BuildSpec) error {
return gerrors.Wrap(base.PutBuildDiff(ctx, azbackend.storage, src, spec))
}

func (azbackend *AzureBackend) GetTMPDir(ctx context.Context) string {
Expand Down
Loading
Loading