-
Notifications
You must be signed in to change notification settings - Fork 16
feat: add configuration options to support mtls #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider { | |||||
if (token == null) { // User aborted. | ||||||
throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing") | ||||||
} | ||||||
val client = CoderRestClient(deploymentURL, token.first, settings.headerCommand, null) | ||||||
val client = CoderRestClient(deploymentURL, token.first,null, settings) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return try { | ||||||
Pair(client, client.me().username) | ||||||
} catch (ex: AuthenticationResponseException) { | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace | |
import com.coder.gateway.sdk.v2.models.WorkspaceBuild | ||
import com.coder.gateway.sdk.v2.models.WorkspaceTransition | ||
import com.coder.gateway.sdk.v2.models.toAgentModels | ||
import com.coder.gateway.services.CoderSettingsState | ||
import com.google.gson.Gson | ||
import com.google.gson.GsonBuilder | ||
import com.intellij.ide.plugins.PluginManagerCore | ||
import com.intellij.openapi.components.Service | ||
import com.intellij.openapi.extensions.PluginId | ||
import com.intellij.openapi.util.SystemInfo | ||
import okhttp3.OkHttpClient | ||
import okhttp3.internal.tls.OkHostnameVerifier | ||
import okhttp3.logging.HttpLoggingInterceptor | ||
import org.zeroturnaround.exec.ProcessExecutor | ||
import retrofit2.Retrofit | ||
import retrofit2.converter.gson.GsonConverterFactory | ||
import java.io.File | ||
import java.io.FileInputStream | ||
import java.net.HttpURLConnection.HTTP_CREATED | ||
import java.net.InetAddress | ||
import java.net.Socket | ||
import java.net.URL | ||
import java.nio.file.Path | ||
import java.security.KeyFactory | ||
import java.security.KeyStore | ||
import java.security.PrivateKey | ||
import java.security.cert.CertificateException | ||
import java.security.cert.CertificateFactory | ||
import java.security.cert.X509Certificate | ||
import java.security.spec.InvalidKeySpecException | ||
import java.security.spec.PKCS8EncodedKeySpec | ||
import java.time.Instant | ||
import java.util.Base64 | ||
import java.util.Locale | ||
import java.util.UUID | ||
import javax.net.ssl.HostnameVerifier | ||
import javax.net.ssl.KeyManagerFactory | ||
import javax.net.ssl.SNIHostName | ||
import javax.net.ssl.SSLContext | ||
import javax.net.ssl.SSLSession | ||
import javax.net.ssl.SSLSocket | ||
import javax.net.ssl.SSLSocketFactory | ||
import javax.net.ssl.TrustManagerFactory | ||
import javax.net.ssl.TrustManager | ||
import javax.net.ssl.X509TrustManager | ||
|
||
@Service(Service.Level.APP) | ||
class CoderRestClientService { | ||
|
@@ -44,18 +71,19 @@ class CoderRestClientService { | |
* | ||
* @throws [AuthenticationResponseException] if authentication failed. | ||
*/ | ||
fun initClientSession(url: URL, token: String, headerCommand: String?): User { | ||
client = CoderRestClient(url, token, headerCommand, null) | ||
fun initClientSession(url: URL, token: String, settings: CoderSettingsState): User { | ||
client = CoderRestClient(url, token, null, settings) | ||
me = client.me() | ||
buildVersion = client.buildInfo().version | ||
isReady = true | ||
return me | ||
} | ||
} | ||
|
||
class CoderRestClient(var url: URL, var token: String, | ||
private var headerCommand: String?, | ||
class CoderRestClient( | ||
var url: URL, var token: String, | ||
private var pluginVersion: String?, | ||
private var settings: CoderSettingsState, | ||
) { | ||
private var httpClient: OkHttpClient | ||
private var retroRestClient: CoderV2RestFacade | ||
|
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String, | |
pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml | ||
} | ||
|
||
val socketFactory = coderSocketFactory(settings) | ||
val trustManagers = coderTrustManagers(settings.tlsCAPath) | ||
httpClient = OkHttpClient.Builder() | ||
.sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager) | ||
.hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname)) | ||
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) } | ||
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) } | ||
.addInterceptor { | ||
var request = it.request() | ||
val headers = getHeaders(url, headerCommand) | ||
val headers = getHeaders(url, settings.headerCommand) | ||
if (headers.size > 0) { | ||
val builder = request.newBuilder() | ||
headers.forEach { h -> builder.addHeader(h.key, h.value) } | ||
|
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String, | |
} | ||
} | ||
} | ||
|
||
fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory { | ||
if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) { | ||
return SSLSocketFactory.getDefault() as SSLSocketFactory | ||
} | ||
|
||
val certificateFactory = CertificateFactory.getInstance("X.509") | ||
val certInputStream = FileInputStream(expandPath(settings.tlsCertPath)) | ||
val certChain = certificateFactory.generateCertificates(certInputStream) | ||
certInputStream.close() | ||
|
||
// ideally we would use something like PemReader from BouncyCastle, but | ||
// BC is used by the IDE. This makes using BC very impractical since | ||
// type casting will mismatch due to the different class loaders. | ||
val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText() | ||
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----") | ||
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start) | ||
val pemBytes: ByteArray = Base64.getDecoder().decode( | ||
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end) | ||
.replace("\\s+".toRegex(), "") | ||
) | ||
|
||
var privateKey : PrivateKey | ||
try { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forgot to mention you can do |
||
val kf = KeyFactory.getInstance("RSA") | ||
val keySpec = PKCS8EncodedKeySpec(pemBytes) | ||
privateKey = kf.generatePrivate(keySpec) | ||
} catch (e: InvalidKeySpecException) { | ||
val kf = KeyFactory.getInstance("EC") | ||
val keySpec = PKCS8EncodedKeySpec(pemBytes) | ||
privateKey = kf.generatePrivate(keySpec) | ||
} | ||
|
||
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()) | ||
keyStore.load(null) | ||
certChain.withIndex().forEach { | ||
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) | ||
} | ||
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray()) | ||
|
||
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) | ||
keyManagerFactory.init(keyStore, null) | ||
|
||
val sslContext = SSLContext.getInstance("TLS") | ||
|
||
val trustManagers = coderTrustManagers(settings.tlsCAPath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a problem that we only use |
||
sslContext.init(keyManagerFactory.keyManagers, trustManagers, null) | ||
|
||
if (settings.tlsAlternateHostname.isBlank()) { | ||
return sslContext.socketFactory | ||
} | ||
|
||
return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname) | ||
} | ||
|
||
fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> { | ||
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) | ||
if (tlsCAPath.isBlank()) { | ||
// return default trust managers | ||
trustManagerFactory.init(null as KeyStore?) | ||
return trustManagerFactory.trustManagers | ||
} | ||
|
||
|
||
val certificateFactory = CertificateFactory.getInstance("X.509") | ||
val caInputStream = FileInputStream(expandPath(tlsCAPath)) | ||
val certChain = certificateFactory.generateCertificates(caInputStream) | ||
|
||
val truststore = KeyStore.getInstance(KeyStore.getDefaultType()) | ||
truststore.load(null) | ||
certChain.withIndex().forEach { | ||
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) | ||
} | ||
trustManagerFactory.init(truststore) | ||
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray() | ||
} | ||
|
||
fun expandPath(path: String): String { | ||
if (path.startsWith("~/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(1)).toString() | ||
} | ||
if (path.startsWith("\$HOME/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(5)).toString() | ||
} | ||
if (path.startsWith("\${user.home}/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(12)).toString() | ||
} | ||
return path | ||
} | ||
|
||
class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() { | ||
override fun getDefaultCipherSuites(): Array<String> { | ||
return delegate.defaultCipherSuites | ||
} | ||
|
||
override fun getSupportedCipherSuites(): Array<String> { | ||
return delegate.supportedCipherSuites | ||
} | ||
|
||
override fun createSocket(): Socket { | ||
val socket = delegate.createSocket() as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
override fun createSocket(host: String?, port: Int): Socket { | ||
val socket = delegate.createSocket(host, port) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket { | ||
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
override fun createSocket(host: InetAddress?, port: Int): Socket { | ||
val socket = delegate.createSocket(host, port) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket { | ||
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket { | ||
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
|
||
private fun customizeSocket(socket: SSLSocket) { | ||
val params = socket.sslParameters | ||
params.serverNames = listOf(SNIHostName(alternateName)) | ||
socket.sslParameters = params | ||
} | ||
} | ||
|
||
class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier { | ||
override fun verify(host: String, session: SSLSession): Boolean { | ||
if (alternateName.isEmpty()) { | ||
println("using default hostname verifier, alternateName is empty") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should copy the logger pattern used elsewhere instead of using |
||
return OkHostnameVerifier.verify(host, session) | ||
} | ||
println("Looking for alternate hostname: $alternateName") | ||
val certs = session.peerCertificates ?: return false | ||
for (cert in certs) { | ||
if (cert !is X509Certificate) { | ||
continue | ||
} | ||
val entries = cert.subjectAlternativeNames ?: continue | ||
for (entry in entries) { | ||
val kind = entry[0] as Int | ||
if (kind != 2) { // DNS Name | ||
continue | ||
} | ||
val hostname = entry[1] as String | ||
println("Found cert hostname: $hostname") | ||
if (hostname.lowercase(Locale.getDefault()) == alternateName) { | ||
return true | ||
} | ||
} | ||
} | ||
println("No matching hostname found") | ||
return false | ||
} | ||
} | ||
|
||
class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager { | ||
private val systemTrustManager : X509TrustManager | ||
init { | ||
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) | ||
trustManagerFactory.init(null as KeyStore?) | ||
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager | ||
} | ||
|
||
override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) { | ||
try { | ||
otherTrustManager.checkClientTrusted(chain, authType) | ||
} catch (e: CertificateException) { | ||
systemTrustManager.checkClientTrusted(chain, authType) | ||
} | ||
} | ||
|
||
override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) { | ||
try { | ||
otherTrustManager.checkServerTrusted(chain, authType) | ||
} catch (e: CertificateException) { | ||
systemTrustManager.checkServerTrusted(chain, authType) | ||
} | ||
} | ||
|
||
override fun getAcceptedIssuers(): Array<X509Certificate> { | ||
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we need the kotlin stdlib, without this I was getting:
I suspect this is required since the kotlin 9.x upgrade?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do you see this error? I tried building with this set to
false
and running the resulting plugin in Gateway (2023.2.4) but so far all seems well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see when running
./gradlew build --info
or equivalent command via IntelliJ. Strange that you don't see it also 🤔