diff --git a/main.go b/main.go index fd7ead3..2d0866f 100644 --- a/main.go +++ b/main.go @@ -74,7 +74,7 @@ Environment variables: More info: https://github.com/cdr/sshcode Arguments: -%vHOST is passed into the ssh command. +%vHOST is passed into the ssh command. Valid formats are '' or 'gcp:'. %vDIR is optional. %v`, diff --git a/sshcode.go b/sshcode.go index c644c7d..cfff58b 100644 --- a/sshcode.go +++ b/sshcode.go @@ -32,6 +32,14 @@ type options struct { func sshCode(host, dir string, o options) error { flog.Info("ensuring code-server is updated...") + host, extraSSHFlags, err := parseIP(host) + if err != nil { + return xerrors.Errorf("failed to parse host IP: %w", err) + } + if extraSSHFlags != "" { + o.sshFlags = strings.Join([]string{extraSSHFlags, o.sshFlags}, " ") + } + dlScript := downloadScript(codeServerPath) // Downloads the latest code-server and allows it to be executed. @@ -41,7 +49,7 @@ func sshCode(host, dir string, o options) error { sshCmd.Stdout = os.Stdout sshCmd.Stderr = os.Stderr sshCmd.Stdin = strings.NewReader(dlScript) - err := sshCmd.Run() + err = sshCmd.Run() if err != nil { return xerrors.Errorf("failed to update code-server: \n---ssh cmd---\n%s\n---download script---\n%s: %w", sshCmdStr, @@ -341,3 +349,55 @@ func ensureDir(path string) error { return nil } + +// parseIP parses the host to a valid IP address. If 'gcp:' is prefixed to the +// host then a lookup is done using gcloud to determine the external IP and any +// additional SSH arguments that should be used for ssh commands. +func parseIP(host string) (ip string, additionalFlags string, err error) { + host = strings.TrimSpace(host) + switch { + case strings.HasPrefix(host, "gcp:"): + instance := strings.TrimPrefix(host, "gcp:") + return parseGCPSSHCmd(instance) + default: + if net.ParseIP(host) == nil { + return "", "", xerrors.New("host argument is not a valid IP address") + } + return host, "", nil + } +} + +// parseGCPSSHCmd parses the IP address and flags used by 'gcloud' when +// ssh'ing to an instance. +func parseGCPSSHCmd(instance string) (ip, sshFlags string, err error) { + dryRunCmd := fmt.Sprintf("gcloud compute ssh --dry-run %v", instance) + + out, err := exec.Command("sh", "-c", dryRunCmd).CombinedOutput() + if err != nil { + return "", "", xerrors.Errorf("%s: %w", out, err) + } + + toks := strings.Split(string(out), " ") + if len(toks) < 2 { + return "", "", xerrors.Errorf("unexpected output for '%v' command, %s", dryRunCmd, out) + } + + // Slice off the '/usr/bin/ssh' prefix and the '@' suffix. + sshFlags = strings.Join(toks[1:len(toks)-1], " ") + + // E.g. foo@1.2.3.4. + userIP := toks[len(toks)-1] + toks = strings.Split(userIP, "@") + // Assume the '@' is missing. + if len(toks) < 2 { + ip = strings.TrimSpace(toks[0]) + } else { + ip = strings.TrimSpace(toks[1]) + } + + if net.ParseIP(ip) == nil { + return "", "", xerrors.Errorf("parsed invalid ip address %v", ip) + } + + return ip, sshFlags, nil +} diff --git a/sshcode_test.go b/sshcode_test.go index e486460..318f7da 100644 --- a/sshcode_test.go +++ b/sshcode_test.go @@ -261,7 +261,7 @@ func waitForSSHCode(t *testing.T, port string, timeout time.Duration) { } } -// fakeRSAKey isn't used for anything other than the trashh ssh +// fakeRSAKey isn't used for anything other than the trassh ssh // server. const fakeRSAKey = `-----BEGIN RSA PRIVATE KEY----- MIIEpQIBAAKCAQEAsbbGAxPQeqti2OgdzuMgJGBAwXe/bFhQTPuk0bIvavkZwX/a