@@ -14,21 +14,47 @@ import com.coder.gateway.sdk.v2.models.Workspace
14
14
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
15
15
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
16
16
import com.coder.gateway.sdk.v2.models.toAgentModels
17
+ import com.coder.gateway.services.CoderSettingsState
17
18
import com.google.gson.Gson
18
19
import com.google.gson.GsonBuilder
19
20
import com.intellij.ide.plugins.PluginManagerCore
20
21
import com.intellij.openapi.components.Service
22
+ import com.intellij.openapi.components.service
21
23
import com.intellij.openapi.extensions.PluginId
22
24
import com.intellij.openapi.util.SystemInfo
23
25
import okhttp3.OkHttpClient
26
+ import okhttp3.internal.tls.OkHostnameVerifier
24
27
import okhttp3.logging.HttpLoggingInterceptor
25
28
import org.zeroturnaround.exec.ProcessExecutor
26
29
import retrofit2.Retrofit
27
30
import retrofit2.converter.gson.GsonConverterFactory
31
+ import java.io.File
32
+ import java.io.FileInputStream
28
33
import java.net.HttpURLConnection.HTTP_CREATED
34
+ import java.net.InetAddress
35
+ import java.net.Socket
29
36
import java.net.URL
37
+ import java.security.KeyFactory
38
+ import java.security.KeyStore
39
+ import java.security.PrivateKey
40
+ import java.security.cert.CertificateFactory
41
+ import java.security.cert.X509Certificate
42
+ import java.security.spec.InvalidKeySpecException
43
+ import java.security.spec.PKCS8EncodedKeySpec
30
44
import java.time.Instant
45
+ import java.util.Base64
46
+ import java.util.Locale
31
47
import java.util.UUID
48
+ import javax.net.ssl.HostnameVerifier
49
+ import javax.net.ssl.KeyManagerFactory
50
+ import javax.net.ssl.SNIHostName
51
+ import javax.net.ssl.SSLContext
52
+ import javax.net.ssl.SSLSession
53
+ import javax.net.ssl.SSLSocket
54
+ import javax.net.ssl.SSLSocketFactory
55
+ import javax.net.ssl.TrustManagerFactory
56
+ import javax.net.ssl.TrustManager
57
+ import javax.net.ssl.X509TrustManager
32
58
33
59
@Service(Service .Level .APP )
34
60
class CoderRestClientService {
@@ -66,7 +92,11 @@ class CoderRestClient(var url: URL, var token: String,
66
92
pluginVersion = PluginManagerCore .getPlugin(PluginId .getId(" com.coder.gateway" ))!! .version // this is the id from the plugin.xml
67
93
}
68
94
95
+ val socketFactory = coderSocketFactory()
96
+ val trustManagers = coderTrustManagers()
69
97
httpClient = OkHttpClient .Builder ()
98
+ .sslSocketFactory(socketFactory, trustManagers[0 ] as X509TrustManager )
99
+ .hostnameVerifier(CoderHostnameVerifier ())
70
100
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" Coder-Session-Token" , token).build()) }
71
101
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" User-Agent" , " Coder Gateway/${pluginVersion} (${SystemInfo .getOsNameAndVersion()} ; ${SystemInfo .OS_ARCH } )" ).build()) }
72
102
.addInterceptor {
@@ -218,3 +248,168 @@ class CoderRestClient(var url: URL, var token: String,
218
248
}
219
249
}
220
250
}
251
+
252
+ fun coderSocketFactory () : SSLSocketFactory {
253
+ val state: CoderSettingsState = service()
254
+
255
+ if (state.tlsCertPath.isBlank() || state.tlsKeyPath.isBlank()) {
256
+ return SSLSocketFactory .getDefault() as SSLSocketFactory
257
+ }
258
+
259
+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
260
+ val certInputStream = FileInputStream (state.tlsCertPath)
261
+ val certChain = certificateFactory.generateCertificates(certInputStream)
262
+ certInputStream.close()
263
+
264
+ // ideally we would use something like PemReader from BouncyCastle, but
265
+ // BC is used by the IDE. This makes using BC very impractical since
266
+ // type casting will mismatch due to the different class loaders.
267
+ val privateKeyPem = File (state.tlsKeyPath).readText()
268
+ val start: Int = privateKeyPem.indexOf(" -----BEGIN PRIVATE KEY-----" )
269
+ val end: Int = privateKeyPem.indexOf(" -----END PRIVATE KEY-----" , start)
270
+ val pemBytes: ByteArray = Base64 .getDecoder().decode(
271
+ privateKeyPem.substring(start + " -----BEGIN PRIVATE KEY-----" .length, end)
272
+ .replace(" \\ s+" .toRegex(), " " )
273
+ )
274
+
275
+ var privateKey : PrivateKey
276
+ try {
277
+ val kf = KeyFactory .getInstance(" RSA" )
278
+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
279
+ privateKey = kf.generatePrivate(keySpec)
280
+ } catch (e: InvalidKeySpecException ) {
281
+ val kf = KeyFactory .getInstance(" EC" )
282
+ val keySpec = PKCS8EncodedKeySpec (pemBytes)
283
+ privateKey = kf.generatePrivate(keySpec)
284
+ }
285
+
286
+ val keyStore = KeyStore .getInstance(KeyStore .getDefaultType())
287
+ keyStore.load(null )
288
+ certChain.withIndex().forEach {
289
+ keyStore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
290
+ }
291
+ keyStore.setKeyEntry(" key" , privateKey, null , certChain.toTypedArray())
292
+
293
+ val keyManagerFactory = KeyManagerFactory .getInstance(KeyManagerFactory .getDefaultAlgorithm())
294
+ keyManagerFactory.init (keyStore, null )
295
+
296
+ val sslContext = SSLContext .getInstance(" TLS" )
297
+
298
+ val trustManagers = coderTrustManagers()
299
+ sslContext.init (keyManagerFactory.keyManagers, trustManagers, null )
300
+
301
+ if (state.tlsAlternateHostname.isBlank()) {
302
+ return sslContext.socketFactory
303
+ }
304
+
305
+ return AlternateNameSSLSocketFactory (sslContext.socketFactory, state.tlsAlternateHostname)
306
+ }
307
+
308
+ fun coderTrustManagers () : Array <TrustManager > {
309
+ val state: CoderSettingsState = service()
310
+
311
+ val trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
312
+ if (state.tlsCAPath.isBlank()) {
313
+ // return default trust managers
314
+ trustManagerFactory.init (null as KeyStore ? )
315
+ return trustManagerFactory.trustManagers
316
+ }
317
+
318
+
319
+ val certificateFactory = CertificateFactory .getInstance(" X.509" )
320
+ val caInputStream = FileInputStream (state.tlsCAPath)
321
+ val certChain = certificateFactory.generateCertificates(caInputStream)
322
+
323
+ val truststore = KeyStore .getInstance(KeyStore .getDefaultType())
324
+ truststore.load(null )
325
+ certChain.withIndex().forEach {
326
+ truststore.setCertificateEntry(" cert${it.index} " , it.value as X509Certificate )
327
+ }
328
+ trustManagerFactory.init (truststore)
329
+ return trustManagerFactory.trustManagers
330
+ }
331
+
332
+ class AlternateNameSSLSocketFactory (private val delegate : SSLSocketFactory , private val alternateName : String ) : SSLSocketFactory() {
333
+ override fun getDefaultCipherSuites (): Array <String > {
334
+ return delegate.defaultCipherSuites
335
+ }
336
+
337
+ override fun getSupportedCipherSuites (): Array <String > {
338
+ return delegate.supportedCipherSuites
339
+ }
340
+
341
+ override fun createSocket (): Socket {
342
+ val socket = delegate.createSocket() as SSLSocket
343
+ customizeSocket(socket)
344
+ return socket
345
+ }
346
+
347
+ override fun createSocket (host : String? , port : Int ): Socket {
348
+ val socket = delegate.createSocket(host, port) as SSLSocket
349
+ customizeSocket(socket)
350
+ return socket
351
+ }
352
+
353
+ override fun createSocket (host : String? , port : Int , localHost : InetAddress ? , localPort : Int ): Socket {
354
+ val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
355
+ customizeSocket(socket)
356
+ return socket
357
+ }
358
+
359
+ override fun createSocket (host : InetAddress ? , port : Int ): Socket {
360
+ val socket = delegate.createSocket(host, port) as SSLSocket
361
+ customizeSocket(socket)
362
+ return socket
363
+ }
364
+
365
+ override fun createSocket (address : InetAddress ? , port : Int , localAddress : InetAddress ? , localPort : Int ): Socket {
366
+ val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
367
+ customizeSocket(socket)
368
+ return socket
369
+ }
370
+
371
+ override fun createSocket (s : Socket ? , host : String? , port : Int , autoClose : Boolean ): Socket {
372
+ val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
373
+ customizeSocket(socket)
374
+ return socket
375
+ }
376
+
377
+ private fun customizeSocket (socket : SSLSocket ) {
378
+ val params = socket.sslParameters
379
+ params.serverNames = listOf (SNIHostName (alternateName))
380
+ socket.sslParameters = params
381
+ }
382
+ }
383
+
384
+ class CoderHostnameVerifier () : HostnameVerifier {
385
+ private val alternateName: String
386
+
387
+ init {
388
+ val state: CoderSettingsState = service()
389
+ this .alternateName = state.tlsAlternateHostname.lowercase(Locale .getDefault())
390
+ }
391
+
392
+ override fun verify (host : String , session : SSLSession ): Boolean {
393
+ if (alternateName.isEmpty()) {
394
+ return OkHostnameVerifier .verify(host, session)
395
+ }
396
+ val certs = session.peerCertificates ? : return false
397
+ for (cert in certs) {
398
+ if (cert !is X509Certificate ) {
399
+ continue
400
+ }
401
+ val entries = cert.subjectAlternativeNames ? : continue
402
+ for (entry in entries) {
403
+ val kind = entry[0 ] as Int
404
+ if (kind != 2 ) { // DNS Name
405
+ continue
406
+ }
407
+ val hostname = entry[1 ] as String
408
+ if (hostname.lowercase(Locale .getDefault()) == alternateName) {
409
+ return true
410
+ }
411
+ }
412
+ }
413
+ return false
414
+ }
415
+ }
0 commit comments