diff --git a/src/dstack/_internal/core/services/ssh/client.py b/src/dstack/_internal/core/services/ssh/client.py index 738201782..2126ffcb8 100644 --- a/src/dstack/_internal/core/services/ssh/client.py +++ b/src/dstack/_internal/core/services/ssh/client.py @@ -1,6 +1,4 @@ -import os import re -import shutil import subprocess from dataclasses import dataclass from pathlib import Path @@ -9,6 +7,7 @@ from dstack._internal.compat import IS_WINDOWS from dstack._internal.core.errors import SSHError from dstack._internal.utils.path import PathLike +from dstack._internal.utils.ssh import find_ssh_client @dataclass @@ -87,39 +86,6 @@ def inspect_ssh_client(path: PathLike) -> SSHClientInfo: raise SSHError(f"failed to parse `{path} -V` output: {output}") -def find_ssh_client() -> Optional[Path]: - path_str = os.getenv("DSTACK_SSH_CLIENT") - if path_str: - return Path(path_str) - if not IS_WINDOWS: - path_str = shutil.which("ssh") - if path_str: - return Path(path_str) - return None - # First, we check for ssh bundled with Git for Windows (MSYS2/MinGW-w64-built OpenSSH Portable) - # as a preferred client. It supports ForkAfterAuthentication; ControlMaster is only partially - # supported, we don't use it. - git_path_str = shutil.which("git") - if git_path_str: - # C:\Program Files\Git\cmd\git.exe -> C:\Program Files\Git\usr\bin\ssh.exe - path = Path(git_path_str).parent.parent / "usr" / "bin" / "ssh.exe" - if path.exists(): - return path - # Then we check for OpenSSH for Windows (Microsoft's fork of OpenSSH Portable). - # It does not support some features, namely ControlMaster and ForkAfterAuthentication. - windir_str = os.getenv("WINDIR") - if windir_str: - path = Path(windir_str) / "System32" / "OpenSSH" / "ssh.exe" - if path.exists(): - return path - # Finally, we check for any ssh client in PATH. It can be anything, it can be not compatible, - # so we use it only as a last resort. - path_str = shutil.which("ssh") - if path_str: - return Path(path_str) - return None - - _ssh_client_info: Optional[SSHClientInfo] = None diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index 0992b69a3..068585c48 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -1,6 +1,7 @@ import io import os import re +import shutil import subprocess import sys import tempfile @@ -40,12 +41,22 @@ def get_host_config(hostname: str, ssh_config_path: PathLike = default_ssh_confi def make_ssh_command_for_git(identity_file: PathLike) -> str: - return f"ssh -F none -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o IdentitiesOnly=yes -F /dev/null -o IdentityFile={identity_file}" + # No need to use :func:`find_ssh_client()` even on Windows even if `ssh` not in + # Windows `PATH` -- MSYS2 git (from Git for Windows) always has access to it, + # see https://www.msys2.org/docs/environments/ ("MSYS environment [...] is always active") + return ( + f'ssh -F none -i "{normalize_path(identity_file)}"' + " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o IdentitiesOnly=yes" + ) def try_ssh_key_passphrase(identity_file: PathLike, passphrase: str = "") -> bool: + ssh_keygen = find_ssh_util("ssh-keygen") + if ssh_keygen is None: + logger.warning("ssh-keygen not found") + return False r = subprocess.run( - ["ssh-keygen", "-y", "-P", passphrase, "-f", identity_file], + [ssh_keygen, "-y", "-P", passphrase, "-f", identity_file], stdout=subprocess.DEVNULL, stderr=sys.stdout.buffer, ) @@ -179,19 +190,19 @@ def convert_ssh_key_to_pem(private_string: str) -> str: with tempfile.NamedTemporaryFile(mode="w+") as key_file: key_file.write(private_string) key_file.flush() - cmd = ["ssh-keygen", "-p", "-m", "PEM", "-f", key_file.name, "-y", "-q", "-N", ""] - try: - subprocess.run( - cmd, - check=True, - capture_output=True, - text=True, - ) - except FileNotFoundError: + if ssh_keygen := find_ssh_util("ssh-keygen"): + cmd = [ssh_keygen, "-p", "-m", "PEM", "-f", key_file.name, "-y", "-q", "-N", ""] + try: + subprocess.run( + cmd, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + logger.error("Fail to convert ssh key: stdout=%s, stderr=%s", e.stdout, e.stderr) + else: logger.error("Use a PEM key or install ssh-keygen to convert it automatically") - except subprocess.CalledProcessError as e: - logger.error("Fail to convert ssh key: stdout=%s, stderr=%s", e.stdout, e.stderr) - key_file.seek(0) private_string = key_file.read() return private_string @@ -234,3 +245,74 @@ def check_required_ssh_version() -> bool: return False return False + + +def find_ssh_client() -> Optional[Path]: + """ + Finds and returns an absolute path of `ssh` executable or `None` if not found. + + If the `DSTACK_SSH_CLIENT` environment variable is set, return its value, otherwise: + * on POSIX, look for `ssh` executable in `PATH` and return it (if any). + * on Windows, first look for OpenSSH bundled with Git for Windows checking + a known directory structure, then check `PATH`, and finally check a well-known location of + OpenSSH for Windows. + """ + path_str = os.getenv("DSTACK_SSH_CLIENT") + if path_str: + path = Path(path_str) + if path.exists(): + return path.resolve() + logger.warning("DSTACK_SSH_CLIENT=%s does not exist", path_str) + return None + if not IS_WINDOWS: + path_str = shutil.which("ssh") + if path_str: + return Path(path_str) + return None + # First, we check for ssh bundled with Git for Windows (MSYS2/MinGW-w64-built OpenSSH Portable) + # as a preferred client. It supports ForkAfterAuthentication; ControlMaster is only partially + # supported, we don't use it. + git_path_str = shutil.which("git") + if git_path_str: + # C:\Program Files\Git\cmd\git.exe -> C:\Program Files\Git\usr\bin\ssh.exe + path = Path(git_path_str).parent.parent / "usr" / "bin" / "ssh.exe" + if path.exists(): + return path + # Then we check for any ssh client in PATH. It can be anything, but most likely it will be + # OpenSSH for Windows (see below). Nonetheless, it's worth trying since it's also may be + # MSYS2/Cygwin OpenSSH Portable. + path_str = shutil.which("ssh") + if path_str: + return Path(path_str) + # Finally we check for OpenSSH for Windows (Microsoft's fork of OpenSSH Portable). + # It does not support some features, namely ControlMaster and ForkAfterAuthentication. + windir_str = os.getenv("WINDIR") + if windir_str: + path = Path(windir_str) / "System32" / "OpenSSH" / "ssh.exe" + if path.exists(): + return path + return None + + +_ssh_util_dir: Optional[Path] = None + + +def find_ssh_util(name: str) -> Optional[Path]: + """ + Returns an absolute path of a given `ssh*` utility or `None` if not found. + + :param name: a utility binary name without `.exe` suffix, e.g., `ssh-keygen`, `ssh-copy-id`. + :return: a Path object. + """ + global _ssh_util_dir + if _ssh_util_dir is None: + ssh_client_path = find_ssh_client() + if ssh_client_path is None: + return None + _ssh_util_dir = ssh_client_path.parent + if IS_WINDOWS: + name = f"{name}.exe" + path = _ssh_util_dir / name + if path.exists(): + return path + return None