diff --git a/.gitignore b/.gitignore index 6eec620..dc0daa9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ vendor bin .vscode +sshcode diff --git a/main.go b/main.go index a09df53..bc0de4a 100644 --- a/main.go +++ b/main.go @@ -37,6 +37,7 @@ type rootCmd struct { skipSync bool syncBack bool printVersion bool + bindAddr string sshFlags string } @@ -52,6 +53,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) { fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host") fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination") fl.BoolVar(&c.printVersion, "version", false, "print version information and exit") + fl.StringVar(&c.bindAddr, "bind", "", "local bind address for ssh tunnel") fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags") } @@ -76,6 +78,7 @@ func (c *rootCmd) Run(fl *flag.FlagSet) { err := sshCode(host, dir, options{ skipSync: c.skipSync, sshFlags: c.sshFlags, + bindAddr: c.bindAddr, syncBack: c.syncBack, }) diff --git a/sshcode.go b/sshcode.go index 1d539e5..f7a7b6b 100644 --- a/sshcode.go +++ b/sshcode.go @@ -25,7 +25,7 @@ type options struct { skipSync bool syncBack bool noOpen bool - localPort string + bindAddr string remotePort string sshFlags string } @@ -79,11 +79,9 @@ func sshCode(host, dir string, o options) error { flog.Info("starting code-server...") - if o.localPort == "" { - o.localPort, err = randomPort() - } + o.bindAddr, err = parseBindAddr(o.bindAddr) if err != nil { - return xerrors.Errorf("failed to find available local port: %w", err) + return xerrors.Errorf("failed to parse bind address: %w", err) } if o.remotePort == "" { @@ -93,11 +91,12 @@ func sshCode(host, dir string, o options) error { return xerrors.Errorf("failed to find available remote port: %w", err) } - flog.Info("Tunneling local port %v to remote port %v", o.localPort, o.remotePort) + flog.Info("Tunneling remote port %v to %v", o.remotePort, o.bindAddr) - sshCmdStr = fmt.Sprintf("ssh -tt -q -L %v %v %v 'cd %v; %v --host 127.0.0.1 --allow-http --no-auth --port=%v'", - o.localPort+":localhost:"+o.remotePort, o.sshFlags, host, dir, codeServerPath, o.remotePort, - ) + sshCmdStr = + fmt.Sprintf("ssh -tt -q -L %v:localhost:%v %v %v 'cd %v; %v --host 127.0.0.1 --allow-http --no-auth --port=%v'", + o.bindAddr, o.remotePort, o.sshFlags, host, dir, codeServerPath, o.remotePort, + ) // Starts code-server and forwards the remote port. sshCmd = exec.Command("sh", "-c", sshCmdStr) @@ -109,7 +108,7 @@ func sshCode(host, dir string, o options) error { return xerrors.Errorf("failed to start code-server: %w", err) } - url := "http://127.0.0.1:" + o.localPort + url := fmt.Sprintf("http://%s", o.bindAddr) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() @@ -168,6 +167,23 @@ func sshCode(host, dir string, o options) error { return nil } +func parseBindAddr(bindAddr string) (string, error) { + host, port, err := net.SplitHostPort(bindAddr) + if err != nil { + return "", err + } + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port, err = randomPort() + } + if err != nil { + return "", err + } + return net.JoinHostPort(host, port), nil +} + func openBrowser(url string) { var openCmd *exec.Cmd diff --git a/sshcode_test.go b/sshcode_test.go index b2ae89d..fc6eb7d 100644 --- a/sshcode_test.go +++ b/sshcode_test.go @@ -37,7 +37,7 @@ func TestSSHCode(t *testing.T) { defer wg.Done() err := sshCode("foo@127.0.0.1", "", options{ sshFlags: testSSHArgs(sshPort), - localPort: localPort, + bindAddr: net.JoinHostPort("127.0.0.1", localPort), remotePort: remotePort, noOpen: true, })