diff --git a/gradle.properties b/gradle.properties index 92b80e02..fad3afe0 100644 --- a/gradle.properties +++ b/gradle.properties @@ -29,4 +29,4 @@ gradleVersion=7.4 # Opt-out flag for bundling Kotlin standard library. # See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details. # suppress inspection "UnusedProperty" -kotlin.stdlib.default.dependency=false +kotlin.stdlib.default.dependency=true diff --git a/src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt b/src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt index e6acb61c..b19e9f7a 100644 --- a/src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt +++ b/src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt @@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider { if (token == null) { // User aborted. throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing") } - val client = CoderRestClient(deploymentURL, token.first, settings.headerCommand, null) + val client = CoderRestClient(deploymentURL, token.first,null, settings) return try { Pair(client, client.me().username) } catch (ex: AuthenticationResponseException) { diff --git a/src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt b/src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt index c92a2d71..e73482a6 100644 --- a/src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt +++ b/src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt @@ -39,7 +39,7 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") { .comment( CoderGatewayBundle.message( "gateway.connector.settings.binary-source.comment", - CoderCLIManager(URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path, + CoderCLIManager(state, URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path, ) ) }.layout(RowLayout.PARENT_GRID) @@ -73,6 +73,34 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") { CoderGatewayBundle.message("gateway.connector.settings.header-command.comment") ) }.layout(RowLayout.PARENT_GRID) + row(CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.title")) { + textField().resizableColumn().align(AlignX.FILL) + .bindText(state::tlsCertPath) + .comment( + CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.comment") + ) + }.layout(RowLayout.PARENT_GRID) + row(CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.title")) { + textField().resizableColumn().align(AlignX.FILL) + .bindText(state::tlsKeyPath) + .comment( + CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.comment") + ) + }.layout(RowLayout.PARENT_GRID) + row(CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.title")) { + textField().resizableColumn().align(AlignX.FILL) + .bindText(state::tlsCAPath) + .comment( + CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.comment") + ) + }.layout(RowLayout.PARENT_GRID) + row(CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.title")) { + textField().resizableColumn().align(AlignX.FILL) + .bindText(state::tlsAlternateHostname) + .comment( + CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.comment") + ) + }.layout(RowLayout.PARENT_GRID) } } diff --git a/src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt b/src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt index bc9e31eb..57705bec 100644 --- a/src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt +++ b/src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt @@ -22,6 +22,7 @@ import java.nio.file.StandardCopyOption import java.security.DigestInputStream import java.security.MessageDigest import java.util.zip.GZIPInputStream +import javax.net.ssl.HttpsURLConnection import javax.xml.bind.annotation.adapters.HexBinaryAdapter @@ -29,6 +30,7 @@ import javax.xml.bind.annotation.adapters.HexBinaryAdapter * Manage the CLI for a single deployment. */ class CoderCLIManager @JvmOverloads constructor( + private val settings: CoderSettingsState, private val deploymentURL: URL, dataDir: Path, cliDir: Path? = null, @@ -104,6 +106,10 @@ class CoderCLIManager @JvmOverloads constructor( conn.setRequestProperty("If-None-Match", "\"$etag\"") } conn.setRequestProperty("Accept-Encoding", "gzip") + if (conn is HttpsURLConnection) { + conn.sslSocketFactory = coderSocketFactory(settings) + conn.hostnameVerifier = CoderHostnameVerifier(settings.tlsAlternateHostname) + } try { conn.connect() @@ -463,7 +469,7 @@ class CoderCLIManager @JvmOverloads constructor( if (settings.binaryDirectory.isBlank()) null else Path.of(settings.binaryDirectory).toAbsolutePath() - val cli = CoderCLIManager(deploymentURL, dataDir, binDir, settings.binarySource) + val cli = CoderCLIManager(settings, deploymentURL, dataDir, binDir, settings.binarySource) // Short-circuit if we already have the expected version. This // lets us bypass the 304 which is slower and may not be @@ -490,7 +496,7 @@ class CoderCLIManager @JvmOverloads constructor( } // Try falling back to the data directory. - val dataCLI = CoderCLIManager(deploymentURL, dataDir, null, settings.binarySource) + val dataCLI = CoderCLIManager(settings, deploymentURL, dataDir, null, settings.binarySource) val dataCLIMatches = dataCLI.matchesVersion(buildVersion) if (dataCLIMatches == true) { return dataCLI diff --git a/src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt b/src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt index 829d76fe..7d84a639 100644 --- a/src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt +++ b/src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt @@ -14,6 +14,7 @@ import com.coder.gateway.sdk.v2.models.Workspace import com.coder.gateway.sdk.v2.models.WorkspaceBuild import com.coder.gateway.sdk.v2.models.WorkspaceTransition import com.coder.gateway.sdk.v2.models.toAgentModels +import com.coder.gateway.services.CoderSettingsState import com.google.gson.Gson import com.google.gson.GsonBuilder import com.intellij.ide.plugins.PluginManagerCore @@ -21,14 +22,40 @@ import com.intellij.openapi.components.Service import com.intellij.openapi.extensions.PluginId import com.intellij.openapi.util.SystemInfo import okhttp3.OkHttpClient +import okhttp3.internal.tls.OkHostnameVerifier import okhttp3.logging.HttpLoggingInterceptor import org.zeroturnaround.exec.ProcessExecutor import retrofit2.Retrofit import retrofit2.converter.gson.GsonConverterFactory +import java.io.File +import java.io.FileInputStream import java.net.HttpURLConnection.HTTP_CREATED +import java.net.InetAddress +import java.net.Socket import java.net.URL +import java.nio.file.Path +import java.security.KeyFactory +import java.security.KeyStore +import java.security.PrivateKey +import java.security.cert.CertificateException +import java.security.cert.CertificateFactory +import java.security.cert.X509Certificate +import java.security.spec.InvalidKeySpecException +import java.security.spec.PKCS8EncodedKeySpec import java.time.Instant +import java.util.Base64 +import java.util.Locale import java.util.UUID +import javax.net.ssl.HostnameVerifier +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SNIHostName +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLSession +import javax.net.ssl.SSLSocket +import javax.net.ssl.SSLSocketFactory +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.TrustManager +import javax.net.ssl.X509TrustManager @Service(Service.Level.APP) class CoderRestClientService { @@ -44,8 +71,8 @@ class CoderRestClientService { * * @throws [AuthenticationResponseException] if authentication failed. */ - fun initClientSession(url: URL, token: String, headerCommand: String?): User { - client = CoderRestClient(url, token, headerCommand, null) + fun initClientSession(url: URL, token: String, settings: CoderSettingsState): User { + client = CoderRestClient(url, token, null, settings) me = client.me() buildVersion = client.buildInfo().version isReady = true @@ -53,9 +80,10 @@ class CoderRestClientService { } } -class CoderRestClient(var url: URL, var token: String, - private var headerCommand: String?, +class CoderRestClient( + var url: URL, var token: String, private var pluginVersion: String?, + private var settings: CoderSettingsState, ) { private var httpClient: OkHttpClient private var retroRestClient: CoderV2RestFacade @@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String, pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml } + val socketFactory = coderSocketFactory(settings) + val trustManagers = coderTrustManagers(settings.tlsCAPath) httpClient = OkHttpClient.Builder() + .sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager) + .hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname)) .addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) } .addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) } .addInterceptor { var request = it.request() - val headers = getHeaders(url, headerCommand) + val headers = getHeaders(url, settings.headerCommand) if (headers.size > 0) { val builder = request.newBuilder() headers.forEach { h -> builder.addHeader(h.key, h.value) } @@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String, } } } + +fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory { + if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) { + return SSLSocketFactory.getDefault() as SSLSocketFactory + } + + val certificateFactory = CertificateFactory.getInstance("X.509") + val certInputStream = FileInputStream(expandPath(settings.tlsCertPath)) + val certChain = certificateFactory.generateCertificates(certInputStream) + certInputStream.close() + + // ideally we would use something like PemReader from BouncyCastle, but + // BC is used by the IDE. This makes using BC very impractical since + // type casting will mismatch due to the different class loaders. + val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText() + val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----") + val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start) + val pemBytes: ByteArray = Base64.getDecoder().decode( + privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end) + .replace("\\s+".toRegex(), "") + ) + + var privateKey : PrivateKey + try { + val kf = KeyFactory.getInstance("RSA") + val keySpec = PKCS8EncodedKeySpec(pemBytes) + privateKey = kf.generatePrivate(keySpec) + } catch (e: InvalidKeySpecException) { + val kf = KeyFactory.getInstance("EC") + val keySpec = PKCS8EncodedKeySpec(pemBytes) + privateKey = kf.generatePrivate(keySpec) + } + + val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()) + keyStore.load(null) + certChain.withIndex().forEach { + keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) + } + keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray()) + + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + keyManagerFactory.init(keyStore, null) + + val sslContext = SSLContext.getInstance("TLS") + + val trustManagers = coderTrustManagers(settings.tlsCAPath) + sslContext.init(keyManagerFactory.keyManagers, trustManagers, null) + + if (settings.tlsAlternateHostname.isBlank()) { + return sslContext.socketFactory + } + + return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname) +} + +fun coderTrustManagers(tlsCAPath: String) : Array { + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + if (tlsCAPath.isBlank()) { + // return default trust managers + trustManagerFactory.init(null as KeyStore?) + return trustManagerFactory.trustManagers + } + + + val certificateFactory = CertificateFactory.getInstance("X.509") + val caInputStream = FileInputStream(expandPath(tlsCAPath)) + val certChain = certificateFactory.generateCertificates(caInputStream) + + val truststore = KeyStore.getInstance(KeyStore.getDefaultType()) + truststore.load(null) + certChain.withIndex().forEach { + truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) + } + trustManagerFactory.init(truststore) + return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray() +} + +fun expandPath(path: String): String { + if (path.startsWith("~/")) { + return Path.of(System.getProperty("user.home"), path.substring(1)).toString() + } + if (path.startsWith("\$HOME/")) { + return Path.of(System.getProperty("user.home"), path.substring(5)).toString() + } + if (path.startsWith("\${user.home}/")) { + return Path.of(System.getProperty("user.home"), path.substring(12)).toString() + } + return path +} + +class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() { + override fun getDefaultCipherSuites(): Array { + return delegate.defaultCipherSuites + } + + override fun getSupportedCipherSuites(): Array { + return delegate.supportedCipherSuites + } + + override fun createSocket(): Socket { + val socket = delegate.createSocket() as SSLSocket + customizeSocket(socket) + return socket + } + + override fun createSocket(host: String?, port: Int): Socket { + val socket = delegate.createSocket(host, port) as SSLSocket + customizeSocket(socket) + return socket + } + + override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket { + val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket + customizeSocket(socket) + return socket + } + + override fun createSocket(host: InetAddress?, port: Int): Socket { + val socket = delegate.createSocket(host, port) as SSLSocket + customizeSocket(socket) + return socket + } + + override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket { + val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket + customizeSocket(socket) + return socket + } + + override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket { + val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket + customizeSocket(socket) + return socket + } + + private fun customizeSocket(socket: SSLSocket) { + val params = socket.sslParameters + params.serverNames = listOf(SNIHostName(alternateName)) + socket.sslParameters = params + } +} + +class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier { + override fun verify(host: String, session: SSLSession): Boolean { + if (alternateName.isEmpty()) { + println("using default hostname verifier, alternateName is empty") + return OkHostnameVerifier.verify(host, session) + } + println("Looking for alternate hostname: $alternateName") + val certs = session.peerCertificates ?: return false + for (cert in certs) { + if (cert !is X509Certificate) { + continue + } + val entries = cert.subjectAlternativeNames ?: continue + for (entry in entries) { + val kind = entry[0] as Int + if (kind != 2) { // DNS Name + continue + } + val hostname = entry[1] as String + println("Found cert hostname: $hostname") + if (hostname.lowercase(Locale.getDefault()) == alternateName) { + return true + } + } + } + println("No matching hostname found") + return false + } +} + +class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager { + private val systemTrustManager : X509TrustManager + init { + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(null as KeyStore?) + systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager + } + + override fun checkClientTrusted(chain: Array, authType: String?) { + try { + otherTrustManager.checkClientTrusted(chain, authType) + } catch (e: CertificateException) { + systemTrustManager.checkClientTrusted(chain, authType) + } + } + + override fun checkServerTrusted(chain: Array, authType: String?) { + try { + otherTrustManager.checkServerTrusted(chain, authType) + } catch (e: CertificateException) { + systemTrustManager.checkServerTrusted(chain, authType) + } + } + + override fun getAcceptedIssuers(): Array { + return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/coder/gateway/services/CoderSettingsState.kt b/src/main/kotlin/com/coder/gateway/services/CoderSettingsState.kt index e75a6ef9..0f2ab9e4 100644 --- a/src/main/kotlin/com/coder/gateway/services/CoderSettingsState.kt +++ b/src/main/kotlin/com/coder/gateway/services/CoderSettingsState.kt @@ -19,6 +19,10 @@ class CoderSettingsState : PersistentStateComponent { var enableDownloads: Boolean = true var enableBinaryDirectoryFallback: Boolean = false var headerCommand: String = "" + var tlsCertPath: String = "" + var tlsKeyPath: String = "" + var tlsCAPath: String = "" + var tlsAlternateHostname: String = "" override fun getState(): CoderSettingsState { return this } diff --git a/src/main/kotlin/com/coder/gateway/views/CoderGatewayRecentWorkspaceConnectionsView.kt b/src/main/kotlin/com/coder/gateway/views/CoderGatewayRecentWorkspaceConnectionsView.kt index 4e5a5fcf..af2b8837 100644 --- a/src/main/kotlin/com/coder/gateway/views/CoderGatewayRecentWorkspaceConnectionsView.kt +++ b/src/main/kotlin/com/coder/gateway/views/CoderGatewayRecentWorkspaceConnectionsView.kt @@ -256,7 +256,7 @@ class CoderGatewayRecentWorkspaceConnectionsView(private val setContentCallback: deployments[dir] ?: try { val url = Path.of(dir).resolve("url").readText() val token = Path.of(dir).resolve("session").readText() - DeploymentInfo(CoderRestClient(url.toURL(), token, settings.headerCommand, null)) + DeploymentInfo(CoderRestClient(url.toURL(), token,null, settings)) } catch (e: Exception) { logger.error("Unable to create client from $dir", e) DeploymentInfo(error = "Error trying to read $dir: ${e.message}") diff --git a/src/main/kotlin/com/coder/gateway/views/steps/CoderWorkspacesStepView.kt b/src/main/kotlin/com/coder/gateway/views/steps/CoderWorkspacesStepView.kt index 2ddefd4e..9eb2be94 100644 --- a/src/main/kotlin/com/coder/gateway/views/steps/CoderWorkspacesStepView.kt +++ b/src/main/kotlin/com/coder/gateway/views/steps/CoderWorkspacesStepView.kt @@ -533,7 +533,7 @@ class CoderWorkspacesStepView(val setNextButtonEnabled: (Boolean) -> Unit) : Cod */ private fun authenticate(url: URL, token: String) { logger.info("Authenticating to $url...") - clientService.initClientSession(url, token, settings.headerCommand) + clientService.initClientSession(url, token, settings) try { logger.info("Checking compatibility with Coder version ${clientService.buildVersion}...") diff --git a/src/main/resources/messages/CoderGatewayBundle.properties b/src/main/resources/messages/CoderGatewayBundle.properties index 6e93c557..19941750 100644 --- a/src/main/resources/messages/CoderGatewayBundle.properties +++ b/src/main/resources/messages/CoderGatewayBundle.properties @@ -93,3 +93,21 @@ gateway.connector.settings.header-command.comment=An external command that \ outputs additional HTTP headers added to all requests. The command must \ output each header as `key=value` on its own line. The following \ environment variables will be available to the process: CODER_URL. +gateway.connector.settings.tls-cert-path.title=Cert Path: +gateway.connector.settings.tls-cert-path.comment=Optionally set this to \ + the path of a certificate to use for TLS connections. The certificate \ + should be in X.509 PEM format. +gateway.connector.settings.tls-key-path.title=Key Path: +gateway.connector.settings.tls-key-path.comment=Optionally set this to \ + the path of the private key that corresponds to the above cert path to use \ + for TLS connections. The key should be in X.509 PEM format. +gateway.connector.settings.tls-ca-path.title=CA Path: +gateway.connector.settings.tls-ca-path.comment=Optionally set this to \ + the path of a file containing certificates for an alternate certificate \ + authority used to verify TLS certs returned by the Coder service. \ + The file should be in X.509 PEM format. +gateway.connector.settings.tls-alt-name.title=Alt Hostname: +gateway.connector.settings.tls-alt-name.comment=Optionally set this to \ + an alternate hostname used for verifying TLS connections. This is useful \ + when the hostname used to connect to the Coder service does not match the \ + hostname in the TLS certificate. diff --git a/src/test/groovy/CoderCLIManagerTest.groovy b/src/test/groovy/CoderCLIManagerTest.groovy index c6b7e1a6..139e71dc 100644 --- a/src/test/groovy/CoderCLIManagerTest.groovy +++ b/src/test/groovy/CoderCLIManagerTest.groovy @@ -19,6 +19,7 @@ import java.security.MessageDigest class CoderCLIManagerTest extends Specification { @Shared private Path tmpdir = Path.of(System.getProperty("java.io.tmpdir")).resolve("coder-gateway-test/cli-manager") + private CoderSettingsState settings = new CoderSettingsState() /** * Create, start, and return a server that mocks Coder. @@ -82,7 +83,7 @@ class CoderCLIManagerTest extends Specification { def "uses a sub-directory"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid"), tmpdir) expect: ccm.localBinaryPath.getParent() == tmpdir.resolve("test.coder.invalid") @@ -90,7 +91,7 @@ class CoderCLIManagerTest extends Specification { def "includes port in sub-directory if included"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid:3000"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid:3000"), tmpdir) expect: ccm.localBinaryPath.getParent() == tmpdir.resolve("test.coder.invalid-3000") @@ -98,7 +99,7 @@ class CoderCLIManagerTest extends Specification { def "encodes IDN with punycode"() { given: - def ccm = new CoderCLIManager(new URL("https://test.😉.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.😉.invalid"), tmpdir) expect: ccm.localBinaryPath.getParent() == tmpdir.resolve("test.xn--n28h.invalid") @@ -107,7 +108,7 @@ class CoderCLIManagerTest extends Specification { def "fails to download"() { given: def (srv, url) = mockServer(HttpURLConnection.HTTP_INTERNAL_ERROR) - def ccm = new CoderCLIManager(new URL(url), tmpdir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir) when: ccm.downloadCLI() @@ -125,7 +126,7 @@ class CoderCLIManagerTest extends Specification { given: def (srv, url) = mockServer() def dir = tmpdir.resolve("cli-dir-fallver") - def ccm = new CoderCLIManager(new URL(url), tmpdir, dir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir, dir) Files.createDirectories(ccm.localBinaryPath.getParent()) ccm.localBinaryPath.parent.toFile().setWritable(false) @@ -148,7 +149,7 @@ class CoderCLIManagerTest extends Specification { if (url == null) { url = "https://dev.coder.com" } - def ccm = new CoderCLIManager(new URL(url), tmpdir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir) ccm.localBinaryPath.getParent().toFile().deleteDir() when: @@ -170,7 +171,7 @@ class CoderCLIManagerTest extends Specification { def "downloads a mocked cli"() { given: def (srv, url) = mockServer() - def ccm = new CoderCLIManager(new URL(url), tmpdir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir) ccm.localBinaryPath.getParent().toFile().deleteDir() when: @@ -189,7 +190,7 @@ class CoderCLIManagerTest extends Specification { def "fails to run non-existent binary"() { given: - def ccm = new CoderCLIManager(new URL("https://foo"), tmpdir.resolve("does-not-exist")) + def ccm = new CoderCLIManager(settings, new URL("https://foo"), tmpdir.resolve("does-not-exist")) when: ccm.login("token") @@ -201,7 +202,7 @@ class CoderCLIManagerTest extends Specification { def "overwrites cli if incorrect version"() { given: def (srv, url) = mockServer() - def ccm = new CoderCLIManager(new URL(url), tmpdir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir) Files.createDirectories(ccm.localBinaryPath.getParent()) ccm.localBinaryPath.toFile().write("cli") ccm.localBinaryPath.toFile().setLastModified(0) @@ -222,7 +223,7 @@ class CoderCLIManagerTest extends Specification { def "skips cli download if it already exists"() { given: def (srv, url) = mockServer() - def ccm = new CoderCLIManager(new URL(url), tmpdir) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir) when: def downloaded1 = ccm.downloadCLI() @@ -243,8 +244,8 @@ class CoderCLIManagerTest extends Specification { setup: def (srv1, url1) = mockServer() def (srv2, url2) = mockServer() - def ccm1 = new CoderCLIManager(new URL(url1), tmpdir) - def ccm2 = new CoderCLIManager(new URL(url2), tmpdir) + def ccm1 = new CoderCLIManager(settings, new URL(url1), tmpdir) + def ccm2 = new CoderCLIManager(settings, new URL(url2), tmpdir) when: ccm1.downloadCLI() @@ -263,7 +264,7 @@ class CoderCLIManagerTest extends Specification { def "overrides binary URL"() { given: def (srv, url) = mockServer() - def ccm = new CoderCLIManager(new URL(url), tmpdir, null, override.replace("{{url}}", url)) + def ccm = new CoderCLIManager(settings, new URL(url), tmpdir, null, override.replace("{{url}}", url)) when: def downloaded = ccm.downloadCLI() @@ -398,7 +399,7 @@ class CoderCLIManagerTest extends Specification { def "configures an SSH file"() { given: def sshConfigPath = tmpdir.resolve(input + "_to_" + output + ".conf") - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir, null, null, sshConfigPath) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid"), tmpdir, null, null, sshConfigPath) if (input != null) { Files.createDirectories(sshConfigPath.getParent()) def originalConf = Path.of("src/test/fixtures/inputs").resolve(input + ".conf").toFile().text @@ -445,7 +446,7 @@ class CoderCLIManagerTest extends Specification { def "fails if config is malformed"() { given: def sshConfigPath = tmpdir.resolve("configured" + input + ".conf") - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir, null, null, sshConfigPath) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid"), tmpdir, null, null, sshConfigPath) Files.createDirectories(sshConfigPath.getParent()) Files.copy( Path.of("src/test/fixtures/inputs").resolve(input + ".conf"), @@ -470,7 +471,7 @@ class CoderCLIManagerTest extends Specification { def "fails if header command is malformed"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid"), tmpdir) when: ccm.configSsh(["foo", "bar"].collect { DataGen.workspaceAgentModel(it) }, headerCommand) @@ -487,7 +488,7 @@ class CoderCLIManagerTest extends Specification { @IgnoreIf({ os.windows }) def "parses version"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings,new URL("https://test.coder.invalid"), tmpdir) Files.createDirectories(ccm.localBinaryPath.parent) when: @@ -506,7 +507,7 @@ class CoderCLIManagerTest extends Specification { @IgnoreIf({ os.windows }) def "fails to parse version"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.parse-fail.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.parse-fail.invalid"), tmpdir) Files.createDirectories(ccm.localBinaryPath.parent) when: @@ -532,7 +533,7 @@ class CoderCLIManagerTest extends Specification { @IgnoreIf({ os.windows }) def "checks if version matches"() { given: - def ccm = new CoderCLIManager(new URL("https://test.coder.version-matches.invalid"), tmpdir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.version-matches.invalid"), tmpdir) Files.createDirectories(ccm.localBinaryPath.parent) when: @@ -567,7 +568,7 @@ class CoderCLIManagerTest extends Specification { def "separately configures cli path from data dir"() { given: def dir = tmpdir.resolve("cli-dir") - def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir, dir) + def ccm = new CoderCLIManager(settings, new URL("https://test.coder.invalid"), tmpdir, dir) expect: ccm.localBinaryPath.getParent() == dir.resolve("test.coder.invalid") @@ -585,11 +586,10 @@ class CoderCLIManagerTest extends Specification { def (srv, url) = mockServer() def dataDir = tmpdir.resolve("data-dir") def binDir = tmpdir.resolve("bin-dir") - def mainCCM = new CoderCLIManager(new URL(url), dataDir, binDir) - def fallbackCCM = new CoderCLIManager(new URL(url), dataDir) + def mainCCM = new CoderCLIManager(settings, new URL(url), dataDir, binDir) + def fallbackCCM = new CoderCLIManager(settings, new URL(url), dataDir) when: - def settings = new CoderSettingsState() settings.binaryDirectory = binDir.toAbsolutePath() settings.dataDirectory = dataDir.toAbsolutePath() settings.enableDownloads = download diff --git a/src/test/groovy/CoderRestClientTest.groovy b/src/test/groovy/CoderRestClientTest.groovy index cadea39a..493640df 100644 --- a/src/test/groovy/CoderRestClientTest.groovy +++ b/src/test/groovy/CoderRestClientTest.groovy @@ -4,6 +4,7 @@ import com.coder.gateway.sdk.convertors.InstantConverter import com.coder.gateway.sdk.v2.models.Workspace import com.coder.gateway.sdk.v2.models.WorkspaceResource import com.coder.gateway.sdk.v2.models.WorkspacesResponse +import com.coder.gateway.services.CoderSettingsState import com.google.gson.GsonBuilder import com.sun.net.httpserver.HttpExchange import com.sun.net.httpserver.HttpHandler @@ -18,6 +19,7 @@ import java.time.Instant @Unroll class CoderRestClientTest extends Specification { + private CoderSettingsState settings = new CoderSettingsState() /** * Create, start, and return a server that mocks the Coder API. * @@ -63,7 +65,7 @@ class CoderRestClientTest extends Specification { def "gets workspaces"() { given: def (srv, url) = mockServer(workspaces) - def client = new CoderRestClient(new URL(url), "token", null, "test") + def client = new CoderRestClient(new URL(url), "token", "test", settings) expect: client.workspaces()*.name == expected @@ -81,7 +83,7 @@ class CoderRestClientTest extends Specification { def "gets resources"() { given: def (srv, url) = mockServer(workspaces, resources) - def client = new CoderRestClient(new URL(url), "token", null, "test") + def client = new CoderRestClient(new URL(url), "token", "test", settings) expect: client.agents(workspaces).collect { it.agentID.toString() } == expected