diff --git a/src/main/java/org/tarantool/TarantoolBase.java b/src/main/java/org/tarantool/TarantoolBase.java index 3fb1ce40..bc375c25 100644 --- a/src/main/java/org/tarantool/TarantoolBase.java +++ b/src/main/java/org/tarantool/TarantoolBase.java @@ -40,11 +40,13 @@ public TarantoolBase() { public TarantoolBase(String username, String password, Socket socket) { super(); try { - this.is = new DataInputStream(cis = new CountInputStreamImpl(socket.getInputStream())); + cis = new CountInputStreamImpl(socket.getInputStream()); + is = new DataInputStream(cis); byte[] bytes = new byte[64]; is.readFully(bytes); String firstLine = new String(bytes); if (!firstLine.startsWith(WELCOME)) { + closeStreams(); close(); throw new CommunicationException("Welcome message should starts with tarantool but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet")); } @@ -56,23 +58,15 @@ public TarantoolBase(String username, String password, Socket socket) { OutputStream os = socket.getOutputStream(); os.write(authPacket.array(), 0, authPacket.remaining()); os.flush(); - readPacket(is); + readPacket(); Long code = (Long) headers.get(Key.CODE.getId()); if (code != 0) { + closeStreams(); throw serverError(code, body.get(Key.ERROR.getId())); } } } catch (IOException e) { - try { - is.close(); - } catch (IOException ignored) { - - } - try { - cis.close(); - } catch (IOException ignored) { - - } + closeStreams(); throw new CommunicationException("Couldn't connect to tarantool", e); } } @@ -130,7 +124,7 @@ protected ByteBuffer createPacket(Code code, Long syncId, Long schemaId, Object. return buffer; } - protected void readPacket(DataInputStream is) throws IOException { + protected void readPacket() throws IOException { int size = ((Number) msgPackLite.unpack(is)).intValue(); long mark = cis.getBytesRead(); headers = (Map) msgPackLite.unpack(is); @@ -185,7 +179,6 @@ protected List> readSqlResult(List> data) { return values; } - protected Long getSqlRowCount() { Map info = (Map) body.get(Key.SQL_INFO.getId()); Number rowCount; @@ -220,6 +213,21 @@ protected void closeChannel(SocketChannel channel) { } } + protected void closeStreams() { + if (is != null) { + try { + is.close(); + } catch (IOException ignored) { + } + } + if (cis != null) { + try { + cis.close(); + } catch (IOException ignored) { + } + } + } + protected void validateArgs(Object[] args) { if (args != null) { for (int i = 0; i < args.length; i += 2) { diff --git a/src/main/java/org/tarantool/TarantoolClientImpl.java b/src/main/java/org/tarantool/TarantoolClientImpl.java index 0c510287..977ed90a 100644 --- a/src/main/java/org/tarantool/TarantoolClientImpl.java +++ b/src/main/java/org/tarantool/TarantoolClientImpl.java @@ -132,8 +132,10 @@ protected void reconnect(int retry, Throwable lastError) { } protected void connect(final SocketChannel channel) throws Exception { + closeStreams(); try { - DataInputStream is = new DataInputStream(cis = new ByteBufferInputStream(channel)); + cis = new ByteBufferInputStream(channel); + is = new DataInputStream(cis); byte[] bytes = new byte[64]; is.readFully(bytes); String firstLine = new String(bytes); @@ -141,6 +143,7 @@ protected void connect(final SocketChannel channel) throws Exception { CommunicationException e = new CommunicationException("Welcome message should starts with tarantool " + "but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet")); + closeStreams(); close(e); throw e; } @@ -148,24 +151,15 @@ protected void connect(final SocketChannel channel) throws Exception { this.salt = new String(bytes); if (config.username != null && config.password != null) { writeFully(channel, createAuthPacket(config.username, config.password)); - readPacket(is); + readPacket(); Long code = (Long) headers.get(Key.CODE.getId()); if (code != 0) { + closeStreams(); throw serverError(code, body.get(Key.ERROR.getId())); } } - this.is = is; } catch (IOException e) { - try { - is.close(); - } catch (IOException ignored) { - - } - try { - cis.close(); - } catch (IOException ignored) { - - } + closeStreams(); throw new CommunicationException("Couldn't connect to tarantool", e); } channel.configureBlocking(false); @@ -358,7 +352,7 @@ protected void readThread() { while (!Thread.currentThread().isInterrupted()) { try { long code; - readPacket(is); + readPacket(); code = (Long) headers.get(Key.CODE.getId()); Long syncId = (Long) headers.get(Key.SYNC.getId()); CompletableFuture future = futures.remove(syncId); diff --git a/src/main/java/org/tarantool/TarantoolConnection.java b/src/main/java/org/tarantool/TarantoolConnection.java index b817988f..62ed613b 100644 --- a/src/main/java/org/tarantool/TarantoolConnection.java +++ b/src/main/java/org/tarantool/TarantoolConnection.java @@ -28,7 +28,7 @@ protected List exec(Code code, Object... args) { ByteBuffer packet = createPacket(code, syncId.incrementAndGet(), null, args); out.write(packet.array(), 0, packet.remaining()); out.flush(); - readPacket(is); + readPacket(); Long c = (Long) headers.get(Key.CODE.getId()); if (c == 0) { return (List) body.get(Key.DATA.getId()); diff --git a/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java b/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java index 4c494f50..739a1cf3 100644 --- a/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java +++ b/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java @@ -135,6 +135,10 @@ protected TarantoolClient makeClient() { return new TarantoolClientImpl(socketChannelProvider, makeClientConfig()); } + protected TarantoolClient makeClient(SocketChannelProvider provider) { + return new TarantoolClientImpl(provider, makeClientConfig()); + } + protected static TarantoolClientConfig makeClientConfig() { return fillClientConfig(new TarantoolClientConfig()); } diff --git a/src/test/java/org/tarantool/ClientReconnectIT.java b/src/test/java/org/tarantool/ClientReconnectIT.java index 2472bf05..59c4011a 100644 --- a/src/test/java/org/tarantool/ClientReconnectIT.java +++ b/src/test/java/org/tarantool/ClientReconnectIT.java @@ -20,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -215,7 +216,12 @@ public void run() { /** * Test concurrent operations, reconnects and close. + * * Expected situation is nothing gets stuck. + * + * The test sets SO_LINGER to 0 for outgoing connections to avoid producing + * many TIME_WAIT sockets, because an available port range can be + * exhausted. */ @Test public void testLongParallelCloseReconnects() { @@ -223,11 +229,14 @@ public void testLongParallelCloseReconnects() { int numClients = 4; int timeBudget = 30*1000; + SocketChannelProvider provider = new TestSocketChannelProvider(host, + port, RESTART_TIMEOUT).setSoLinger(0); + final AtomicReferenceArray clients = new AtomicReferenceArray(numClients); for (int idx = 0; idx < clients.length(); idx++) { - clients.set(idx, makeClient()); + clients.set(idx, makeClient(provider)); } final Random rnd = new Random(); @@ -256,7 +265,7 @@ public void run() { cli.close(); - TarantoolClient next = makeClient(); + TarantoolClient next = makeClient(provider); if (!clients.compareAndSet(idx, cli, next)) { next.close(); } @@ -284,7 +293,9 @@ public void run() { fail(e); } if (deadline > System.currentTimeMillis()) { - System.out.println("" + (deadline - System.currentTimeMillis())/1000 + "s remains."); + System.out.println("testLongParallelCloseReconnects: " + + (deadline - System.currentTimeMillis()) / 1000 + + "s remain"); } } @@ -302,4 +313,46 @@ public void run() { assertTrue(cnt.get() > threads.length); } + + /** + * Verify that we don't exceed a file descriptor limit (and so likely don't + * leak file descriptors) when trying to connect to an existing node with + * wrong authentification credentials. + * + * The test sets SO_LINGER to 0 for outgoing connections to avoid producing + * many TIME_WAIT sockets, because an available port range can be + * exhausted. + */ + @Test + public void testReconnectWrongAuth() throws Exception { + SocketChannelProvider provider = new TestSocketChannelProvider(host, + port, RESTART_TIMEOUT).setSoLinger(0); + TarantoolClientConfig config = makeClientConfig(); + config.initTimeoutMillis = 100; + config.password = config.password + 'x'; + for (int i = 0; i < 100; ++i) { + if (i % 10 == 0) + System.out.println("testReconnectWrongAuth: " + (100 - i) + + " iterations remain"); + CommunicationException e = assertThrows(CommunicationException.class, + new Executable() { + @Override + public void execute() throws Throwable { + client = new TarantoolClientImpl(provider, config); + } + } + ); + assertEquals(e.getMessage(), "100ms is exceeded when waiting " + + "for client initialization. You could configure init " + + "timeout in TarantoolConfig"); + } + + /* + * Verify we don't exceed a file descriptor limit. If we exceed it, a + * client will not able to connect to tarantool. + */ + TarantoolClient client = makeClient(); + client.syncOps().ping(); + client.close(); + } } diff --git a/src/test/java/org/tarantool/TestSocketChannelProvider.java b/src/test/java/org/tarantool/TestSocketChannelProvider.java index 469bc77c..57391f70 100644 --- a/src/test/java/org/tarantool/TestSocketChannelProvider.java +++ b/src/test/java/org/tarantool/TestSocketChannelProvider.java @@ -2,6 +2,7 @@ import java.net.InetSocketAddress; import java.nio.channels.SocketChannel; +import static java.net.StandardSocketOptions.SO_LINGER; /** * Socket channel provider to be used throughout the tests. @@ -9,20 +10,34 @@ public class TestSocketChannelProvider implements SocketChannelProvider { String host; int port; - int restart_timeout; + int restartTimeout; + int soLinger; - public TestSocketChannelProvider(String host, int port, int restart_timeout) { + public TestSocketChannelProvider(String host, int port, int restartTimeout) { this.host = host; this.port = port; - this.restart_timeout = restart_timeout; + this.restartTimeout = restartTimeout; + this.soLinger = -1; + } + + public TestSocketChannelProvider setSoLinger(int soLinger) { + this.soLinger = soLinger; + return this; } @Override public SocketChannel get(int retryNumber, Throwable lastError) { - long budget = System.currentTimeMillis() + restart_timeout; + long budget = System.currentTimeMillis() + restartTimeout; while (!Thread.currentThread().isInterrupted()) { try { - return SocketChannel.open(new InetSocketAddress(host, port)); + SocketChannel channel = SocketChannel.open(); + /* + * A value less then zero means disable lingering (it is a + * default behaviour). + */ + channel.setOption(SO_LINGER, soLinger); + channel.connect(new InetSocketAddress(host, port)); + return channel; } catch (Exception e) { if (budget < System.currentTimeMillis()) throw new RuntimeException(e);