From fd8cdf1fc4e6bc1ff9038ae8385d2ef119c04b88 Mon Sep 17 00:00:00 2001 From: swsvc <23121066+swsvc@users.noreply.github.com> Date: Fri, 30 Aug 2024 12:34:24 +0300 Subject: [PATCH] Add Support for ECDSA and Ed25519 keys (#1641) Closes: https://github.com/dstackai/dstack/issues/1443 --- src/dstack/_internal/cli/commands/pool.py | 7 +++++-- .../cli/services/configurators/fleet.py | 8 ++++++-- .../background/tasks/process_instances.py | 12 ++++++------ src/dstack/_internal/utils/ssh.py | 19 ++++++++++++++----- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index ce7a4d217..ea1b96f38 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -30,7 +30,7 @@ from dstack._internal.core.models.runs import Requirements, get_policy_map from dstack._internal.utils.common import pretty_date from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, rsa_pkey_from_str +from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, pkey_from_str from dstack.api._public.resources import Resources from dstack.api.utils import load_profile @@ -316,12 +316,15 @@ def _add_ssh(self, args: argparse.Namespace) -> None: try: pub_key = args.ssh_identity_file.with_suffix(".pub").read_text() except FileNotFoundError: - pub_key = generate_public_key(rsa_pkey_from_str(private_key)) + pub_key = generate_public_key(pkey_from_str(private_key)) ssh_key = SSHKey(public=pub_key, private=private_key) ssh_keys.append(ssh_key) except OSError: console.print("[error]Unable to read the public key.[/]") return + except ValueError: + console.print("[error]Key type is not supported.[/]") + return login, ssh_host, port = parse_destination(args.destination) diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 90ce4cdf3..54405da44 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -14,7 +14,7 @@ from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec from dstack._internal.core.models.instances import SSHKey from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, rsa_pkey_from_str +from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, pkey_from_str from dstack.api.utils import load_profile logger = get_logger(__name__) @@ -158,9 +158,13 @@ def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]: try: pub_key = ssh_key_path_obj.with_suffix(".pub").read_text() except FileNotFoundError: - pub_key = generate_public_key(rsa_pkey_from_str(private_key)) + pub_key = generate_public_key(pkey_from_str(private_key)) return SSHKey(public=pub_key, private=private_key) except OSError as e: logger.debug("Got OSError: %s", repr(e)) console.print(f"[error]Unable to read the SSH key at {ssh_key_path}[/]") exit() + except ValueError as e: + logger.debug("Key type is not supported", repr(e)) + console.print("[error]Key type is not supported[/]") + exit() diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index a7d46f091..1ef499984 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -6,7 +6,7 @@ import requests from paramiko.pkey import PKey -from paramiko.ssh_exception import PasswordRequiredException, SSHException +from paramiko.ssh_exception import PasswordRequiredException from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import joinedload @@ -72,7 +72,7 @@ from dstack._internal.utils.logging import get_logger from dstack._internal.utils.network import get_ip_from_network from dstack._internal.utils.ssh import ( - rsa_pkey_from_str, + pkey_from_str, ) PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) @@ -249,7 +249,7 @@ async def add_remote(instance_id: UUID) -> None: # Prepare connection key try: pkeys = [ - rsa_pkey_from_str(sk.private) + pkey_from_str(sk.private) for sk in remote_details.ssh_keys if sk.private is not None ] @@ -271,14 +271,14 @@ async def add_remote(instance_id: UUID) -> None: }, ) return - except SSHException: + except ValueError: instance.status = InstanceStatus.TERMINATED instance.deleted = True instance.deleted_at = get_current_datetime() - instance.termination_reason = "Cannot parse private key, RSA key required" + instance.termination_reason = "Cannot parse private key, key type is not supported" await session.commit() logger.warning( - "Failed to start instance %s: private SSH key is not a valid RSA key", + "Failed to start instance %s: unsupported private SSH key type", instance.name, extra={ "instance_name": instance.name, diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index 92f3e285e..dceb2e1d2 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -11,6 +11,7 @@ from filelock import FileLock from paramiko.config import SSHConfig from paramiko.pkey import PKey, PublicBlob +from paramiko.ssh_exception import SSHException from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike @@ -20,6 +21,8 @@ default_ssh_config_path = "~/.ssh/config" +SUPPORTED_KEY_TYPES = (paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key) + def get_public_key_fingerprint(text: str) -> str: pb = PublicBlob.from_string(text) @@ -147,11 +150,17 @@ def convert_pkcs8_to_pem(private_string: str) -> str: return private_string -def rsa_pkey_from_str(private_string: str) -> PKey: - key_file = io.StringIO(private_string.strip()) - pkey = paramiko.RSAKey.from_private_key(key_file) - key_file.close() - return pkey +def pkey_from_str(private_string: str) -> PKey: + for key_type in SUPPORTED_KEY_TYPES: + try: + key_file = io.StringIO(private_string.strip()) + pkey = key_type.from_private_key(key_file) + key_file.close() + return pkey + except (SSHException, ValueError): + pass + + raise ValueError("Unsupported key type") def generate_public_key(private_key: PKey) -> str: