Skip to content

Commit

Permalink
Add Support for ECDSA and Ed25519 keys (#1641)
Browse files Browse the repository at this point in the history
Closes: #1443
  • Loading branch information
swsvc authored Aug 30, 2024
1 parent 618bf10 commit fd8cdf1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
7 changes: 5 additions & 2 deletions src/dstack/_internal/cli/commands/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dstack._internal.core.models.runs import Requirements, get_policy_map
from dstack._internal.utils.common import pretty_date
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, rsa_pkey_from_str
from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, pkey_from_str
from dstack.api._public.resources import Resources
from dstack.api.utils import load_profile

Expand Down Expand Up @@ -316,12 +316,15 @@ def _add_ssh(self, args: argparse.Namespace) -> None:
try:
pub_key = args.ssh_identity_file.with_suffix(".pub").read_text()
except FileNotFoundError:
pub_key = generate_public_key(rsa_pkey_from_str(private_key))
pub_key = generate_public_key(pkey_from_str(private_key))
ssh_key = SSHKey(public=pub_key, private=private_key)
ssh_keys.append(ssh_key)
except OSError:
console.print("[error]Unable to read the public key.[/]")
return
except ValueError:
console.print("[error]Key type is not supported.[/]")
return

login, ssh_host, port = parse_destination(args.destination)

Expand Down
8 changes: 6 additions & 2 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec
from dstack._internal.core.models.instances import SSHKey
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, rsa_pkey_from_str
from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, pkey_from_str
from dstack.api.utils import load_profile

logger = get_logger(__name__)
Expand Down Expand Up @@ -158,9 +158,13 @@ def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]:
try:
pub_key = ssh_key_path_obj.with_suffix(".pub").read_text()
except FileNotFoundError:
pub_key = generate_public_key(rsa_pkey_from_str(private_key))
pub_key = generate_public_key(pkey_from_str(private_key))
return SSHKey(public=pub_key, private=private_key)
except OSError as e:
logger.debug("Got OSError: %s", repr(e))
console.print(f"[error]Unable to read the SSH key at {ssh_key_path}[/]")
exit()
except ValueError as e:
logger.debug("Key type is not supported", repr(e))
console.print("[error]Key type is not supported[/]")
exit()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import requests
from paramiko.pkey import PKey
from paramiko.ssh_exception import PasswordRequiredException, SSHException
from paramiko.ssh_exception import PasswordRequiredException
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import joinedload
Expand Down Expand Up @@ -72,7 +72,7 @@
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.network import get_ip_from_network
from dstack._internal.utils.ssh import (
rsa_pkey_from_str,
pkey_from_str,
)

PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)
Expand Down Expand Up @@ -249,7 +249,7 @@ async def add_remote(instance_id: UUID) -> None:
# Prepare connection key
try:
pkeys = [
rsa_pkey_from_str(sk.private)
pkey_from_str(sk.private)
for sk in remote_details.ssh_keys
if sk.private is not None
]
Expand All @@ -271,14 +271,14 @@ async def add_remote(instance_id: UUID) -> None:
},
)
return
except SSHException:
except ValueError:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
instance.termination_reason = "Cannot parse private key, RSA key required"
instance.termination_reason = "Cannot parse private key, key type is not supported"
await session.commit()
logger.warning(
"Failed to start instance %s: private SSH key is not a valid RSA key",
"Failed to start instance %s: unsupported private SSH key type",
instance.name,
extra={
"instance_name": instance.name,
Expand Down
19 changes: 14 additions & 5 deletions src/dstack/_internal/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from filelock import FileLock
from paramiko.config import SSHConfig
from paramiko.pkey import PKey, PublicBlob
from paramiko.ssh_exception import SSHException

from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike
Expand All @@ -20,6 +21,8 @@

default_ssh_config_path = "~/.ssh/config"

SUPPORTED_KEY_TYPES = (paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key)


def get_public_key_fingerprint(text: str) -> str:
pb = PublicBlob.from_string(text)
Expand Down Expand Up @@ -147,11 +150,17 @@ def convert_pkcs8_to_pem(private_string: str) -> str:
return private_string


def rsa_pkey_from_str(private_string: str) -> PKey:
key_file = io.StringIO(private_string.strip())
pkey = paramiko.RSAKey.from_private_key(key_file)
key_file.close()
return pkey
def pkey_from_str(private_string: str) -> PKey:
for key_type in SUPPORTED_KEY_TYPES:
try:
key_file = io.StringIO(private_string.strip())
pkey = key_type.from_private_key(key_file)
key_file.close()
return pkey
except (SSHException, ValueError):
pass

raise ValueError("Unsupported key type")


def generate_public_key(private_key: PKey) -> str:
Expand Down

0 comments on commit fd8cdf1

Please sign in to comment.