package main

import (
	"context"
	"fmt"
	"math/rand"
	"net"
	"net/http"
	"os"
	"os/exec"
	"os/signal"
	"path/filepath"
	"strconv"
	"strings"
	"syscall"
	"time"

	"github.com/pkg/browser"
	"go.coder.com/flog"
	"golang.org/x/xerrors"
)

const codeServerPath = "~/.cache/sshcode/sshcode-server"

const (
	sshDirectory               = "~/.ssh"
	sshDirectoryUnsafeModeMask = 0022
	sshControlPath             = sshDirectory + "/control-%h-%p-%r"
)

type options struct {
	skipSync        bool
	syncBack        bool
	noOpen          bool
	reuseConnection bool
	bindAddr        string
	remotePort      string
	sshFlags        string
}

func sshCode(host, dir string, o options) error {
	host, extraSSHFlags, err := parseHost(host)
	if err != nil {
		return xerrors.Errorf("failed to parse host IP: %w", err)
	}
	if extraSSHFlags != "" {
		o.sshFlags = strings.Join([]string{extraSSHFlags, o.sshFlags}, " ")
	}

	o.bindAddr, err = parseBindAddr(o.bindAddr)
	if err != nil {
		return xerrors.Errorf("failed to parse bind address: %w", err)
	}

	if o.remotePort == "" {
		o.remotePort, err = randomPort()
	}
	if err != nil {
		return xerrors.Errorf("failed to find available remote port: %w", err)
	}

	// Check the SSH directory's permissions and warn the user if it is not safe.
	o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection)

	// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
	// only happens on the initial connection.
	if o.reuseConnection {
		flog.Info("starting SSH master connection...")
		newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host)
		defer cancel()
		if err != nil {
			flog.Error("failed to start SSH master connection: %v", err)
			o.reuseConnection = false
		} else {
			o.sshFlags = newSSHFlags
		}
	}

	flog.Info("ensuring code-server is updated...")
	dlScript := downloadScript(codeServerPath)

	// Downloads the latest code-server and allows it to be executed.
	sshCmdStr := fmt.Sprintf("ssh %v %v '/usr/bin/env bash -l'", o.sshFlags, host)

	sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr)
	sshCmd.Stdout = os.Stdout
	sshCmd.Stderr = os.Stderr
	sshCmd.Stdin = strings.NewReader(dlScript)
	err = sshCmd.Run()
	if err != nil {
		return xerrors.Errorf("failed to update code-server: \n---ssh cmd---\n%s\n---download script---\n%s: %w",
			sshCmdStr,
			dlScript,
			err,
		)
	}

	if !o.skipSync {
		start := time.Now()
		flog.Info("syncing settings")
		err = syncUserSettings(o.sshFlags, host, false)
		if err != nil {
			return xerrors.Errorf("failed to sync settings: %w", err)
		}

		flog.Info("synced settings in %s", time.Since(start))

		flog.Info("syncing extensions")
		err = syncExtensions(o.sshFlags, host, false)
		if err != nil {
			return xerrors.Errorf("failed to sync extensions: %w", err)
		}
		flog.Info("synced extensions in %s", time.Since(start))
	}

	flog.Info("starting code-server...")

	flog.Info("Tunneling remote port %v to %v", o.remotePort, o.bindAddr)

	sshCmdStr =
		fmt.Sprintf("ssh -tt -q -L %v:localhost:%v %v %v 'cd %v; %v --host 127.0.0.1 --allow-http --no-auth --port=%v'",
			o.bindAddr, o.remotePort, o.sshFlags, host, dir, codeServerPath, o.remotePort,
		)

	// Starts code-server and forwards the remote port.
	sshCmd = exec.Command("sh", "-l", "-c", sshCmdStr)
	sshCmd.Stdin = os.Stdin
	sshCmd.Stdout = os.Stdout
	sshCmd.Stderr = os.Stderr
	err = sshCmd.Start()
	if err != nil {
		return xerrors.Errorf("failed to start code-server: %w", err)
	}

	url := fmt.Sprintf("http://%s", o.bindAddr)
	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
	defer cancel()

	client := http.Client{
		Timeout: time.Second * 3,
	}
	for {
		if ctx.Err() != nil {
			return xerrors.Errorf("code-server didn't start in time: %w", ctx.Err())
		}
		// Waits for code-server to be available before opening the browser.
		resp, err := client.Get(url)
		if err != nil {
			continue
		}
		resp.Body.Close()
		break
	}

	ctx, cancel = context.WithCancel(context.Background())

	if !o.noOpen {
		openBrowser(url)
	}

	go func() {
		defer cancel()
		sshCmd.Wait()
	}()

	c := make(chan os.Signal)
	signal.Notify(c, os.Interrupt)

	select {
	case <-ctx.Done():
	case <-c:
	}

	flog.Info("shutting down")
	if !o.syncBack || o.skipSync {
		return nil
	}

	flog.Info("synchronizing VS Code back to local")

	err = syncExtensions(o.sshFlags, host, true)
	if err != nil {
		return xerrors.Errorf("failed to sync extensions back: %w", err)
	}

	err = syncUserSettings(o.sshFlags, host, true)
	if err != nil {
		return xerrors.Errorf("failed to sync user settings settings back: %w", err)
	}

	return nil
}

// expandPath returns an expanded version of path.
func expandPath(path string) string {
	path = filepath.Clean(os.ExpandEnv(path))

	// Replace tilde notation in path with the home directory. You can't replace the first instance of `~` in the
	// string with the homedir as having a tilde in the middle of a filename is valid.
	homedir := os.Getenv("HOME")
	if homedir != "" {
		if path == "~" {
			path = homedir
		} else if strings.HasPrefix(path, "~/") {
			path = filepath.Join(homedir, path[2:])
		}
	}

	return filepath.Clean(path)
}

func parseBindAddr(bindAddr string) (string, error) {
	if !strings.Contains(bindAddr, ":") {
		bindAddr += ":"
	}

	host, port, err := net.SplitHostPort(bindAddr)
	if err != nil {
		return "", err
	}

	if host == "" {
		host = "127.0.0.1"
	}

	if port == "" {
		port, err = randomPort()
	}
	if err != nil {
		return "", err
	}

	return net.JoinHostPort(host, port), nil
}

func openBrowser(url string) {
	var openCmd *exec.Cmd

	const (
		macPath = "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
		wslPath = "/mnt/c/Program Files (x86)/Google/Chrome/Application/chrome.exe"
	)

	switch {
	case commandExists("google-chrome"):
		openCmd = exec.Command("google-chrome", chromeOptions(url)...)
	case commandExists("google-chrome-stable"):
		openCmd = exec.Command("google-chrome-stable", chromeOptions(url)...)
	case commandExists("chromium"):
		openCmd = exec.Command("chromium", chromeOptions(url)...)
	case commandExists("chromium-browser"):
		openCmd = exec.Command("chromium-browser", chromeOptions(url)...)
	case pathExists(macPath):
		openCmd = exec.Command(macPath, chromeOptions(url)...)
	case pathExists(wslPath):
		openCmd = exec.Command(wslPath, chromeOptions(url)...)
	default:
		err := browser.OpenURL(url)
		if err != nil {
			flog.Error("failed to open browser: %v", err)
		}
		return
	}

	// We do not use CombinedOutput because if there is no chrome instance, this will block
	// and become the parent process instead of using an existing chrome instance.
	err := openCmd.Start()
	if err != nil {
		flog.Error("failed to open browser: %v", err)
	}
}

func chromeOptions(url string) []string {
	return []string{"--app=" + url, "--disable-extensions", "--disable-plugins", "--incognito"}
}

// Checks if a command exists locally.
func commandExists(name string) bool {
	_, err := exec.LookPath(name)
	return err == nil
}

func pathExists(name string) bool {
	_, err := os.Stat(name)
	return err == nil
}

// randomPort picks a random port to start code-server on.
func randomPort() (string, error) {
	const (
		minPort  = 1024
		maxPort  = 65535
		maxTries = 10
	)
	for i := 0; i < maxTries; i++ {
		port := rand.Intn(maxPort-minPort+1) + minPort
		l, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
		if err == nil {
			_ = l.Close()
			return strconv.Itoa(port), nil
		}
		flog.Info("port taken: %d", port)
	}

	return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

// checkSSHDirectory performs sanity and safety checks on sshDirectory, and
// returns a new value for o.reuseConnection depending on the checks.
func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool {
	sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
	if err != nil {
		if reuseConnection {
			flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
		}
		reuseConnection = false
	} else {
		if !sshDirectoryMode.IsDir() {
			if reuseConnection {
				flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
			} else {
				flog.Info("warning: %v is not a directory", sshDirectory)
			}
			reuseConnection = false
		}
		if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
			flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
				"the owner (and files inside should be set to 0600)", sshDirectory)
		}
	}
	return reuseConnection
}

// startSSHMaster starts an SSH master connection and waits for it to be ready.
// It returns a new set of SSH flags for child SSH processes to use.
func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) {
	ctx, cancel := context.WithCancel(context.Background())

	newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath)

	// -MN means "start a master socket and don't open a session, just connect".
	sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host)
	sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr)
	sshMasterCmd.Stdin = os.Stdin
	sshMasterCmd.Stderr = os.Stderr

	// Gracefully stop the SSH master.
	stopSSHMaster := func() {
		if sshMasterCmd.Process != nil {
			if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() {
				return
			}
			err := sshMasterCmd.Process.Signal(syscall.SIGTERM)
			if err != nil {
				flog.Error("failed to send SIGTERM to SSH master process: %v", err)
			}
		}
		cancel()
	}

	// Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before
	// sshcode does, and allows sshMasterCmd.ProcessState to be populated.
	err := sshMasterCmd.Start()
	go sshMasterCmd.Wait()
	if err != nil {
		return "", stopSSHMaster, err
	}
	err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
	if err != nil {
		stopSSHMaster()
		return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err)
	}
	return newSSHFlags, stopSSHMaster, nil
}

// checkSSHMaster polls every second for 30 seconds to check if the SSH master
// is ready.
func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error {
	var (
		maxTries = 30
		sleepDur = time.Second
		err      error
	)
	for i := 0; i < maxTries; i++ {
		// Check if the master is running.
		if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) {
			return xerrors.Errorf("SSH master process is not running")
		}

		// Check if it's ready.
		sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
		sshCmd := exec.Command("sh", "-c", sshCmdStr)
		err = sshCmd.Run()
		if err == nil {
			return nil
		}
		time.Sleep(sleepDur)
	}
	return xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

func syncUserSettings(sshFlags string, host string, back bool) error {
	localConfDir, err := configDir()
	if err != nil {
		return err
	}

	err = ensureDir(localConfDir)
	if err != nil {
		return err
	}

	const remoteSettingsDir = "~/.local/share/code-server/User/"

	var (
		src  = localConfDir + "/"
		dest = host + ":" + remoteSettingsDir
	)

	if back {
		dest, src = src, dest
	}

	// Append "/" to have rsync copy the contents of the dir.
	return rsync(src, dest, sshFlags, "workspaceStorage", "logs", "CachedData")
}

func syncExtensions(sshFlags string, host string, back bool) error {
	localExtensionsDir, err := extensionsDir()
	if err != nil {
		return err
	}

	err = ensureDir(localExtensionsDir)
	if err != nil {
		return err
	}

	const remoteExtensionsDir = "~/.local/share/code-server/extensions/"

	var (
		src  = localExtensionsDir + "/"
		dest = host + ":" + remoteExtensionsDir
	)
	if back {
		dest, src = src, dest
	}

	return rsync(src, dest, sshFlags)
}

func rsync(src string, dest string, sshFlags string, excludePaths ...string) error {
	excludeFlags := make([]string, len(excludePaths))
	for i, path := range excludePaths {
		excludeFlags[i] = "--exclude=" + path
	}

	cmd := exec.Command("rsync", append(excludeFlags, "-azvr",
		"-e", "ssh "+sshFlags,
		// Only update newer directories, and sync times
		// to keep things simple.
		"-u", "--times",
		// This is more unsafe, but it's obnoxious having to enter VS Code
		// locally in order to properly delete an extension.
		"--delete",
		"--copy-unsafe-links",
		src, dest,
	)...,
	)
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	err := cmd.Run()
	if err != nil {
		return xerrors.Errorf("failed to rsync '%s' to '%s': %w", src, dest, err)
	}

	return nil
}

func downloadScript(codeServerPath string) string {
	return fmt.Sprintf(
		`set -euxo pipefail || exit 1

[ "$(uname -m)" != "x86_64" ] && echo "Unsupported server architecture $(uname -m). code-server only has releases for x86_64 systems." && exit 1
pkill -f %v || true
mkdir -p ~/.local/share/code-server %v
cd %v
wget -N https://codesrv-ci.cdr.sh/latest-linux
[ -f %v ] && rm %v
ln latest-linux %v
chmod +x %v`,
		codeServerPath,
		filepath.Dir(codeServerPath),
		filepath.Dir(codeServerPath),
		codeServerPath,
		codeServerPath,
		codeServerPath,
		codeServerPath,
	)
}

// ensureDir creates a directory if it does not exist.
func ensureDir(path string) error {
	_, err := os.Stat(path)
	if os.IsNotExist(err) {
		err = os.MkdirAll(path, 0750)
	}

	if err != nil {
		return err
	}

	return nil
}

// parseHost parses the host argument. If 'gcp:' is prefixed to the
// host then a lookup is done using gcloud to determine the external IP and any
// additional SSH arguments that should be used for ssh commands. Otherwise, host
// is returned.
func parseHost(host string) (parsedHost string, additionalFlags string, err error) {
	host = strings.TrimSpace(host)
	switch {
	case strings.HasPrefix(host, "gcp:"):
		instance := strings.TrimPrefix(host, "gcp:")
		return parseGCPSSHCmd(instance)
	default:
		return host, "", nil
	}
}

// parseGCPSSHCmd parses the IP address and flags used by 'gcloud' when
// ssh'ing to an instance.
func parseGCPSSHCmd(instance string) (ip, sshFlags string, err error) {
	dryRunCmd := fmt.Sprintf("gcloud compute ssh --dry-run %v", instance)

	out, err := exec.Command("sh", "-l", "-c", dryRunCmd).CombinedOutput()
	if err != nil {
		return "", "", xerrors.Errorf("%s: %w", out, err)
	}

	toks := strings.Split(string(out), " ")
	if len(toks) < 2 {
		return "", "", xerrors.Errorf("unexpected output for '%v' command, %s", dryRunCmd, out)
	}

	// Slice off the '/usr/bin/ssh' prefix and the '<user>@<ip>' suffix.
	sshFlags = strings.Join(toks[1:len(toks)-1], " ")

	// E.g. foo@1.2.3.4.
	userIP := toks[len(toks)-1]
	toks = strings.Split(userIP, "@")
	// Assume the '<user>@' is missing.
	if len(toks) < 2 {
		ip = strings.TrimSpace(toks[0])
	} else {
		ip = strings.TrimSpace(toks[1])
	}

	if net.ParseIP(ip) == nil {
		return "", "", xerrors.Errorf("parsed invalid ip address %v", ip)
	}

	return ip, sshFlags, nil
}