Skip to content

Commit 44684c9

Browse files
committed
Use same logic for plain/TLS in IO loops
Fixes #11
1 parent d6a7478 commit 44684c9

10 files changed

+208
-160
lines changed

src/main/java/com/rabbitmq/client/impl/nio/ByteBufferOutputStream.java

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ public void write(int b) throws IOException {
4242
buffer.put((byte) b);
4343
}
4444

45+
@Override
46+
public void flush() throws IOException {
47+
drain(channel, buffer);
48+
}
49+
4550
public static void drain(WritableByteChannel channel, ByteBuffer buffer) throws IOException {
4651
buffer.flip();
4752
while(buffer.hasRemaining() && channel.write(buffer) != -1);

src/main/java/com/rabbitmq/client/impl/nio/FrameWriteRequest.java

+2-5
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import java.io.DataOutputStream;
2121
import java.io.IOException;
22-
import java.nio.ByteBuffer;
23-
import java.nio.channels.WritableByteChannel;
2422

2523
/**
2624
*
@@ -34,8 +32,7 @@ public FrameWriteRequest(Frame frame) {
3432
}
3533

3634
@Override
37-
public void handle(WritableByteChannel writableChannel, ByteBuffer buffer) throws IOException {
38-
// FIXME reuse output stream from state
39-
frame.writeTo(new DataOutputStream(new ByteBufferOutputStream(writableChannel, buffer)));
35+
public void handle(DataOutputStream outputStream) throws IOException {
36+
frame.writeTo(outputStream);
4037
}
4138
}

src/main/java/com/rabbitmq/client/impl/nio/HeaderWriteRequest.java

+7-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.rabbitmq.client.AMQP;
1919

20+
import java.io.DataOutputStream;
2021
import java.io.IOException;
2122
import java.nio.ByteBuffer;
2223
import java.nio.channels.WritableByteChannel;
@@ -27,14 +28,11 @@
2728
public class HeaderWriteRequest implements WriteRequest {
2829

2930
@Override
30-
public void handle(WritableByteChannel writableChannel, ByteBuffer buffer) throws IOException {
31-
buffer.put("AMQP".getBytes("US-ASCII"));
32-
buffer.put((byte) 0);
33-
buffer.put((byte) AMQP.PROTOCOL.MAJOR);
34-
buffer.put((byte) AMQP.PROTOCOL.MINOR);
35-
buffer.put((byte) AMQP.PROTOCOL.REVISION);
36-
buffer.flip();
37-
while(buffer.hasRemaining() && writableChannel.write(buffer) != -1);
38-
buffer.clear();
31+
public void handle(DataOutputStream outputStream) throws IOException {
32+
outputStream.write("AMQP".getBytes("US-ASCII"));
33+
outputStream.write(0);
34+
outputStream.write(AMQP.PROTOCOL.MAJOR);
35+
outputStream.write(AMQP.PROTOCOL.MINOR);
36+
outputStream.write(AMQP.PROTOCOL.REVISION);
3937
}
4038
}

src/main/java/com/rabbitmq/client/impl/nio/NioLoopsState.java

+10-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.slf4j.LoggerFactory;
66

77
import java.io.IOException;
8+
import java.nio.ByteBuffer;
89
import java.nio.channels.Selector;
910
import java.util.concurrent.ExecutorService;
1011
import java.util.concurrent.Future;
@@ -24,6 +25,8 @@ public class NioLoopsState {
2425

2526
private final ThreadFactory threadFactory;
2627

28+
final ByteBuffer readBuffer, writeBuffer;
29+
2730
private Thread readThread, writeThread;
2831

2932
private Future<?> writeTask;
@@ -34,10 +37,12 @@ public class NioLoopsState {
3437
private final AtomicLong nioLoopsConnectionCount = new AtomicLong();
3538

3639
public NioLoopsState(SocketChannelFrameHandlerFactory socketChannelFrameHandlerFactory,
37-
ExecutorService executorService, ThreadFactory threadFactory) {
40+
NioParams nioParams) {
3841
this.socketChannelFrameHandlerFactory = socketChannelFrameHandlerFactory;
39-
this.executorService = executorService;
40-
this.threadFactory = threadFactory;
42+
this.executorService = nioParams.getNioExecutor();
43+
this.threadFactory = nioParams.getThreadFactory();
44+
this.readBuffer = ByteBuffer.allocate(nioParams.getReadByteBufferSize());
45+
this.writeBuffer = ByteBuffer.allocate(nioParams.getWriteByteBufferSize());
4146
}
4247

4348
void notifyNewConnection() {
@@ -62,14 +67,14 @@ private void startIoLoops() {
6267
);
6368
this.writeThread = Environment.newThread(
6469
threadFactory,
65-
new WriteLoop(socketChannelFrameHandlerFactory.nioParams,this.writeSelectorState),
70+
new WriteLoop(socketChannelFrameHandlerFactory.nioParams,this),
6671
"rabbitmq-nio-write"
6772
);
6873
readThread.start();
6974
writeThread.start();
7075
} else {
7176
this.executorService.submit(new ReadLoop(socketChannelFrameHandlerFactory.nioParams, this));
72-
this.writeTask = this.executorService.submit(new WriteLoop(socketChannelFrameHandlerFactory.nioParams,this.writeSelectorState));
77+
this.writeTask = this.executorService.submit(new WriteLoop(socketChannelFrameHandlerFactory.nioParams,this));
7378
}
7479
}
7580

src/main/java/com/rabbitmq/client/impl/nio/ReadLoop.java

+46-27
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,19 @@ public class ReadLoop extends AbstractNioLoop {
3737

3838
private final NioLoopsState state;
3939

40-
public ReadLoop(NioParams nioParams, NioLoopsState readSelectorState) {
40+
public ReadLoop(NioParams nioParams, NioLoopsState loopsState) {
4141
super(nioParams);
42-
this.state = readSelectorState;
42+
this.state = loopsState;
4343
}
4444

4545
@Override
4646
public void run() {
4747
final SelectorHolder selectorState = state.readSelectorState;
4848
final Selector selector = selectorState.selector;
49-
Set<SocketChannelRegistration> registrations = selectorState.registrations;
50-
// FIXME find a better default?
51-
ByteBuffer buffer = ByteBuffer.allocate(nioParams.getReadByteBufferSize());
49+
final Set<SocketChannelRegistration> registrations = selectorState.registrations;
50+
51+
final ByteBuffer buffer = state.readBuffer;
52+
5253
try {
5354
int idlenessCount = 0;
5455
while (true && !Thread.currentThread().isInterrupted()) {
@@ -112,23 +113,44 @@ public void run() {
112113
if (key.isReadable()) {
113114
final SocketChannelFrameHandlerState state = (SocketChannelFrameHandlerState) key.attachment();
114115
try {
115-
if (state.ssl) {
116-
ByteBuffer peerAppData = state.peerAppData;
117-
ByteBuffer peerNetData = state.peerNetData;
118-
SSLEngine engine = state.sslEngine;
119116

120-
peerNetData.clear();
121-
peerAppData.clear();
117+
DataInputStream inputStream = state.inputStream;
122118

123-
peerNetData.flip();
124-
peerAppData.flip();
119+
state.prepareForReadSequence();
125120

126-
// FIXME reuse input stream
127-
SslEngineByteBufferInputStream sslEngineByteBufferInputStream = new SslEngineByteBufferInputStream(
128-
engine, peerAppData, peerNetData, channel
129-
);
121+
while (state.continueReading()) {
122+
Frame frame = Frame.readFrom(inputStream);
130123

131-
DataInputStream inputStream = new DataInputStream(sslEngineByteBufferInputStream);
124+
try {
125+
boolean noProblem = state.getConnection().handleReadFrame(frame);
126+
if (noProblem && (!state.getConnection().isRunning() || state.getConnection().hasBrokerInitiatedShutdown())) {
127+
// looks like the frame was Close-Ok or Close
128+
dispatchShutdownToConnection(state);
129+
key.cancel();
130+
break;
131+
}
132+
} catch (Throwable ex) {
133+
// problem during frame processing, tell connection, and
134+
// we can stop for this channel
135+
handleIoError(state, ex);
136+
key.cancel();
137+
break;
138+
}
139+
140+
}
141+
142+
/*
143+
if (state.ssl) {
144+
ByteBuffer plainIn = state.plainIn;
145+
ByteBuffer cipherIn = state.cipherIn;
146+
147+
cipherIn.clear();
148+
plainIn.clear();
149+
150+
cipherIn.flip();
151+
plainIn.flip();
152+
153+
DataInputStream inputStream = state.inputStream;
132154
133155
while (true) {
134156
Frame frame = Frame.readFrom(inputStream);
@@ -149,25 +171,22 @@ public void run() {
149171
break;
150172
}
151173
152-
if (!peerAppData.hasRemaining() && !peerNetData.hasRemaining()) {
174+
if (!plainIn.hasRemaining() && !cipherIn.hasRemaining()) {
153175
// need to try to read something
154-
peerNetData.clear();
155-
int bytesRead = channel.read(peerNetData);
176+
cipherIn.clear();
177+
int bytesRead = channel.read(cipherIn);
156178
if (bytesRead <= 0) {
157179
break;
158180
} else {
159-
peerNetData.flip();
181+
cipherIn.flip();
160182
}
161183
}
162184
}
163185
} else {
164186
channel.read(buffer);
165187
buffer.flip();
166188
167-
// FIXME reuse input stream
168-
DataInputStream inputStream = new DataInputStream(
169-
new ByteBufferInputStream(channel, buffer)
170-
);
189+
DataInputStream inputStream = state.inputStream;
171190
172191
while (buffer.hasRemaining()) {
173192
Frame frame = Frame.readFrom(inputStream);
@@ -195,7 +214,7 @@ public void run() {
195214
}
196215
}
197216
}
198-
217+
*/
199218
state.setLastActivity(System.currentTimeMillis());
200219
} catch (final Exception e) {
201220
LOGGER.warn("Error during reading frames", e);

src/main/java/com/rabbitmq/client/impl/nio/SocketChannelFrameHandlerFactory.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public SocketChannelFrameHandlerFactory(int connectionTimeout, SocketConfigurato
6767
this.threadFactory = nioParams.getThreadFactory();
6868
this.nioLoopsStates = new ArrayList<NioLoopsState>(this.nioParams.getNbIoThreads() / 2);
6969
for(int i = 0; i < this.nioParams.getNbIoThreads() / 2; i++) {
70-
this.nioLoopsStates.add(new NioLoopsState(this, this.executorService, this.threadFactory));
70+
this.nioLoopsStates.add(new NioLoopsState(this, this.nioParams));
7171
}
7272
}
7373

@@ -112,7 +112,7 @@ public FrameHandler create(Address addr) throws IOException {
112112

113113
SocketChannelFrameHandlerState state = new SocketChannelFrameHandlerState(
114114
channel,
115-
nioLoopsState.readSelectorState, nioLoopsState.writeSelectorState,
115+
nioLoopsState,
116116
nioParams,
117117
sslEngine
118118
);

src/main/java/com/rabbitmq/client/impl/nio/SocketChannelFrameHandlerState.java

+96-21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.slf4j.LoggerFactory;
2222

2323
import javax.net.ssl.SSLEngine;
24+
import java.io.DataInputStream;
25+
import java.io.DataOutputStream;
2426
import java.io.IOException;
2527
import java.nio.ByteBuffer;
2628
import java.nio.channels.SelectionKey;
@@ -56,36 +58,58 @@ public class SocketChannelFrameHandlerState {
5658

5759
final SSLEngine sslEngine;
5860

59-
/** app data to be crypted before sending */
60-
final ByteBuffer localAppData;
61-
/** crypted data to be sent */
62-
final ByteBuffer localNetData;
63-
/** app data received and decrypted */
64-
final ByteBuffer peerAppData;
65-
/** crypted data received */
66-
final ByteBuffer peerNetData;
67-
68-
public SocketChannelFrameHandlerState(SocketChannel channel, SelectorHolder readSelectorState,
69-
SelectorHolder writeSelectorState, NioParams nioParams, SSLEngine sslEngine) {
61+
/** outbound app data (to be crypted if TLS is on) */
62+
final ByteBuffer plainOut;
63+
64+
/** inbound app data (deciphered if TLS is on) */
65+
final ByteBuffer plainIn;
66+
67+
/** outbound net data (ciphered if TLS is on) */
68+
final ByteBuffer cipherOut;
69+
70+
/** inbound data (ciphered if TLS is on) */
71+
final ByteBuffer cipherIn;
72+
73+
final DataOutputStream outputStream;
74+
75+
final DataInputStream inputStream;
76+
77+
public SocketChannelFrameHandlerState(SocketChannel channel, NioLoopsState nioLoopsState, NioParams nioParams, SSLEngine sslEngine) {
7078
this.channel = channel;
71-
this.readSelectorState = readSelectorState;
72-
this.writeSelectorState = writeSelectorState;
79+
this.readSelectorState = nioLoopsState.readSelectorState;
80+
this.writeSelectorState = nioLoopsState.writeSelectorState;
7381
this.writeQueue = new ArrayBlockingQueue<WriteRequest>(nioParams.getWriteQueueCapacity(), true);
7482
this.writeEnqueuingTimeoutInMs = nioParams.getWriteEnqueuingTimeoutInMs();
7583
this.sslEngine = sslEngine;
7684
if(this.sslEngine == null) {
7785
this.ssl = false;
78-
this.localAppData = null;
79-
this.localNetData = null;
80-
this.peerAppData = null;
81-
this.peerNetData = null;
86+
this.plainOut = nioLoopsState.writeBuffer;
87+
this.cipherOut = null;
88+
this.plainIn = nioLoopsState.readBuffer;
89+
this.cipherIn = null;
90+
91+
this.outputStream = new DataOutputStream(
92+
new ByteBufferOutputStream(channel, plainOut)
93+
);
94+
this.inputStream = new DataInputStream(
95+
new ByteBufferInputStream(channel, plainIn)
96+
);
97+
8298
} else {
8399
this.ssl = true;
84-
this.localAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
85-
this.localNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
86-
this.peerAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
87-
this.peerNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
100+
this.plainOut = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
101+
this.cipherOut = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
102+
this.plainIn = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize());
103+
this.cipherIn = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
104+
105+
this.outputStream = new DataOutputStream(
106+
new SslEngineByteBufferOutputStream(sslEngine, plainOut, cipherOut, channel)
107+
);
108+
this.inputStream = new DataInputStream(
109+
new SslEngineByteBufferInputStream(sslEngine, plainIn, cipherIn, channel)
110+
);
88111
}
112+
89113
}
90114

91115
public SocketChannel getChannel() {
@@ -136,4 +160,55 @@ public void setLastActivity(long lastActivity) {
136160
public long getLastActivity() {
137161
return lastActivity;
138162
}
163+
164+
void prepareForWriteSequence() {
165+
if(ssl) {
166+
plainOut.clear();
167+
cipherOut.clear();
168+
}
169+
}
170+
171+
void endWriteSequence() {
172+
if(!ssl) {
173+
plainOut.clear();
174+
}
175+
}
176+
177+
void prepareForReadSequence() throws IOException {
178+
if(ssl) {
179+
cipherIn.clear();
180+
plainIn.clear();
181+
182+
cipherIn.flip();
183+
plainIn.flip();
184+
} else {
185+
channel.read(plainIn);
186+
plainIn.flip();
187+
}
188+
}
189+
190+
boolean continueReading() throws IOException {
191+
if(ssl) {
192+
if (!plainIn.hasRemaining() && !cipherIn.hasRemaining()) {
193+
// need to try to read something
194+
cipherIn.clear();
195+
int bytesRead = channel.read(cipherIn);
196+
if (bytesRead <= 0) {
197+
return false;
198+
} else {
199+
cipherIn.flip();
200+
return true;
201+
}
202+
} else {
203+
return true;
204+
}
205+
} else {
206+
if (!plainIn.hasRemaining()) {
207+
plainIn.clear();
208+
channel.read(plainIn);
209+
plainIn.flip();
210+
}
211+
return plainIn.hasRemaining();
212+
}
213+
}
139214
}

0 commit comments

Comments
 (0)