Skip to content

Fix a file descriptor leak with wrong user:pass #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/main/java/org/tarantool/TarantoolBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ public TarantoolBase() {
public TarantoolBase(String username, String password, Socket socket) {
super();
try {
this.is = new DataInputStream(cis = new CountInputStreamImpl(socket.getInputStream()));
cis = new CountInputStreamImpl(socket.getInputStream());
is = new DataInputStream(cis);
byte[] bytes = new byte[64];
is.readFully(bytes);
String firstLine = new String(bytes);
if (!firstLine.startsWith(WELCOME)) {
closeStreams();
close();
throw new CommunicationException("Welcome message should starts with tarantool but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));
}
Expand All @@ -56,23 +58,15 @@ public TarantoolBase(String username, String password, Socket socket) {
OutputStream os = socket.getOutputStream();
os.write(authPacket.array(), 0, authPacket.remaining());
os.flush();
readPacket(is);
readPacket();
Long code = (Long) headers.get(Key.CODE.getId());
if (code != 0) {
closeStreams();
throw serverError(code, body.get(Key.ERROR.getId()));
}
}
} catch (IOException e) {
try {
is.close();
} catch (IOException ignored) {

}
try {
cis.close();
} catch (IOException ignored) {

}
closeStreams();
throw new CommunicationException("Couldn't connect to tarantool", e);
}
}
Expand Down Expand Up @@ -130,7 +124,7 @@ protected ByteBuffer createPacket(Code code, Long syncId, Long schemaId, Object.
return buffer;
}

protected void readPacket(DataInputStream is) throws IOException {
protected void readPacket() throws IOException {
int size = ((Number) msgPackLite.unpack(is)).intValue();
long mark = cis.getBytesRead();
headers = (Map<Integer, Object>) msgPackLite.unpack(is);
Expand Down Expand Up @@ -185,7 +179,6 @@ protected List<Map<String, Object>> readSqlResult(List<List<?>> data) {
return values;
}


protected Long getSqlRowCount() {
Map<Key, Object> info = (Map<Key, Object>) body.get(Key.SQL_INFO.getId());
Number rowCount;
Expand Down Expand Up @@ -220,6 +213,21 @@ protected void closeChannel(SocketChannel channel) {
}
}

protected void closeStreams() {
if (is != null) {
try {
is.close();
} catch (IOException ignored) {
}
}
if (cis != null) {
try {
cis.close();
} catch (IOException ignored) {
}
}
}

protected void validateArgs(Object[] args) {
if (args != null) {
for (int i = 0; i < args.length; i += 2) {
Expand Down
22 changes: 8 additions & 14 deletions src/main/java/org/tarantool/TarantoolClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,40 +132,34 @@ protected void reconnect(int retry, Throwable lastError) {
}

protected void connect(final SocketChannel channel) throws Exception {
closeStreams();
try {
DataInputStream is = new DataInputStream(cis = new ByteBufferInputStream(channel));
cis = new ByteBufferInputStream(channel);
is = new DataInputStream(cis);
byte[] bytes = new byte[64];
is.readFully(bytes);
String firstLine = new String(bytes);
if (!firstLine.startsWith("Tarantool")) {
CommunicationException e = new CommunicationException("Welcome message should starts with tarantool " +
"but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));

closeStreams();
close(e);
throw e;
}
is.readFully(bytes);
this.salt = new String(bytes);
if (config.username != null && config.password != null) {
writeFully(channel, createAuthPacket(config.username, config.password));
readPacket(is);
readPacket();
Long code = (Long) headers.get(Key.CODE.getId());
if (code != 0) {
closeStreams();
throw serverError(code, body.get(Key.ERROR.getId()));
}
}
this.is = is;
} catch (IOException e) {
try {
is.close();
} catch (IOException ignored) {

}
try {
cis.close();
} catch (IOException ignored) {

}
closeStreams();
throw new CommunicationException("Couldn't connect to tarantool", e);
}
channel.configureBlocking(false);
Expand Down Expand Up @@ -358,7 +352,7 @@ protected void readThread() {
while (!Thread.currentThread().isInterrupted()) {
try {
long code;
readPacket(is);
readPacket();
code = (Long) headers.get(Key.CODE.getId());
Long syncId = (Long) headers.get(Key.SYNC.getId());
CompletableFuture<?> future = futures.remove(syncId);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/TarantoolConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected List<?> exec(Code code, Object... args) {
ByteBuffer packet = createPacket(code, syncId.incrementAndGet(), null, args);
out.write(packet.array(), 0, packet.remaining());
out.flush();
readPacket(is);
readPacket();
Long c = (Long) headers.get(Key.CODE.getId());
if (c == 0) {
return (List) body.get(Key.DATA.getId());
Expand Down
4 changes: 4 additions & 0 deletions src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ protected TarantoolClient makeClient() {
return new TarantoolClientImpl(socketChannelProvider, makeClientConfig());
}

protected TarantoolClient makeClient(SocketChannelProvider provider) {
return new TarantoolClientImpl(provider, makeClientConfig());
}

protected static TarantoolClientConfig makeClientConfig() {
return fillClientConfig(new TarantoolClientConfig());
}
Expand Down
59 changes: 56 additions & 3 deletions src/test/java/org/tarantool/ClientReconnectIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -215,19 +216,27 @@ public void run() {

/**
* Test concurrent operations, reconnects and close.
*
* Expected situation is nothing gets stuck.
*
* The test sets SO_LINGER to 0 for outgoing connections to avoid producing
* many TIME_WAIT sockets, because an available port range can be
* exhausted.
*/
@Test
public void testLongParallelCloseReconnects() {
int numThreads = 4;
int numClients = 4;
int timeBudget = 30*1000;

SocketChannelProvider provider = new TestSocketChannelProvider(host,
port, RESTART_TIMEOUT).setSoLinger(0);

final AtomicReferenceArray<TarantoolClient> clients =
new AtomicReferenceArray<TarantoolClient>(numClients);

for (int idx = 0; idx < clients.length(); idx++) {
clients.set(idx, makeClient());
clients.set(idx, makeClient(provider));
}

final Random rnd = new Random();
Expand Down Expand Up @@ -256,7 +265,7 @@ public void run() {

cli.close();

TarantoolClient next = makeClient();
TarantoolClient next = makeClient(provider);
if (!clients.compareAndSet(idx, cli, next)) {
next.close();
}
Expand Down Expand Up @@ -284,7 +293,9 @@ public void run() {
fail(e);
}
if (deadline > System.currentTimeMillis()) {
System.out.println("" + (deadline - System.currentTimeMillis())/1000 + "s remains.");
System.out.println("testLongParallelCloseReconnects: " +
(deadline - System.currentTimeMillis()) / 1000 +
"s remain");
}
}

Expand All @@ -302,4 +313,46 @@ public void run() {

assertTrue(cnt.get() > threads.length);
}

/**
* Verify that we don't exceed a file descriptor limit (and so likely don't
* leak file descriptors) when trying to connect to an existing node with
* wrong authentification credentials.
*
* The test sets SO_LINGER to 0 for outgoing connections to avoid producing
* many TIME_WAIT sockets, because an available port range can be
* exhausted.
*/
@Test
public void testReconnectWrongAuth() throws Exception {
SocketChannelProvider provider = new TestSocketChannelProvider(host,
port, RESTART_TIMEOUT).setSoLinger(0);
TarantoolClientConfig config = makeClientConfig();
config.initTimeoutMillis = 100;
config.password = config.password + 'x';
for (int i = 0; i < 100; ++i) {
if (i % 10 == 0)
System.out.println("testReconnectWrongAuth: " + (100 - i) +
" iterations remain");
CommunicationException e = assertThrows(CommunicationException.class,
new Executable() {
@Override
public void execute() throws Throwable {
client = new TarantoolClientImpl(provider, config);
}
}
);
assertEquals(e.getMessage(), "100ms is exceeded when waiting " +
"for client initialization. You could configure init " +
"timeout in TarantoolConfig");
}

/*
* Verify we don't exceed a file descriptor limit. If we exceed it, a
* client will not able to connect to tarantool.
*/
TarantoolClient client = makeClient();
client.syncOps().ping();
client.close();
}
}
25 changes: 20 additions & 5 deletions src/test/java/org/tarantool/TestSocketChannelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,42 @@

import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel;
import static java.net.StandardSocketOptions.SO_LINGER;

/**
* Socket channel provider to be used throughout the tests.
*/
public class TestSocketChannelProvider implements SocketChannelProvider {
String host;
int port;
int restart_timeout;
int restartTimeout;
int soLinger;

public TestSocketChannelProvider(String host, int port, int restart_timeout) {
public TestSocketChannelProvider(String host, int port, int restartTimeout) {
this.host = host;
this.port = port;
this.restart_timeout = restart_timeout;
this.restartTimeout = restartTimeout;
this.soLinger = -1;
}

public TestSocketChannelProvider setSoLinger(int soLinger) {
this.soLinger = soLinger;
return this;
}

@Override
public SocketChannel get(int retryNumber, Throwable lastError) {
long budget = System.currentTimeMillis() + restart_timeout;
long budget = System.currentTimeMillis() + restartTimeout;
while (!Thread.currentThread().isInterrupted()) {
try {
return SocketChannel.open(new InetSocketAddress(host, port));
SocketChannel channel = SocketChannel.open();
/*
* A value less then zero means disable lingering (it is a
* default behaviour).
*/
channel.setOption(SO_LINGER, soLinger);
channel.connect(new InetSocketAddress(host, port));
return channel;
} catch (Exception e) {
if (budget < System.currentTimeMillis())
throw new RuntimeException(e);
Expand Down