Skip to content

Commit

Permalink
Support gateways without public IPs on AWS (#1224)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored May 15, 2024
1 parent 3965d6c commit 3f1504b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 33 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
BackendType.LAMBDA,
BackendType.TENSORDOCK,
]
BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = [BackendType.AWS]
15 changes: 14 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def create_gateway(
]
if settings.DSTACK_VERSION is not None:
tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION})
vpc_id, subnet_id = get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
config=self.config,
region=configuration.region,
allocate_public_ip=configuration.public_ip,
)
response = ec2.create_instances(
**aws_resources.create_instances_struct(
disk_size=10,
Expand All @@ -215,17 +221,24 @@ def create_gateway(
security_group_id=aws_resources.create_gateway_security_group(
ec2_client=ec2_client,
project_id=configuration.project_name,
vpc_id=vpc_id,
),
spot=False,
subnet_id=subnet_id,
allocate_public_ip=configuration.public_ip,
)
)
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address
if configuration.public_ip:
ip_address = instance.public_ip_address
else:
ip_address = instance.private_ip_address
return LaunchedGatewayInfo(
instance_id=instance.instance_id,
region=configuration.region,
ip_address=instance.public_ip_address,
ip_address=ip_address,
)


Expand Down
32 changes: 22 additions & 10 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,31 @@ def get_gateway_image_id(ec2_client: botocore.client.BaseClient) -> str:
return image["ImageId"]


def create_gateway_security_group(ec2_client: botocore.client.BaseClient, project_id: str) -> str:
def create_gateway_security_group(
ec2_client: botocore.client.BaseClient,
project_id: str,
vpc_id: Optional[str],
) -> str:
security_group_name = "dstack_gw_sg_" + project_id.replace("-", "_").lower()

response = ec2_client.describe_security_groups(
Filters=[
describe_security_groups_filters = [
{
"Name": "group-name",
"Values": [security_group_name],
},
]
if vpc_id is not None:
describe_security_groups_filters.append(
{
"Name": "group-name",
"Values": [security_group_name],
},
],
)
"Name": "vpc-id",
"Values": [vpc_id],
}
)
response = ec2_client.describe_security_groups(Filters=describe_security_groups_filters)
if response.get("SecurityGroups"):
return response["SecurityGroups"][0]["GroupId"]

create_security_group_kwargs = {}
if vpc_id is not None:
create_security_group_kwargs["VpcId"] = vpc_id
security_group = ec2_client.create_security_group(
Description="Generated by dstack",
GroupName=security_group_name,
Expand All @@ -198,6 +209,7 @@ def create_gateway_security_group(ec2_client: botocore.client.BaseClient, projec
],
},
],
**create_security_group_kwargs,
)
group_id = security_group["GroupId"]

Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class GatewayConfiguration(CoreModel):
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `*.example.com`")
] = None
# public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True
public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True


class GatewayComputeConfiguration(CoreModel):
Expand Down
58 changes: 39 additions & 19 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import dstack._internal.server.services.jobs as jobs_services
import dstack._internal.utils.random_names as random_names
from dstack._internal.core.backends import BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
from dstack._internal.core.backends.base.compute import (
Compute,
get_dstack_gateway_wheel,
Expand Down Expand Up @@ -87,17 +88,27 @@ async def get_project_default_gateway(


async def create_gateway_compute(
project_name: str,
backend_compute: Compute,
configuration: GatewayComputeConfiguration,
configuration: GatewayConfiguration,
backend_id: Optional[uuid.UUID] = None,
) -> GatewayComputeModel:
private_bytes, public_bytes = generate_rsa_key_pair_bytes()
gateway_ssh_private_key = private_bytes.decode()
gateway_ssh_public_key = public_bytes.decode()

compute_configuration = GatewayComputeConfiguration(
project_name=project_name,
instance_name=configuration.name,
backend=configuration.backend,
region=configuration.region,
public_ip=configuration.public_ip,
ssh_key_pub=gateway_ssh_public_key,
)

info = await run_async(
backend_compute.create_gateway,
configuration,
compute_configuration,
)

return GatewayComputeModel(
Expand All @@ -122,6 +133,15 @@ async def create_gateway(
else:
raise ResourceNotExistsError()

if (
not configuration.public_ip
and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
):
raise GatewayError(
f"Private gateways are not supported for {configuration.backend.value} backend. "
f"Supported backends: {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}."
)

if configuration.name is None:
configuration.name = await generate_gateway_name(session=session, project=project)

Expand All @@ -139,19 +159,11 @@ async def create_gateway(
if project.default_gateway is None or configuration.default:
await set_default_gateway(session=session, project=project, name=configuration.name)

compute_configuration = GatewayComputeConfiguration(
project_name=project.name,
instance_name=gateway.name,
backend=configuration.backend,
region=configuration.region,
public_ip=True,
ssh_key_pub=project.name,
)

try:
gateway.gateway_compute = await create_gateway_compute(
backend_compute=backend.compute(),
configuration=compute_configuration,
project_name=project.name,
configuration=configuration,
backend_id=backend_model.id,
)
session.add(gateway)
Expand Down Expand Up @@ -321,13 +333,6 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
async def register_service(session: AsyncSession, run_model: RunModel):
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)

service_https = run_spec.configuration.https
service_protocol = "https" if service_https else "http"

# Currently, gateway endpoint is always https
gateway_https = True
gateway_protocol = "https" if gateway_https else "http"

# TODO(egor-s): allow to configure gateway name
gateway_name: Optional[str] = None
if gateway_name is None:
Expand All @@ -343,6 +348,21 @@ async def register_service(session: AsyncSession, run_model: RunModel):
if gateway.gateway_compute is None:
raise ServerClientError("Gateway has no instance associated with it")

service_https = run_spec.configuration.https
service_protocol = "https" if service_https else "http"

gateway_configuration = None
if gateway.configuration is not None:
gateway_configuration = GatewayConfiguration.__response__.parse_raw(gateway.configuration)
if service_https and not gateway_configuration.public_ip:
raise ServerClientError("Cannot run HTTPS service on gateway without public IP")

gateway_https = True
if gateway_configuration is not None:
# Currently, https is always False for private gateways
gateway_https = gateway_configuration.public_ip
gateway_protocol = "https" if gateway_https else "http"

wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
if wildcard_domain is None:
raise ServerClientError("Domain is required for gateway")
Expand Down
6 changes: 4 additions & 2 deletions src/dstack/_internal/server/services/gateways/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import aiorwlock

from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.services.gateways.client import (
GATEWAY_MANAGEMENT_PORT,
GatewayClient,
Expand All @@ -29,9 +30,10 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
self._lock = aiorwlock.RWLock()
self.stats: Dict[str, Dict[int, Stat]] = {}
self.ip_address = ip_address

self.ports_lock = PortsLock(restrictions={server_port: 0}).acquire()
local_port = self.ports_lock.dict()[server_port]
args = ["-L", "{temp_dir}/gateway:localhost:%d" % GATEWAY_MANAGEMENT_PORT]
args += ["-R", f"localhost:8001:localhost:{server_port}"]
args += ["-R", f"localhost:{local_port}:localhost:{server_port}"]
self.tunnel = AsyncSSHTunnel(
f"ubuntu@{ip_address}",
id_rsa,
Expand Down
7 changes: 7 additions & 0 deletions src/tests/_internal/server/routers/test_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def test_list(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": False,
"public_ip": True,
},
}
]
Expand Down Expand Up @@ -124,6 +125,7 @@ async def test_get(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": False,
"public_ip": True,
},
}

Expand Down Expand Up @@ -203,6 +205,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession):
"region": "us",
"domain": None,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -257,6 +260,7 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession)
"region": "us",
"domain": None,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -391,6 +395,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -498,6 +503,7 @@ def get_backend(_, backend_type):
"region": gateway_gcp.region,
"domain": gateway_gcp.wildcard_domain,
"default": False,
"public_ip": True,
},
}
]
Expand Down Expand Up @@ -557,6 +563,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": "test.com",
"default": False,
"public_ip": True,
},
}

Expand Down

0 comments on commit 3f1504b

Please sign in to comment.