From 987113ab6b48238fcbb1c52c9698e5527961f398 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Wed, 2 Aug 2023 13:27:15 +0400 Subject: [PATCH 1/2] Create and delete gateway in AWS --- cli/dstack/_internal/backend/aws/compute.py | 29 ++++ cli/dstack/_internal/backend/aws/gateway.py | 162 ++++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 cli/dstack/_internal/backend/aws/gateway.py diff --git a/cli/dstack/_internal/backend/aws/compute.py b/cli/dstack/_internal/backend/aws/compute.py index 33320a0d8..06f34270b 100644 --- a/cli/dstack/_internal/backend/aws/compute.py +++ b/cli/dstack/_internal/backend/aws/compute.py @@ -2,10 +2,12 @@ from boto3 import Session +import dstack._internal.backend.aws.gateway as gateway from dstack._internal.backend.aws import runners from dstack._internal.backend.aws import utils as aws_utils from dstack._internal.backend.aws.config import AWSConfig from dstack._internal.backend.base.compute import Compute +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead @@ -63,6 +65,33 @@ def cancel_spot_request(self, runner: Runner): request_id=runner.request_id, ) + def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: + instance = gateway.create_gateway_instance( + ec2_client=self._get_ec2_client(region=self.backend_config.region_name), + subnet_id=self.backend_config.subnet_id, + bucket_name=self.backend_config.bucket_name, + instance_name=instance_name, + ssh_key_pub=ssh_key_pub, + ) + return GatewayHead( + instance_name=instance_name, + external_ip=instance["PublicIpAddress"], + internal_ip=instance["PrivateIpAddress"], + ) + + def delete_instance(self, instance_name: str): + try: + instance_id = gateway.get_instance_id( + ec2_client=self._get_ec2_client(region=self.backend_config.region_name), + instance_name=instance_name, + ) + runners.terminate_instance( + ec2_client=self._get_ec2_client(region=self.backend_config.region_name), + request_id=instance_id, + ) + except IndexError: + return + def _get_ec2_client(self, region: Optional[str] = None): if region is None: return aws_utils.get_ec2_client(self.session) diff --git a/cli/dstack/_internal/backend/aws/gateway.py b/cli/dstack/_internal/backend/aws/gateway.py new file mode 100644 index 000000000..81c32d12e --- /dev/null +++ b/cli/dstack/_internal/backend/aws/gateway.py @@ -0,0 +1,162 @@ +import time +from typing import Optional + +from botocore.client import BaseClient + + +def create_gateway_instance( + ec2_client: BaseClient, + subnet_id: Optional[str], + bucket_name: str, + instance_name: str, + ssh_key_pub: str, + machine_type: str = "t2.micro", +) -> dict: + launch_specification = {} + if subnet_id: + launch_specification["NetworkInterfaces"] = [ + { + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "SubnetId": subnet_id, + "Groups": [gateway_security_group_id(ec2_client, subnet_id, bucket_name)], + }, + ] + else: + launch_specification["SecurityGroupIds"] = [ + gateway_security_group_id(ec2_client, subnet_id, bucket_name) + ] + tags = [ + {"Key": "Name", "Value": instance_name}, + {"Key": "owner", "Value": "dstack"}, + {"Key": "role", "Value": "gateway"}, + {"Key": "dstack_bucket", "Value": bucket_name}, + ] + response = ec2_client.run_instances( + BlockDeviceMappings=[ + { + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": 10, + "VolumeType": "gp2", + }, + } + ], + ImageId="ami-0cffefff2d52e0a23", # Ubuntu 22.04 LTS + InstanceType=machine_type, + MinCount=1, + MaxCount=1, + UserData=gateway_user_data_script(ssh_key_pub), + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": tags, + }, + ], + **launch_specification, + ) + return wait_till_running(ec2_client, response["Instances"][0]) + + +def gateway_security_group_id( + ec2_client: BaseClient, subnet_id: Optional[str], bucket_name: str +) -> str: + name_parts = ["dstack_gateway_sg"] + if subnet_id: + name_parts.append(subnet_id.replace("-", "_")) + name_parts.append(bucket_name.replace("-", "_")) + security_group_name = "_".join(name_parts) + response = ec2_client.describe_security_groups( + Filters=[ + { + "Name": "group-name", + "Values": [ + security_group_name, + ], + }, + ], + ) + if response.get("SecurityGroups"): + return response["SecurityGroups"][0]["GroupId"] + + group_specification = {} + if subnet_id: + subnets_response = ec2_client.describe_subnets(SubnetIds=[subnet_id]) + group_specification["VpcId"] = subnets_response["Subnets"][0]["VpcId"] + security_group = ec2_client.create_security_group( + Description="Generated by dstack", + GroupName=security_group_name, + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [ + {"Key": "owner", "Value": "dstack"}, + {"Key": "role", "Value": "gateway"}, + ], + }, + ], + **group_specification, + ) + security_group_id = security_group["GroupId"] + ip_permissions = [ + { + "FromPort": 0, + "ToPort": 65535, + "IpProtocol": "tcp", + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + ] + ec2_client.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) + ec2_client.authorize_security_group_egress( + GroupId=security_group_id, + IpPermissions=[ + { + "IpProtocol": "-1", + } + ], + ) + return security_group_id + + +def wait_till_running( + ec2_client: BaseClient, instance: dict, delay: int = 5, attempts: int = 30 +) -> dict: + instance_id = instance["InstanceId"] + attempt = 0 + while instance["State"]["Name"] != "running": + if attempt >= attempts: + raise RuntimeError(f"Instance {instance_id} is not running") + time.sleep(delay) + attempt += 1 + desc = ec2_client.describe_instances(InstanceIds=[instance_id]) + instance = desc["Reservations"][0]["Instances"][0] + return instance + + +def get_instance_id(ec2_client: BaseClient, instance_name: str) -> str: + desc = ec2_client.describe_instances( + Filters=[ + { + "Name": "tag:Name", + "Values": [instance_name], + } + ] + ) + return desc["Reservations"][0]["Instances"][0]["InstanceId"] + + +def gateway_user_data_script(ssh_key_pub: str) -> str: + return f"""#!/bin/bash +sudo apt-get update +DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx +UBUNTU_UID=$(id -u ubuntu) +UBUNTU_GID=$(id -g ubuntu) +install -m 700 -o $UBUNTU_UID -g $UBUNTU_GID -d /home/ubuntu/.ssh +install -m 600 -o $UBUNTU_UID -g $UBUNTU_GID /dev/null /home/ubuntu/.ssh/authorized_keys +echo "{ssh_key_pub}" > /home/ubuntu/.ssh/authorized_keys +WWW_UID=$(id -u www-data) +WWW_GID=$(id -g www-data) +install -m 700 -o $WWW_UID -g $WWW_GID -d /var/www/.ssh +install -m 600 -o $WWW_UID -g $WWW_GID /dev/null /var/www/.ssh/authorized_keys""" From 5b4f019d7457f80ab8e96afb9875847468c5d705 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Wed, 2 Aug 2023 16:50:06 +0400 Subject: [PATCH 2/2] Create and delete gateway in Azure --- cli/dstack/_internal/backend/azure/compute.py | 40 ++++ cli/dstack/_internal/backend/azure/gateway.py | 195 ++++++++++++++++++ cli/dstack/_internal/backend/base/compute.py | 2 - 3 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 cli/dstack/_internal/backend/azure/gateway.py diff --git a/cli/dstack/_internal/backend/azure/compute.py b/cli/dstack/_internal/backend/azure/compute.py index c9f8e3230..dc203caf0 100644 --- a/cli/dstack/_internal/backend/azure/compute.py +++ b/cli/dstack/_internal/backend/azure/compute.py @@ -34,7 +34,9 @@ from azure.mgmt.keyvault import KeyVaultManagementClient from azure.mgmt.network import NetworkManagementClient from azure.mgmt.resource import ResourceManagementClient +from msrestazure.tools import parse_resource_id +import dstack._internal.backend.azure.gateway as gateway from dstack import version from dstack._internal.backend.azure import utils as azure_utils from dstack._internal.backend.azure.config import AzureConfig @@ -47,6 +49,7 @@ ) from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead, RequestStatus @@ -140,6 +143,43 @@ def terminate_instance(self, runner: Runner): def cancel_spot_request(self, runner: Runner): self.terminate_instance(runner) + def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: + vm = gateway.create_gateway( + compute_client=self._compute_client, + network_client=self._network_client, + subscription_id=self.azure_config.subscription_id, + location=self.azure_config.location, + resource_group=self.azure_config.resource_group, + network=self.azure_config.network, + subnet=self.azure_config.subnet, + instance_name=instance_name, + ssh_key_pub=ssh_key_pub, + ) + interface = gateway.get_network_interface( + network_client=self._network_client, + resource_group=self.azure_config.resource_group, + interface=parse_resource_id(vm.network_profile.network_interfaces[0].id)[ + "resource_name" + ], + ) + public_ip = gateway.get_public_ip( + network_client=self._network_client, + resource_group=self.azure_config.resource_group, + public_ip=interface.ip_configurations[0].public_ip_address.name, + ) + return GatewayHead( + instance_name=instance_name, + external_ip=public_ip.ip_address, + internal_ip=interface.ip_configurations[0].private_ip_address, + ) + + def delete_instance(self, instance_name: str): + _terminate_instance( + compute_client=self._compute_client, + resource_group=self.azure_config.resource_group, + instance_name=instance_name, + ) + def _get_instance_types(client: ComputeManagementClient, location: str) -> List[InstanceType]: instance_types = [] diff --git a/cli/dstack/_internal/backend/azure/gateway.py b/cli/dstack/_internal/backend/azure/gateway.py new file mode 100644 index 000000000..3afc518f5 --- /dev/null +++ b/cli/dstack/_internal/backend/azure/gateway.py @@ -0,0 +1,195 @@ +import base64 +from typing import List + +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.compute.models import ( + DiskCreateOptionTypes, + HardwareProfile, + ImageReference, + LinuxConfiguration, + ManagedDiskParameters, + NetworkProfile, + OSDisk, + OSProfile, + SshConfiguration, + SshPublicKey, + StorageAccountTypes, + StorageProfile, + SubResource, + VirtualMachine, + VirtualMachineNetworkInterfaceConfiguration, + VirtualMachineNetworkInterfaceIPConfiguration, + VirtualMachinePublicIPAddressConfiguration, +) +from azure.mgmt.network import NetworkManagementClient +from azure.mgmt.network.models import ( + NetworkInterface, + NetworkInterfaceIPConfiguration, + NetworkSecurityGroup, + PublicIPAddress, + SecurityRule, + SecurityRuleAccess, + SecurityRuleDirection, + SecurityRuleProtocol, +) + +import dstack._internal.backend.azure.utils as azure_utils + + +def create_gateway( + compute_client: ComputeManagementClient, + network_client: NetworkManagementClient, + subscription_id: str, + location: str, + resource_group: str, + network: str, + subnet: str, + instance_name: str, + ssh_key_pub: str, + vm_size: str = "Standard_B1s", +) -> VirtualMachine: + poller = compute_client.virtual_machines.begin_create_or_update( + resource_group, + instance_name, + VirtualMachine( + location=location, + hardware_profile=HardwareProfile(vm_size=vm_size), + storage_profile=gateway_storage_profile(), + os_profile=OSProfile( + computer_name="gatewayvm", + admin_username="ubuntu", + linux_configuration=LinuxConfiguration( + ssh=SshConfiguration( + public_keys=[ + SshPublicKey( + path="/home/ubuntu/.ssh/authorized_keys", + key_data=ssh_key_pub, + ) + ] + ) + ), + ), + network_profile=NetworkProfile( + network_api_version=NetworkManagementClient.DEFAULT_API_VERSION, + network_interface_configurations=gateway_interface_configurations( + network_client=network_client, + subscription_id=subscription_id, + location=location, + resource_group=resource_group, + network=network, + subnet=subnet, + ), + ), + priority="Regular", + user_data=base64.b64encode(gateway_user_data_script().encode()).decode(), + tags={ + "owner": "dstack", + "role": "gateway", + }, + ), + ) + vm = poller.result() + return vm + + +def gateway_storage_profile() -> StorageProfile: + return StorageProfile( + image_reference=ImageReference( + publisher="canonical", + offer="0001-com-ubuntu-server-jammy", + sku="22_04-lts", + version="latest", + ), + os_disk=OSDisk( + create_option=DiskCreateOptionTypes.FROM_IMAGE, + managed_disk=ManagedDiskParameters( + storage_account_type=StorageAccountTypes.STANDARD_SSD_LRS + ), + disk_size_gb=30, + delete_option="Delete", + ), + ) + + +def gateway_interface_configurations( + network_client: NetworkManagementClient, + subscription_id: str, + location: str, + resource_group: str, + network: str, + subnet: str, +) -> List[VirtualMachineNetworkInterfaceConfiguration]: + conf = VirtualMachineNetworkInterfaceConfiguration( + name="nic_config", + network_security_group=SubResource( + id=gateway_network_security_group(network_client, location, resource_group) + ), + ip_configurations=[ + VirtualMachineNetworkInterfaceIPConfiguration( + name="ip_config", + subnet=SubResource( + id=azure_utils.get_subnet_id( + subscription_id, + resource_group, + network, + subnet, + ) + ), + public_ip_address_configuration=VirtualMachinePublicIPAddressConfiguration( + name="public_ip_config", + ), + ) + ], + ) + return [conf] + + +def gateway_network_security_group( + network_client: NetworkManagementClient, + location: str, + resource_group: str, +) -> str: + poller = network_client.network_security_groups.begin_create_or_update( + resource_group_name=resource_group, + network_security_group_name="dstack-gateway-network-security-group", + parameters=NetworkSecurityGroup( + location=location, + security_rules=[ + SecurityRule( + name="runner_service", + protocol=SecurityRuleProtocol.TCP, + source_address_prefix="Internet", + source_port_range="*", + destination_address_prefix="*", + destination_port_range="0-65535", + access=SecurityRuleAccess.ALLOW, + priority=101, + direction=SecurityRuleDirection.INBOUND, + ) + ], + ), + ) + security_group: NetworkSecurityGroup = poller.result() + return security_group.id + + +def get_network_interface( + network_client: NetworkManagementClient, resource_group: str, interface: str +) -> NetworkInterface: + return network_client.network_interfaces.get(resource_group, interface) + + +def get_public_ip( + network_client: NetworkManagementClient, resource_group: str, public_ip: str +) -> PublicIPAddress: + return network_client.public_ip_addresses.get(resource_group, public_ip) + + +def gateway_user_data_script() -> str: + return f"""#!/bin/sh +sudo apt-get update +DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx +WWW_UID=$(id -u www-data) +WWW_GID=$(id -g www-data) +install -m 700 -o $WWW_UID -g $WWW_GID -d /var/www/.ssh +install -m 600 -o $WWW_UID -g $WWW_GID /dev/null /var/www/.ssh/authorized_keys""" diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index 5305301a8..db7618b4f 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -51,11 +51,9 @@ def cancel_spot_request(self, runner: Runner): pass def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: - # todo make abstract & implement for each backend raise NotImplementedError() def delete_instance(self, instance_name: str): - # todo make abstract & implement for each backend raise NotImplementedError()