Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support provisioning instances without public IPs on AWS #1203

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create_instance(
ec2 = self.session.resource("ec2", region_name=instance_offer.region)
ec2_client = self.session.client("ec2", region_name=instance_offer.region)
iam_client = self.session.client("iam", region_name=instance_offer.region)
allocate_public_ip = self.config.allocate_public_ips

tags = [
{"Key": "Name", "Value": instance_config.instance_name},
Expand All @@ -117,6 +118,7 @@ def create_instance(
ec2_client=ec2_client,
config=self.config,
region=instance_offer.region,
allocate_public_ip=allocate_public_ip,
)
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
response = ec2.create_instances(
Expand All @@ -140,6 +142,7 @@ def create_instance(
),
spot=instance_offer.instance.resources.spot,
subnet_id=subnet_id,
allocate_public_ip=allocate_public_ip,
)
)
instance = response[0]
Expand All @@ -149,11 +152,16 @@ def create_instance(
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
)
if allocate_public_ip:
hostname = instance.public_ip_address
else:
hostname = instance.private_ip_address
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance.instance_id,
hostname=instance.public_ip_address,
public_ip_enabled=allocate_public_ip,
hostname=hostname,
internal_ip=instance.private_ip_address,
region=instance_offer.region,
price=instance_offer.price,
Expand Down Expand Up @@ -247,6 +255,7 @@ def get_vpc_id_subnet_id_or_error(
ec2_client: botocore.client.BaseClient,
config: AWSConfig,
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
if config.vpc_ids is not None:
vpc_id = config.vpc_ids.get(region)
Expand All @@ -259,6 +268,7 @@ def get_vpc_id_subnet_id_or_error(
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
Expand All @@ -268,13 +278,15 @@ def get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
vpc_name=config.vpc_name,
region=region,
allocate_public_ip=allocate_public_ip,
)


def _get_vpc_id_subnet_id_by_vpc_name_or_error(
ec2_client: botocore.client.BaseClient,
vpc_name: Optional[str],
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
if vpc_name is not None:
vpc_id = aws_resources.get_vpc_id_by_name(
Expand All @@ -290,9 +302,20 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
if vpc_name is not None:
raise ComputeError(f"Failed to find public subnet for VPC {vpc_name} in region {region}")
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
if allocate_public_ip:
raise ComputeError(
f"Failed to find public subnet for VPC {vpc_name} in region {region}"
)
raise ComputeError(
f"Failed to find private subnet with NAT for VPC {vpc_name} in region {region}"
)
if allocate_public_ip:
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
raise ComputeError(
f"Failed to find private subnet with NAT for default VPC in region {region}"
)
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/backends/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@

class AWSConfig(AWSStoredConfig, BackendConfig):
creds: AnyAWSCreds

@property
def allocate_public_ips(self) -> bool:
if self.public_ips is not None:
return self.public_ips
return True
62 changes: 55 additions & 7 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def create_instances_struct(
security_group_id: str,
spot: bool,
subnet_id: Optional[str] = None,
allocate_public_ip: bool = True,
) -> Dict[str, Any]:
struct = dict(
BlockDeviceMappings=[
Expand Down Expand Up @@ -230,7 +231,7 @@ def create_instances_struct(
if subnet_id is not None:
struct["NetworkInterfaces"] = [
{
"AssociatePublicIpAddress": True,
"AssociatePublicIpAddress": allocate_public_ip,
"DeviceIndex": 0,
"SubnetId": subnet_id,
"Groups": [security_group_id],
Expand Down Expand Up @@ -334,18 +335,31 @@ def get_vpc_by_vpc_id(ec2_client: botocore.client.BaseClient, vpc_id: str) -> Op
def get_subnet_id_for_vpc(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
allocate_public_ip: bool,
) -> Optional[str]:
"""
If `allocate_public_ip` is True, returns a first public subnet found in the VPC.
If `allocate_public_ip` is False, returns a first subnet with NAT found in the VPC.
"""
subnets = _get_subnets_by_vpc_id(ec2_client=ec2_client, vpc_id=vpc_id)
if len(subnets) == 0:
return None
# Return first public subnet
for subnet in subnets:
subnet_id = subnet["SubnetId"]
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
if allocate_public_ip:
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
else:
subnet_behind_nat = _is_subnet_behind_nat(
ec2_client=ec2_client,
vpc_id=vpc_id,
subnet_id=subnet_id,
)
if subnet_behind_nat:
return subnet_id
return None


Expand Down Expand Up @@ -440,3 +454,37 @@ def _is_public_subnet(
return True

return False


def _is_subnet_behind_nat(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
subnet_id: str,
) -> bool:
# Check explicitly associated route tables
response = ec2_client.describe_route_tables(
Filters=[{"Name": "association.subnet-id", "Values": [subnet_id]}]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True

# Main route table controls the routing of all subnetes
# that are not explicitly associated with any other route table.
if len(response["RouteTables"]) > 0:
return False

# Check implicitly associated main route table
response = ec2_client.describe_route_tables(
Filters=[
{"Name": "association.main", "Values": ["true"]},
{"Name": "vpc-id", "Values": [vpc_id]},
]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True

return False
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/backends/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AWSConfigInfo(CoreModel):
regions: Optional[List[str]] = None
vpc_name: Optional[str] = None
vpc_ids: Optional[Dict[str, str]] = None
public_ips: Optional[bool] = None


class AWSAccessKeyCreds(CoreModel):
Expand Down Expand Up @@ -46,6 +47,7 @@ class AWSConfigInfoWithCredsPartial(CoreModel):
regions: Optional[List[str]]
vpc_name: Optional[str]
vpc_ids: Optional[Dict[str, str]]
public_ips: Optional[bool]


class AWSConfigValues(CoreModel):
Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ class JobProvisioningData(CoreModel):
backend: BackendType
instance_type: InstanceType
instance_id: str
# hostname may not be set immediately after instance provisioning
# hostname may not be set immediately after instance provisioning.
# It is set to a public IP or, if public IPs are disabled, to a private IP.
hostname: Optional[str]
internal_ip: Optional[str]
# public_ip_enabled can used to distinguished instances with and without public IPs.
# hostname being None is not enough since it can be filled after provisioning.
public_ip_enabled: bool = True
region: str
price: float
username: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart
regions = config.regions
if regions is None:
regions = DEFAULT_REGIONS
allocate_public_ip = config.public_ips if config.public_ips is not None else True
# The number of workers should be >= the number of regions
with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor:
futures = []
Expand All @@ -149,6 +150,7 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart
ec2_client=ec2_client,
config=AWSConfig.parse_obj(config),
region=region,
allocate_public_ip=allocate_public_ip,
)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class AWSConfig(CoreModel):
vpc_ids: Annotated[
Optional[Dict[str, str]], Field(description="The mapping from AWS regions to VPC IDs")
] = None
public_ips: Annotated[
Optional[bool],
Field(
description="A flag to enable/disable public IP assigning on instances. Defaults to `true`."
),
] = None
creds: AnyAWSCreds = Field(..., description="The credentials", discriminator="type")


Expand All @@ -76,8 +82,8 @@ class AzureConfig(CoreModel):
class CudoConfig(CoreModel):
type: Annotated[Literal["cudo"], Field(description="The type of backend")] = "cudo"
regions: Optional[List[str]] = None
creds: Annotated[AnyCudoCreds, Field(description="The credentials")]
project_id: Annotated[str, Field(description="The project ID")]
creds: Annotated[AnyCudoCreds, Field(description="The credentials")]


class DataCrunchConfig(CoreModel):
Expand Down
1 change: 1 addition & 0 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ async def test_returns_config_info(self, test_db, session: AsyncSession):
"regions": json.loads(backend.config)["regions"],
"vpc_name": None,
"vpc_ids": None,
"public_ips": None,
"creds": json.loads(backend.auth),
}

Expand Down
1 change: 1 addition & 0 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession):
"regions": json.loads(backend.config)["regions"],
"vpc_name": None,
"vpc_ids": None,
"public_ips": None,
},
}
],
Expand Down
Loading