diff --git a/driver/src/main/java/org/neo4j/driver/internal/security/SecurityPlan.java b/driver/src/main/java/org/neo4j/driver/internal/security/SecurityPlan.java index af10095338..66d5f7999c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/security/SecurityPlan.java +++ b/driver/src/main/java/org/neo4j/driver/internal/security/SecurityPlan.java @@ -19,13 +19,6 @@ package org.neo4j.driver.internal.security; -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.v1.*; - -import javax.net.ssl.KeyManager; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; import java.io.File; import java.io.IOException; import java.security.GeneralSecurityException; @@ -33,6 +26,11 @@ import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; import static org.neo4j.driver.internal.util.CertificateTool.loadX509Cert; @@ -68,13 +66,34 @@ public static SecurityPlan forSystemCertificates() throws NoSuchAlgorithmExcepti } - public static SecurityPlan forTrustOnFirstUse( File knownHosts, BoltServerAddress address, Logger logger ) + public static SecurityPlan forTrustOnFirstUse( File knownHosts ) throws IOException, KeyManagementException, NoSuchAlgorithmException { - SSLContext sslContext = SSLContext.getInstance( "TLS" ); - sslContext.init( new KeyManager[0], new TrustManager[]{new TrustOnFirstUseTrustManager( address, knownHosts, logger )}, null ); + ConcurrentHashMap preLoadKnownHostsMap = TrustOnFirstUseTrustManager.createKnownHostsMap( knownHosts ); + return new TrustOnFirstUseSecurityPlan( knownHosts, preLoadKnownHostsMap ); + } - return new SecurityPlan( true, sslContext); + public static class TrustOnFirstUseSecurityPlan extends SecurityPlan + { + private final ConcurrentHashMap trustOnFirstUseMap; + private final File knownHostsFile; + + private TrustOnFirstUseSecurityPlan( File knownHostsFile, ConcurrentHashMap trustOnFirstUseMap ) + { + super( true, null ); + this.trustOnFirstUseMap = trustOnFirstUseMap; + this.knownHostsFile = knownHostsFile; + } + + public ConcurrentHashMap trustOnFirstUseMap() + { + return this.trustOnFirstUseMap; + } + + public File knownHostFile() + { + return this.knownHostsFile; + } } public static SecurityPlan insecure() @@ -85,7 +104,7 @@ public static SecurityPlan insecure() private final boolean requiresEncryption; private final SSLContext sslContext; - private SecurityPlan( boolean requiresEncryption, SSLContext sslContext) + private SecurityPlan( boolean requiresEncryption, SSLContext sslContext ) { this.requiresEncryption = requiresEncryption; this.sslContext = sslContext; diff --git a/driver/src/main/java/org/neo4j/driver/internal/security/TLSSocketChannel.java b/driver/src/main/java/org/neo4j/driver/internal/security/TLSSocketChannel.java index e0c4c88383..8e296a96dd 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/security/TLSSocketChannel.java +++ b/driver/src/main/java/org/neo4j/driver/internal/security/TLSSocketChannel.java @@ -22,18 +22,18 @@ import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.security.GeneralSecurityException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; import javax.net.ssl.KeyManager; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.TrustManager; import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.v1.Logger; -import org.neo4j.driver.internal.util.BytePrinter; import org.neo4j.driver.internal.util.BytePrinter; -import org.neo4j.driver.v1.Config.TrustStrategy; import org.neo4j.driver.v1.Logger; import org.neo4j.driver.v1.exceptions.ClientException; @@ -70,7 +70,7 @@ public class TLSSocketChannel implements ByteChannel public TLSSocketChannel( BoltServerAddress address, SecurityPlan securityPlan, ByteChannel channel, Logger logger ) throws GeneralSecurityException, IOException { - this( channel, logger, createSSLEngine( address, securityPlan.sslContext() ) ); + this( channel, logger, createSSLEngine( address, securityPlan, logger ) ); } public TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws GeneralSecurityException, IOException @@ -356,10 +356,23 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to ) /** * Create SSLEngine with the SSLContext just created. * @param address the host to connect to - * @param sslContext the current ssl context + * @param securityPlan the security plan which holds the current ssl context + * @param logger the logger */ - private static SSLEngine createSSLEngine( BoltServerAddress address, SSLContext sslContext ) + private static SSLEngine createSSLEngine( BoltServerAddress address, SecurityPlan securityPlan, Logger logger ) + throws IOException, KeyManagementException, NoSuchAlgorithmException { + SSLContext sslContext = securityPlan.sslContext(); + if( securityPlan instanceof SecurityPlan.TrustOnFirstUseSecurityPlan ) + { + // It require a new sslContext for each connection + sslContext = SSLContext.getInstance( "TLS" ); + SecurityPlan.TrustOnFirstUseSecurityPlan plan = (SecurityPlan.TrustOnFirstUseSecurityPlan) securityPlan; + sslContext.init( new KeyManager[0], new TrustManager[]{ + new TrustOnFirstUseTrustManager( address.toString(), plan.knownHostFile(), plan.trustOnFirstUseMap(), logger )}, + null ); + } + SSLEngine sslEngine = sslContext.createSSLEngine( address.host(), address.port() ); sslEngine.setUseClientMode( true ); return sslEngine; diff --git a/driver/src/main/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManager.java b/driver/src/main/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManager.java index e1606d70d4..e535703e23 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManager.java +++ b/driver/src/main/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManager.java @@ -28,11 +28,12 @@ import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import javax.net.ssl.X509TrustManager; -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.v1.Logger; import org.neo4j.driver.internal.util.BytePrinter; +import org.neo4j.driver.v1.Logger; import static java.lang.String.format; import static org.neo4j.driver.internal.util.CertificateTool.X509CertToString; @@ -50,21 +51,27 @@ public class TrustOnFirstUseTrustManager implements X509TrustManager * Then when we try to connect to a known server again, we will authenticate the server by checking if it provides * the same certificate as the one saved in this file. */ - private final File knownHosts; + private final File knownHostsFile; - /** The server ip:port (in digits) of the server that we are currently connected to */ - private final String serverId; + /** The map of server ip:port (in digits) and its known certificate we've registered */ + private final ConcurrentHashMap knownHosts; private final Logger logger; + private final String serverId; - /** The known certificate we've registered for this server */ - private String fingerprint; - - TrustOnFirstUseTrustManager( BoltServerAddress address, File knownHosts, Logger logger ) throws IOException + TrustOnFirstUseTrustManager( String serverId, File knownHosts, ConcurrentHashMap preLoadedKnownHosts, Logger logger ) + throws IOException { this.logger = logger; - this.serverId = address.toString(); - this.knownHosts = knownHosts; - load(); + this.knownHostsFile = knownHosts; + this.knownHosts = preLoadedKnownHosts; + this.serverId = serverId; + } + + public static ConcurrentHashMap createKnownHostsMap( File knownHostsFile ) throws IOException + { + ConcurrentHashMap knownHosts = new ConcurrentHashMap<>(); + load( knownHostsFile, knownHosts ); + return knownHosts; } /** @@ -72,27 +79,27 @@ public class TrustOnFirstUseTrustManager implements X509TrustManager * * @throws IOException */ - private void load() throws IOException + private static synchronized void load( File knownHostsFile, Map knownHosts ) throws IOException { - if ( !knownHosts.exists() ) + if ( !knownHostsFile.exists() ) { return; } - assertKnownHostFileReadable(); + assertKnownHostFileReadable( knownHostsFile ); - BufferedReader reader = new BufferedReader( new FileReader( knownHosts ) ); + BufferedReader reader = new BufferedReader( new FileReader( knownHostsFile ) ); String line; while ( (line = reader.readLine()) != null ) { if ( (!line.trim().startsWith( "#" )) ) { String[] strings = line.split( " " ); - if ( strings[0].trim().equals( serverId ) ) + if(strings.length == 2) { - // load the certificate - fingerprint = strings[1].trim(); - return; + // we need to load all serverId and finger prints from the file as we do not know which one is + // our current connection. + knownHosts.put( strings[0], strings[1] ); } } } @@ -100,45 +107,45 @@ private void load() throws IOException } /** - * Save a new (server_ip, cert) pair into knownHosts file + * Save a new (server_ip, cert) pair into knownHostsFile file * * @param fingerprint the SHA-512 fingerprint of the host certificate */ - private void saveTrustedHost( String fingerprint ) throws IOException + private synchronized void saveTrustedHost( String serverId, String fingerprint ) throws IOException { - this.fingerprint = fingerprint; + knownHosts.put( serverId, fingerprint ); logger.warn( "Adding %s as known and trusted certificate for %s.", fingerprint, serverId ); createKnownCertFileIfNotExists(); assertKnownHostFileWritable(); - BufferedWriter writer = new BufferedWriter( new FileWriter( knownHosts, true ) ); - writer.write( serverId + " " + this.fingerprint ); + BufferedWriter writer = new BufferedWriter( new FileWriter( knownHostsFile, true ) ); + writer.write( serverId + " " + fingerprint ); writer.newLine(); writer.close(); } - private void assertKnownHostFileReadable() throws IOException + private static void assertKnownHostFileReadable( File knownHostsFile ) throws IOException { - if( !knownHosts.canRead() ) + if( !knownHostsFile.canRead() ) { throw new IOException( format( "Failed to load certificates from file %s as you have no read permissions to it.\n" + "Try configuring the Neo4j driver to use a file system location you do have read permissions to.", - knownHosts.getAbsolutePath() + knownHostsFile.getAbsolutePath() ) ); } } private void assertKnownHostFileWritable() throws IOException { - if( !knownHosts.canWrite() ) + if( !knownHostsFile.canWrite() ) { throw new IOException( format( "Failed to write certificates to file %s as you have no write permissions to it.\n" + "Try configuring the Neo4j driver to use a file system location you do have write permissions to.", - knownHosts.getAbsolutePath() + knownHostsFile.getAbsolutePath() ) ); } } @@ -161,24 +168,25 @@ public void checkServerTrusted( X509Certificate[] chain, String authType ) X509Certificate certificate = chain[0]; String cert = fingerprint( certificate ); + String fingerprint = this.knownHosts.get( serverId ); - if ( this.fingerprint == null ) + if ( fingerprint == null ) { try { - saveTrustedHost( cert ); + saveTrustedHost( serverId, cert ); } catch ( IOException e ) { throw new CertificateException( format( "Failed to save the server ID and the certificate received from the server to file %s.\n" + "Server ID: %s\nReceived cert:\n%s", - knownHosts.getAbsolutePath(), serverId, X509CertToString( cert ) ), e ); + knownHostsFile.getAbsolutePath(), serverId, X509CertToString( cert ) ), e ); } } else { - if ( !this.fingerprint.equals( cert ) ) + if ( !fingerprint.equals( cert ) ) { throw new CertificateException( format( "Unable to connect to neo4j at `%s`, because the certificate the server uses has changed. " + @@ -187,8 +195,8 @@ public void checkServerTrusted( X509Certificate[] chain, String authType ) "`%s` " + "in the file `%s`.\n" + "The old certificate saved in file is:\n%sThe New certificate received is:\n%s", - serverId, serverId, knownHosts.getAbsolutePath(), - X509CertToString( this.fingerprint ), X509CertToString( cert ) ) ); + serverId, serverId, knownHostsFile.getAbsolutePath(), + X509CertToString( fingerprint ), X509CertToString( cert ) ) ); } } } @@ -213,34 +221,38 @@ public static String fingerprint( X509Certificate cert ) throws CertificateExcep private File createKnownCertFileIfNotExists() throws IOException { - if ( !knownHosts.exists() ) + if ( !knownHostsFile.exists() ) { - File parentDir = knownHosts.getParentFile(); + File parentDir = knownHostsFile.getParentFile(); try { if ( parentDir != null && !parentDir.exists() ) { if ( !parentDir.mkdirs() ) { - throw new IOException( "Failed to create directories for the known hosts file in " + knownHosts.getAbsolutePath() + + throw new IOException( "Failed to create directories for the known hosts file in " + knownHostsFile + + .getAbsolutePath() + ". This is usually because you do not have write permissions to the directory. " + "Try configuring the Neo4j driver to use a file system location you do have write permissions to." ); } } - if ( !knownHosts.createNewFile() ) + if ( !knownHostsFile.createNewFile() ) { - throw new IOException( "Failed to create a known hosts file at " + knownHosts.getAbsolutePath() + + throw new IOException( "Failed to create a known hosts file at " + knownHostsFile + .getAbsolutePath() + ". This is usually because you do not have write permissions to the directory. " + "Try configuring the Neo4j driver to use a file system location you do have write permissions to." ); } } catch( SecurityException e ) { - throw new IOException( "Failed to create known host file and/or parent directories at " + knownHosts.getAbsolutePath() + + throw new IOException( "Failed to create known host file and/or parent directories at " + knownHostsFile + .getAbsolutePath() + ". This is usually because you do not have write permission to the directory. " + "Try configuring the Neo4j driver to use a file location you have write permissions to." ); } - BufferedWriter writer = new BufferedWriter( new FileWriter( knownHosts ) ); + BufferedWriter writer = new BufferedWriter( new FileWriter( knownHostsFile ) ); writer.write( "# This file contains trusted certificates for Neo4j servers, it's created by Neo4j drivers." ); writer.newLine(); writer.write( "# You can configure the location of this file in `org.neo4j.driver.Config`" ); @@ -248,7 +260,7 @@ private File createKnownCertFileIfNotExists() throws IOException writer.close(); } - return knownHosts; + return knownHostsFile; } /** diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index a031815e5e..9de9219434 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -34,7 +34,6 @@ import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.util.BiFunction; import org.neo4j.driver.v1.util.Function; import static java.lang.String.format; @@ -222,8 +221,7 @@ private static SecurityPlan createSecurityPlan( BoltServerAddress address, Confi case TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: return SecurityPlan.forSignedCertificates( config.trustStrategy().certFile() ); case TRUST_ON_FIRST_USE: - return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile(), - address, logger ); + return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile() ); default: throw new ClientException( "Unknown TLS authentication strategy: " + config.trustStrategy().strategy().name() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java b/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java index e503b5f8cf..5b3df126df 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java @@ -30,11 +30,13 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Scanner; +import java.util.concurrent.ConcurrentHashMap; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.v1.Logger; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -88,8 +90,9 @@ public void shouldLoadExistingCert() throws Throwable // Given BoltServerAddress knownServerAddress = new BoltServerAddress( knownServerIp, knownServerPort ); Logger logger = mock(Logger.class); + ConcurrentHashMap knownHostsMap = TrustOnFirstUseTrustManager.createKnownHostsMap( knownCertsFile ); TrustOnFirstUseTrustManager manager = - new TrustOnFirstUseTrustManager( knownServerAddress, knownCertsFile, logger ); + new TrustOnFirstUseTrustManager( knownServerAddress.toString(), knownCertsFile, knownHostsMap, logger ); X509Certificate wrongCertificate = mock( X509Certificate.class ); when( wrongCertificate.getEncoded() ).thenReturn( "fake certificate".getBytes() ); @@ -115,7 +118,9 @@ public void shouldSaveNewCert() throws Throwable int newPort = 200; BoltServerAddress address = new BoltServerAddress( knownServerIp, newPort ); Logger logger = mock(Logger.class); - TrustOnFirstUseTrustManager manager = new TrustOnFirstUseTrustManager( address, knownCertsFile, logger ); + ConcurrentHashMap knownHostsMap = TrustOnFirstUseTrustManager.createKnownHostsMap( knownCertsFile ); + TrustOnFirstUseTrustManager manager = + new TrustOnFirstUseTrustManager( address.toString(), knownCertsFile, knownHostsMap, logger ); String fingerprint = fingerprint( knownCertificate ); @@ -125,6 +130,7 @@ public void shouldSaveNewCert() throws Throwable // Then no exception should've been thrown, and we should've logged that we now trust this certificate verify(logger).warn( "Adding %s as known and trusted certificate for %s.", fingerprint, "1.2.3.4:200" ); + assertThat(knownHostsMap.get( address.toString() ), equalTo( fingerprint ) ); // And the file should contain the right info Scanner reader = new Scanner( knownCertsFile ); @@ -159,7 +165,7 @@ public void shouldThrowMeaningfulExceptionIfHasNoReadPermissionToKnownHostFile() // When & Then try { - new TrustOnFirstUseTrustManager( new BoltServerAddress( knownServerIp, knownServerPort ), knownHostFile, null ); + TrustOnFirstUseTrustManager.createKnownHostsMap( knownHostFile ); fail( "Should have failed in load certs" ); } catch( IOException e ) @@ -177,14 +183,15 @@ public void shouldThrowMeaningfulExceptionIfHasNoWritePermissionToKnownHostFile( { // Given File knownHostFile = mock( File.class ); - when( knownHostFile.exists() ).thenReturn( false /*skip reading*/, true ); + when( knownHostFile.exists() ).thenReturn( true ); when( knownHostFile.canWrite() ).thenReturn( false ); // When & Then try { TrustOnFirstUseTrustManager manager = - new TrustOnFirstUseTrustManager( new BoltServerAddress( knownServerIp, knownServerPort ), knownHostFile, mock( Logger.class ) ); + new TrustOnFirstUseTrustManager( new BoltServerAddress( knownServerIp, knownServerPort ).toString(), + knownHostFile, new ConcurrentHashMap(), mock( Logger.class ) ); manager.checkServerTrusted( new X509Certificate[]{ knownCertificate}, null ); fail( "Should have failed in write to certs" ); } diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java index 45e233329b..7660a12c8e 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java @@ -34,13 +34,17 @@ import java.security.cert.X509Certificate; import javax.net.ssl.SSLHandshakeException; -import org.neo4j.driver.internal.logging.DevNullLogger; +import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.security.TLSSocketChannel; -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.v1.*; import org.neo4j.driver.internal.util.CertificateTool; - +import org.neo4j.driver.v1.Config; +import org.neo4j.driver.v1.Driver; +import org.neo4j.driver.v1.GraphDatabase; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Logging; +import org.neo4j.driver.v1.Session; +import org.neo4j.driver.v1.StatementResult; import org.neo4j.driver.v1.util.CertificateToolTest; import org.neo4j.driver.v1.util.Neo4jRunner; import org.neo4j.driver.v1.util.Neo4jSettings; @@ -184,7 +188,7 @@ public void shouldFailTLSHandshakeDueToWrongCertInKnownCertsFile() throws Throwa createFakeServerCertPairInKnownCerts( address, knownCerts ); // When & Then - SecurityPlan securityPlan = SecurityPlan.forTrustOnFirstUse( knownCerts, address, new DevNullLogger() ); + SecurityPlan securityPlan = SecurityPlan.forTrustOnFirstUse( knownCerts ); TLSSocketChannel sslChannel = null; try { @@ -333,7 +337,7 @@ private void performTLSHandshakeUsingKnownCerts( File knownCerts ) throws Throwa // When - SecurityPlan securityPlan = SecurityPlan.forTrustOnFirstUse( knownCerts, address, new DevNullLogger() ); + SecurityPlan securityPlan = SecurityPlan.forTrustOnFirstUse( knownCerts ); TLSSocketChannel sslChannel = new TLSSocketChannel( address, securityPlan, channel, logger ); sslChannel.close();