Skip to content
This repository was archived by the owner on Jan 17, 2021. It is now read-only.

Add SSH master connection feature #116

Merged
merged 8 commits into from
Jun 28, 2019
23 changes: 13 additions & 10 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ var _ interface {
} = new(rootCmd)

type rootCmd struct {
skipSync bool
syncBack bool
printVersion bool
bindAddr string
sshFlags string
skipSync bool
syncBack bool
printVersion bool
noReuseConnection bool
bindAddr string
sshFlags string
}

func (c *rootCmd) Spec() cli.CommandSpec {
Expand All @@ -53,6 +54,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) {
fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host")
fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination")
fl.BoolVar(&c.printVersion, "version", false, "print version information and exit")
fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket")
fl.StringVar(&c.bindAddr, "bind", "", "local bind address for ssh tunnel")
fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags")
}
Expand All @@ -76,10 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) {
}

err := sshCode(host, dir, options{
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
noReuseConnection: c.noReuseConnection,
})

if err != nil {
Expand All @@ -101,7 +104,7 @@ Environment variables:
More info: https://github.com/cdr/sshcode

Arguments:
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vDIR is optional.`,
helpTab, vsCodeConfigDirEnv,
helpTab, vsCodeExtensionsDirEnv,
Expand Down
139 changes: 128 additions & 11 deletions sshcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path/filepath"
"strconv"
"strings"
"syscall"
"time"

"github.com/pkg/browser"
Expand All @@ -20,19 +21,21 @@ import (
)

const codeServerPath = "~/.cache/sshcode/sshcode-server"
const sshDirectory = "~/.ssh"
const sshDirectoryUnsafeModeMask = 0022
const sshControlPath = sshDirectory + "/control-%h-%p-%r"

type options struct {
skipSync bool
syncBack bool
noOpen bool
bindAddr string
remotePort string
sshFlags string
skipSync bool
syncBack bool
noOpen bool
noReuseConnection bool
bindAddr string
remotePort string
sshFlags string
}

func sshCode(host, dir string, o options) error {
flog.Info("ensuring code-server is updated...")

host, extraSSHFlags, err := parseHost(host)
if err != nil {
return xerrors.Errorf("failed to parse host IP: %w", err)
Expand All @@ -53,6 +56,73 @@ func sshCode(host, dir string, o options) error {
return xerrors.Errorf("failed to find available remote port: %w", err)
}

// Check the SSH directory's permissions and warn the user if it is not safe.
sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
if err != nil {
if !o.noReuseConnection {
flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
o.noReuseConnection = true
}
} else {
if !sshDirectoryMode.IsDir() {
if !o.noReuseConnection {
flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
o.noReuseConnection = true
} else {
flog.Info("warning: %v is not a directory", sshDirectory)
}
}
if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
"the owner (and files inside should be set to 0600)", sshDirectory)
}
}

// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
// only happens on the initial connection.
if !o.noReuseConnection {
newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, o.sshFlags, sshControlPath)

// -MN means "start a master socket and don't open a session, just connect".
sshCmdStr := fmt.Sprintf(`exec ssh %v -MN %v`, newSSHFlags, host)
sshMasterCmd := exec.Command("sh", "-c", sshCmdStr)
sshMasterCmd.Stdin = os.Stdin
sshMasterCmd.Stdout = os.Stdout
sshMasterCmd.Stderr = os.Stderr
stopSSHMaster := func() {
if sshMasterCmd.Process != nil {
err := sshMasterCmd.Process.Signal(syscall.Signal(0))
if err != nil {
return
}
err = sshMasterCmd.Process.Signal(syscall.SIGTERM)
if err != nil {
flog.Error("failed to send SIGTERM to SSH master process: %v", err)
}
}
}
defer stopSSHMaster()

err = sshMasterCmd.Start()
go sshMasterCmd.Wait()
if err != nil {
flog.Error("failed to start SSH master connection, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
stopSSHMaster()
} else {
err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
if err != nil {
flog.Error("SSH master failed to be ready in time, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
stopSSHMaster()
} else {
sshMasterCmd.Stdin = nil
o.sshFlags = newSSHFlags
}
}
}

flog.Info("ensuring code-server is updated...")
dlScript := downloadScript(codeServerPath)

// Downloads the latest code-server and allows it to be executed.
Expand Down Expand Up @@ -147,26 +217,43 @@ func sshCode(host, dir string, o options) error {
case <-c:
}

flog.Info("shutting down")
if !o.syncBack || o.skipSync {
flog.Info("shutting down")
return nil
}

flog.Info("synchronizing VS Code back to local")

err = syncExtensions(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync extensions back: %w", err)
return xerrors.Errorf("failed to sync extensions back: %v", err)
}

err = syncUserSettings(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync user settings settings back: %w", err)
return xerrors.Errorf("failed to sync user settings settings back: %v", err)
}

return nil
}

// expandPath returns an expanded version of path.
func expandPath(path string) string {
path = filepath.Clean(os.ExpandEnv(path))

// Replace tilde notation in path with the home directory.
homedir := os.Getenv("HOME")
if homedir != "" {
if path == "~" {
path = homedir
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(homedir, path[2:])
}
}

return filepath.Clean(path)
}

func parseBindAddr(bindAddr string) (string, error) {
if bindAddr == "" {
bindAddr = ":"
Expand Down Expand Up @@ -263,6 +350,36 @@ func randomPort() (string, error) {
return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

// checkSSHMaster polls every second for 30 seconds to check if the SSH master
// is ready.
func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error {
var (
maxTries = 30
sleepDur = time.Second
err error
)
for i := 0; i < maxTries; i++ {
// Check if the master is running
if sshMasterCmd.Process == nil {
return xerrors.Errorf("SSH master process not running")
}
err = sshMasterCmd.Process.Signal(syscall.Signal(0))
if err != nil {
return xerrors.Errorf("failed to check if SSH master process was alive: %v", err)
}

// Check if it's ready
sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
sshCmd := exec.Command("sh", "-c", sshCmdStr)
err = sshCmd.Run()
if err == nil {
return nil
}
time.Sleep(sleepDur)
}
return xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

func syncUserSettings(sshFlags string, host string, back bool) error {
localConfDir, err := configDir()
if err != nil {
Expand Down