Skip to content

feat: add configuration options to support mtls #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ gradleVersion=7.4
# Opt-out flag for bundling Kotlin standard library.
# See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details.
# suppress inspection "UnusedProperty"
kotlin.stdlib.default.dependency=false
kotlin.stdlib.default.dependency=true
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we need the kotlin stdlib, without this I was getting:

java.lang.ClassNotFoundException: kotlin.enums.EnumEntries

I suspect this is required since the kotlin 9.x upgrade?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you see this error? I tried building with this set to false and running the resulting plugin in Gateway (2023.2.4) but so far all seems well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see when running ./gradlew build --info or equivalent command via IntelliJ. Strange that you don't see it also 🤔

Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider {
if (token == null) { // User aborted.
throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing")
}
val client = CoderRestClient(deploymentURL, token.first, settings.headerCommand, null)
val client = CoderRestClient(deploymentURL, token.first,null, settings)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val client = CoderRestClient(deploymentURL, token.first,null, settings)
val client = CoderRestClient(deploymentURL, token.first, null, settings)

return try {
Pair(client, client.me().username)
} catch (ex: AuthenticationResponseException) {
Expand Down
30 changes: 29 additions & 1 deletion src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
.comment(
CoderGatewayBundle.message(
"gateway.connector.settings.binary-source.comment",
CoderCLIManager(URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
CoderCLIManager(state, URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
)
)
}.layout(RowLayout.PARENT_GRID)
Expand Down Expand Up @@ -73,6 +73,34 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
CoderGatewayBundle.message("gateway.connector.settings.header-command.comment")
)
}.layout(RowLayout.PARENT_GRID)
row(CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.title")) {
textField().resizableColumn().align(AlignX.FILL)
.bindText(state::tlsCertPath)
.comment(
CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.comment")
)
}.layout(RowLayout.PARENT_GRID)
row(CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.title")) {
textField().resizableColumn().align(AlignX.FILL)
.bindText(state::tlsKeyPath)
.comment(
CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.comment")
)
}.layout(RowLayout.PARENT_GRID)
row(CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.title")) {
textField().resizableColumn().align(AlignX.FILL)
.bindText(state::tlsCAPath)
.comment(
CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.comment")
)
}.layout(RowLayout.PARENT_GRID)
row(CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.title")) {
textField().resizableColumn().align(AlignX.FILL)
.bindText(state::tlsAlternateHostname)
.comment(
CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.comment")
)
}.layout(RowLayout.PARENT_GRID)
}
}

Expand Down
10 changes: 8 additions & 2 deletions src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import java.nio.file.StandardCopyOption
import java.security.DigestInputStream
import java.security.MessageDigest
import java.util.zip.GZIPInputStream
import javax.net.ssl.HttpsURLConnection
import javax.xml.bind.annotation.adapters.HexBinaryAdapter


/**
* Manage the CLI for a single deployment.
*/
class CoderCLIManager @JvmOverloads constructor(
private val settings: CoderSettingsState,
private val deploymentURL: URL,
dataDir: Path,
cliDir: Path? = null,
Expand Down Expand Up @@ -104,6 +106,10 @@ class CoderCLIManager @JvmOverloads constructor(
conn.setRequestProperty("If-None-Match", "\"$etag\"")
}
conn.setRequestProperty("Accept-Encoding", "gzip")
if (conn is HttpsURLConnection) {
conn.sslSocketFactory = coderSocketFactory(settings)
conn.hostnameVerifier = CoderHostnameVerifier(settings.tlsAlternateHostname)
}

try {
conn.connect()
Expand Down Expand Up @@ -463,7 +469,7 @@ class CoderCLIManager @JvmOverloads constructor(
if (settings.binaryDirectory.isBlank()) null
else Path.of(settings.binaryDirectory).toAbsolutePath()

val cli = CoderCLIManager(deploymentURL, dataDir, binDir, settings.binarySource)
val cli = CoderCLIManager(settings, deploymentURL, dataDir, binDir, settings.binarySource)

// Short-circuit if we already have the expected version. This
// lets us bypass the 304 which is slower and may not be
Expand All @@ -490,7 +496,7 @@ class CoderCLIManager @JvmOverloads constructor(
}

// Try falling back to the data directory.
val dataCLI = CoderCLIManager(deploymentURL, dataDir, null, settings.binarySource)
val dataCLI = CoderCLIManager(settings, deploymentURL, dataDir, null, settings.binarySource)
val dataCLIMatches = dataCLI.matchesVersion(buildVersion)
if (dataCLIMatches == true) {
return dataCLI
Expand Down
242 changes: 237 additions & 5 deletions src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
import com.coder.gateway.sdk.v2.models.toAgentModels
import com.coder.gateway.services.CoderSettingsState
import com.google.gson.Gson
import com.google.gson.GsonBuilder
import com.intellij.ide.plugins.PluginManagerCore
import com.intellij.openapi.components.Service
import com.intellij.openapi.extensions.PluginId
import com.intellij.openapi.util.SystemInfo
import okhttp3.OkHttpClient
import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.logging.HttpLoggingInterceptor
import org.zeroturnaround.exec.ProcessExecutor
import retrofit2.Retrofit
import retrofit2.converter.gson.GsonConverterFactory
import java.io.File
import java.io.FileInputStream
import java.net.HttpURLConnection.HTTP_CREATED
import java.net.InetAddress
import java.net.Socket
import java.net.URL
import java.nio.file.Path
import java.security.KeyFactory
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import java.security.spec.InvalidKeySpecException
import java.security.spec.PKCS8EncodedKeySpec
import java.time.Instant
import java.util.Base64
import java.util.Locale
import java.util.UUID
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.TrustManager
import javax.net.ssl.X509TrustManager

@Service(Service.Level.APP)
class CoderRestClientService {
Expand All @@ -44,18 +71,19 @@ class CoderRestClientService {
*
* @throws [AuthenticationResponseException] if authentication failed.
*/
fun initClientSession(url: URL, token: String, headerCommand: String?): User {
client = CoderRestClient(url, token, headerCommand, null)
fun initClientSession(url: URL, token: String, settings: CoderSettingsState): User {
client = CoderRestClient(url, token, null, settings)
me = client.me()
buildVersion = client.buildInfo().version
isReady = true
return me
}
}

class CoderRestClient(var url: URL, var token: String,
private var headerCommand: String?,
class CoderRestClient(
var url: URL, var token: String,
private var pluginVersion: String?,
private var settings: CoderSettingsState,
) {
private var httpClient: OkHttpClient
private var retroRestClient: CoderV2RestFacade
Expand All @@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String,
pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml
}

val socketFactory = coderSocketFactory(settings)
val trustManagers = coderTrustManagers(settings.tlsCAPath)
httpClient = OkHttpClient.Builder()
.sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager)
.hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname))
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) }
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) }
.addInterceptor {
var request = it.request()
val headers = getHeaders(url, headerCommand)
val headers = getHeaders(url, settings.headerCommand)
if (headers.size > 0) {
val builder = request.newBuilder()
headers.forEach { h -> builder.addHeader(h.key, h.value) }
Expand Down Expand Up @@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String,
}
}
}

fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory {
if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) {
return SSLSocketFactory.getDefault() as SSLSocketFactory
}

val certificateFactory = CertificateFactory.getInstance("X.509")
val certInputStream = FileInputStream(expandPath(settings.tlsCertPath))
val certChain = certificateFactory.generateCertificates(certInputStream)
certInputStream.close()

// ideally we would use something like PemReader from BouncyCastle, but
// BC is used by the IDE. This makes using BC very impractical since
// type casting will mismatch due to the different class loaders.
val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText()
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----")
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start)
val pemBytes: ByteArray = Base64.getDecoder().decode(
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end)
.replace("\\s+".toRegex(), "")
)

var privateKey : PrivateKey
try {
Copy link
Member

@code-asher code-asher Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to mention you can do val privateKey = try which is pretty neat.

val kf = KeyFactory.getInstance("RSA")
val keySpec = PKCS8EncodedKeySpec(pemBytes)
privateKey = kf.generatePrivate(keySpec)
} catch (e: InvalidKeySpecException) {
val kf = KeyFactory.getInstance("EC")
val keySpec = PKCS8EncodedKeySpec(pemBytes)
privateKey = kf.generatePrivate(keySpec)
}

val keyStore = KeyStore.getInstance(KeyStore.getDefaultType())
keyStore.load(null)
certChain.withIndex().forEach {
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
}
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray())

val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore, null)

val sslContext = SSLContext.getInstance("TLS")

val trustManagers = coderTrustManagers(settings.tlsCAPath)
Copy link
Member

@code-asher code-asher Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a problem that we only use tlsCAPath if we did not return early above when tlsCertPath or tlsKeyPath are blank? Thinking of a use case were someone only sets tlsCAPath but not the others. I see we separately call coderTrustManagers() in the rest client so maybe it works there, but possibly the binary download could have issues since we are only using the socket factory there.

sslContext.init(keyManagerFactory.keyManagers, trustManagers, null)

if (settings.tlsAlternateHostname.isBlank()) {
return sslContext.socketFactory
}

return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname)
}

fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
if (tlsCAPath.isBlank()) {
// return default trust managers
trustManagerFactory.init(null as KeyStore?)
return trustManagerFactory.trustManagers
}


val certificateFactory = CertificateFactory.getInstance("X.509")
val caInputStream = FileInputStream(expandPath(tlsCAPath))
val certChain = certificateFactory.generateCertificates(caInputStream)

val truststore = KeyStore.getInstance(KeyStore.getDefaultType())
truststore.load(null)
certChain.withIndex().forEach {
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
}
trustManagerFactory.init(truststore)
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray()
}

fun expandPath(path: String): String {
if (path.startsWith("~/")) {
return Path.of(System.getProperty("user.home"), path.substring(1)).toString()
}
if (path.startsWith("\$HOME/")) {
return Path.of(System.getProperty("user.home"), path.substring(5)).toString()
}
if (path.startsWith("\${user.home}/")) {
return Path.of(System.getProperty("user.home"), path.substring(12)).toString()
}
return path
}

class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() {
override fun getDefaultCipherSuites(): Array<String> {
return delegate.defaultCipherSuites
}

override fun getSupportedCipherSuites(): Array<String> {
return delegate.supportedCipherSuites
}

override fun createSocket(): Socket {
val socket = delegate.createSocket() as SSLSocket
customizeSocket(socket)
return socket
}

override fun createSocket(host: String?, port: Int): Socket {
val socket = delegate.createSocket(host, port) as SSLSocket
customizeSocket(socket)
return socket
}

override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket {
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
customizeSocket(socket)
return socket
}

override fun createSocket(host: InetAddress?, port: Int): Socket {
val socket = delegate.createSocket(host, port) as SSLSocket
customizeSocket(socket)
return socket
}

override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket {
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
customizeSocket(socket)
return socket
}

override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket {
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
customizeSocket(socket)
return socket
}

private fun customizeSocket(socket: SSLSocket) {
val params = socket.sslParameters
params.serverNames = listOf(SNIHostName(alternateName))
socket.sslParameters = params
}
}

class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier {
override fun verify(host: String, session: SSLSession): Boolean {
if (alternateName.isEmpty()) {
println("using default hostname verifier, alternateName is empty")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should copy the logger pattern used elsewhere instead of using println directly.

return OkHostnameVerifier.verify(host, session)
}
println("Looking for alternate hostname: $alternateName")
val certs = session.peerCertificates ?: return false
for (cert in certs) {
if (cert !is X509Certificate) {
continue
}
val entries = cert.subjectAlternativeNames ?: continue
for (entry in entries) {
val kind = entry[0] as Int
if (kind != 2) { // DNS Name
continue
}
val hostname = entry[1] as String
println("Found cert hostname: $hostname")
if (hostname.lowercase(Locale.getDefault()) == alternateName) {
return true
}
}
}
println("No matching hostname found")
return false
}
}

class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager {
private val systemTrustManager : X509TrustManager
init {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(null as KeyStore?)
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
}

override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) {
try {
otherTrustManager.checkClientTrusted(chain, authType)
} catch (e: CertificateException) {
systemTrustManager.checkClientTrusted(chain, authType)
}
}

override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) {
try {
otherTrustManager.checkServerTrusted(chain, authType)
} catch (e: CertificateException) {
systemTrustManager.checkServerTrusted(chain, authType)
}
}

override fun getAcceptedIssuers(): Array<X509Certificate> {
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
}
}
Loading