diff --git a/spring-integration-sftp/src/main/java/org/springframework/integration/sftp/session/ResourceKnownHostsServerKeyVerifier.java b/spring-integration-sftp/src/main/java/org/springframework/integration/sftp/session/ResourceKnownHostsServerKeyVerifier.java index 9d6a852a574..7a17f61c2fd 100644 --- a/spring-integration-sftp/src/main/java/org/springframework/integration/sftp/session/ResourceKnownHostsServerKeyVerifier.java +++ b/spring-integration-sftp/src/main/java/org/springframework/integration/sftp/session/ResourceKnownHostsServerKeyVerifier.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,18 +66,19 @@ public ResourceKnownHostsServerKeyVerifier(Resource knownHostsResource) { @Override public boolean verifyServerKey(ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey) { Collection knownHosts = this.keysSupplier.get(); - KnownHostsServerKeyVerifier.HostEntryPair match = findKnownHostEntry(clientSession, remoteAddress, knownHosts); - if (match == null) { + List matches = + findKnownHostEntries(clientSession, remoteAddress, knownHosts); + + if (matches.isEmpty()) { return false; } - KnownHostEntry entry = match.getHostEntry(); - PublicKey expected = match.getServerKey(); - if (KeyUtils.compareKeys(expected, serverKey)) { - return !"revoked".equals(entry.getMarker()); - } + String serverKeyType = KeyUtils.getKeyType(serverKey); - return false; + return matches.stream() + .filter(match -> serverKeyType.equals(match.getHostEntry().getKeyEntry().getKeyType())) + .filter(match -> KeyUtils.compareKeys(match.getServerKey(), serverKey)) + .anyMatch(match -> !"revoked".equals(match.getHostEntry().getMarker())); } private static Supplier> getKnownHostSupplier( @@ -106,26 +107,32 @@ private static PublicKey resolveHostKey(KnownHostEntry entry) throws IOException return authEntry.resolvePublicKey(null, PublicKeyEntryResolver.IGNORING); } - private static KnownHostsServerKeyVerifier.HostEntryPair findKnownHostEntry( + private static List findKnownHostEntries( ClientSession clientSession, SocketAddress remoteAddress, Collection knownHosts) { + if (GenericUtils.isEmpty(knownHosts)) { + return Collections.emptyList(); + } + Collection candidates = resolveHostNetworkIdentities(clientSession, remoteAddress); if (GenericUtils.isEmpty(candidates)) { - return null; + return Collections.emptyList(); } + List matches = new ArrayList<>(); for (KnownHostsServerKeyVerifier.HostEntryPair match : knownHosts) { KnownHostEntry entry = match.getHostEntry(); for (SshdSocketAddress host : candidates) { if (entry.isHostMatch(host.getHostName(), host.getPort())) { - return match; + matches.add(match); + break; } } } - return null; // no match found + return matches; } private static Collection resolveHostNetworkIdentities( diff --git a/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java b/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java index f03798a8568..187beee44bd 100644 --- a/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java +++ b/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; import java.net.ConnectException; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -121,7 +122,7 @@ public void concurrentGetSessionDoesntCauseFailure() throws IOException { asyncTaskExecutor.execute(() -> concurrentSessions.add(sftpSessionFactory.getSession())); } - await().until(() -> concurrentSessions.size() == 3); + await().atMost(Duration.ofSeconds(30)).until(() -> concurrentSessions.size() == 3); assertThat(concurrentSessions.get(0)) .isNotEqualTo(concurrentSessions.get(1))