diff --git a/cmd/puppeth/ssh.go b/cmd/puppeth/ssh.go index da2862db2f10..039cb6cb45d9 100644 --- a/cmd/puppeth/ssh.go +++ b/cmd/puppeth/ssh.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/log" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/terminal" ) @@ -43,6 +44,8 @@ type sshClient struct { logger log.Logger } +const EnvSSHAuthSock = "SSH_AUTH_SOCK" + // dial establishes an SSH connection to a remote node using the current user and // the user's configured private RSA key. If that fails, password authentication // is fallen back to. server can be a string like user:identity@server:port. @@ -79,38 +82,49 @@ func dial(server string, pubkey []byte) (*sshClient, error) { if username == "" { username = user.Username } - // Configure the supported authentication methods (private key and password) - var auths []ssh.AuthMethod - path := filepath.Join(user.HomeDir, ".ssh", identity) - if buf, err := ioutil.ReadFile(path); err != nil { - log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) + // Configure the supported authentication methods (ssh agent, private key and password) + var ( + auths []ssh.AuthMethod + conn net.Conn + ) + if conn, err = net.Dial("unix", os.Getenv(EnvSSHAuthSock)); err != nil { + log.Warn("Unable to dial SSH agent, falling back to private keys", "err", err) } else { - key, err := ssh.ParsePrivateKey(buf) - if err != nil { - fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path) - blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() - if err != nil { - log.Warn("Couldn't read password", "err", err) - } - key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob) + client := agent.NewClient(conn) + auths = append(auths, ssh.PublicKeysCallback(client.Signers)) + } + if err != nil { + path := filepath.Join(user.HomeDir, ".ssh", identity) + if buf, err := ioutil.ReadFile(path); err != nil { + log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) + } else { + key, err := ssh.ParsePrivateKey(buf) if err != nil { - log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err) + fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path) + blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + if err != nil { + log.Warn("Couldn't read password", "err", err) + } + key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob) + if err != nil { + log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err) + } else { + auths = append(auths, ssh.PublicKeys(key)) + } } else { auths = append(auths, ssh.PublicKeys(key)) } - } else { - auths = append(auths, ssh.PublicKeys(key)) } - } - auths = append(auths, ssh.PasswordCallback(func() (string, error) { - fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server) - blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) + auths = append(auths, ssh.PasswordCallback(func() (string, error) { + fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server) + blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() - return string(blob), err - })) + fmt.Println() + return string(blob), err + })) + } // Resolve the IP address of the remote server addr, err := net.LookupHost(hostname) if err != nil {