diff --git a/main.go b/main.go index f637438..2a63211 100644 --- a/main.go +++ b/main.go @@ -34,11 +34,12 @@ var _ interface { } = new(rootCmd) type rootCmd struct { - skipSync bool - syncBack bool - printVersion bool - bindAddr string - sshFlags string + skipSync bool + syncBack bool + printVersion bool + noReuseConnection bool + bindAddr string + sshFlags string } func (c *rootCmd) Spec() cli.CommandSpec { @@ -53,6 +54,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) { fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host") fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination") fl.BoolVar(&c.printVersion, "version", false, "print version information and exit") + fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket") fl.StringVar(&c.bindAddr, "bind", "", "local bind address for SSH tunnel, in [HOST][:PORT] syntax (default: 127.0.0.1)") fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags") } @@ -76,10 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) { } err := sshCode(host, dir, options{ - skipSync: c.skipSync, - sshFlags: c.sshFlags, - bindAddr: c.bindAddr, - syncBack: c.syncBack, + skipSync: c.skipSync, + sshFlags: c.sshFlags, + bindAddr: c.bindAddr, + syncBack: c.syncBack, + reuseConnection: !c.noReuseConnection, }) if err != nil { @@ -101,7 +104,7 @@ Environment variables: More info: https://github.com/cdr/sshcode Arguments: -%vHOST is passed into the ssh command. Valid formats are '' or 'gcp:'. +%vHOST is passed into the ssh command. Valid formats are '' or 'gcp:'. %vDIR is optional.`, helpTab, vsCodeConfigDirEnv, helpTab, vsCodeExtensionsDirEnv, diff --git a/sshcode.go b/sshcode.go index f1a0e2f..f525126 100644 --- a/sshcode.go +++ b/sshcode.go @@ -12,6 +12,7 @@ import ( "path/filepath" "strconv" "strings" + "syscall" "time" "github.com/pkg/browser" @@ -21,18 +22,23 @@ import ( const codeServerPath = "~/.cache/sshcode/sshcode-server" +const ( + sshDirectory = "~/.ssh" + sshDirectoryUnsafeModeMask = 0022 + sshControlPath = sshDirectory + "/control-%h-%p-%r" +) + type options struct { - skipSync bool - syncBack bool - noOpen bool - bindAddr string - remotePort string - sshFlags string + skipSync bool + syncBack bool + noOpen bool + reuseConnection bool + bindAddr string + remotePort string + sshFlags string } func sshCode(host, dir string, o options) error { - flog.Info("ensuring code-server is updated...") - host, extraSSHFlags, err := parseHost(host) if err != nil { return xerrors.Errorf("failed to parse host IP: %w", err) @@ -53,6 +59,24 @@ func sshCode(host, dir string, o options) error { return xerrors.Errorf("failed to find available remote port: %w", err) } + // Check the SSH directory's permissions and warn the user if it is not safe. + o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection) + + // Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication + // only happens on the initial connection. + if o.reuseConnection { + flog.Info("starting SSH master connection...") + newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host) + defer cancel() + if err != nil { + flog.Error("failed to start SSH master connection: %v", err) + o.reuseConnection = false + } else { + o.sshFlags = newSSHFlags + } + } + + flog.Info("ensuring code-server is updated...") dlScript := downloadScript(codeServerPath) // Downloads the latest code-server and allows it to be executed. @@ -147,8 +171,8 @@ func sshCode(host, dir string, o options) error { case <-c: } + flog.Info("shutting down") if !o.syncBack || o.skipSync { - flog.Info("shutting down") return nil } @@ -167,6 +191,24 @@ func sshCode(host, dir string, o options) error { return nil } +// expandPath returns an expanded version of path. +func expandPath(path string) string { + path = filepath.Clean(os.ExpandEnv(path)) + + // Replace tilde notation in path with the home directory. You can't replace the first instance of `~` in the + // string with the homedir as having a tilde in the middle of a filename is valid. + homedir := os.Getenv("HOME") + if homedir != "" { + if path == "~" { + path = homedir + } else if strings.HasPrefix(path, "~/") { + path = filepath.Join(homedir, path[2:]) + } + } + + return filepath.Clean(path) +} + func parseBindAddr(bindAddr string) (string, error) { if !strings.Contains(bindAddr, ":") { bindAddr += ":" @@ -263,6 +305,100 @@ func randomPort() (string, error) { return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries) } +// checkSSHDirectory performs sanity and safety checks on sshDirectory, and +// returns a new value for o.reuseConnection depending on the checks. +func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool { + sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory)) + if err != nil { + if reuseConnection { + flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err) + } + reuseConnection = false + } else { + if !sshDirectoryMode.IsDir() { + if reuseConnection { + flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory) + } else { + flog.Info("warning: %v is not a directory", sshDirectory) + } + reuseConnection = false + } + if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 { + flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+ + "the owner (and files inside should be set to 0600)", sshDirectory) + } + } + return reuseConnection +} + +// startSSHMaster starts an SSH master connection and waits for it to be ready. +// It returns a new set of SSH flags for child SSH processes to use. +func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) { + ctx, cancel := context.WithCancel(context.Background()) + + newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath) + + // -MN means "start a master socket and don't open a session, just connect". + sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host) + sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr) + sshMasterCmd.Stdin = os.Stdin + sshMasterCmd.Stderr = os.Stderr + + // Gracefully stop the SSH master. + stopSSHMaster := func() { + if sshMasterCmd.Process != nil { + if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() { + return + } + err := sshMasterCmd.Process.Signal(syscall.SIGTERM) + if err != nil { + flog.Error("failed to send SIGTERM to SSH master process: %v", err) + } + } + cancel() + } + + // Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before + // sshcode does, and allows sshMasterCmd.ProcessState to be populated. + err := sshMasterCmd.Start() + go sshMasterCmd.Wait() + if err != nil { + return "", stopSSHMaster, err + } + err = checkSSHMaster(sshMasterCmd, newSSHFlags, host) + if err != nil { + stopSSHMaster() + return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err) + } + return newSSHFlags, stopSSHMaster, nil +} + +// checkSSHMaster polls every second for 30 seconds to check if the SSH master +// is ready. +func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error { + var ( + maxTries = 30 + sleepDur = time.Second + err error + ) + for i := 0; i < maxTries; i++ { + // Check if the master is running. + if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) { + return xerrors.Errorf("SSH master process is not running") + } + + // Check if it's ready. + sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host) + sshCmd := exec.Command("sh", "-c", sshCmdStr) + err = sshCmd.Run() + if err == nil { + return nil + } + time.Sleep(sleepDur) + } + return xerrors.Errorf("max number of tries exceeded: %d", maxTries) +} + func syncUserSettings(sshFlags string, host string, back bool) error { localConfDir, err := configDir() if err != nil {