|
21 | 21 | import org.slf4j.LoggerFactory;
|
22 | 22 |
|
23 | 23 | import javax.net.ssl.SSLEngine;
|
| 24 | +import java.io.DataInputStream; |
| 25 | +import java.io.DataOutputStream; |
24 | 26 | import java.io.IOException;
|
25 | 27 | import java.nio.ByteBuffer;
|
26 | 28 | import java.nio.channels.SelectionKey;
|
@@ -56,36 +58,58 @@ public class SocketChannelFrameHandlerState {
|
56 | 58 |
|
57 | 59 | final SSLEngine sslEngine;
|
58 | 60 |
|
59 |
| - /** app data to be crypted before sending */ |
60 |
| - final ByteBuffer localAppData; |
61 |
| - /** crypted data to be sent */ |
62 |
| - final ByteBuffer localNetData; |
63 |
| - /** app data received and decrypted */ |
64 |
| - final ByteBuffer peerAppData; |
65 |
| - /** crypted data received */ |
66 |
| - final ByteBuffer peerNetData; |
67 |
| - |
68 |
| - public SocketChannelFrameHandlerState(SocketChannel channel, SelectorHolder readSelectorState, |
69 |
| - SelectorHolder writeSelectorState, NioParams nioParams, SSLEngine sslEngine) { |
| 61 | + /** outbound app data (to be crypted if TLS is on) */ |
| 62 | + final ByteBuffer plainOut; |
| 63 | + |
| 64 | + /** inbound app data (deciphered if TLS is on) */ |
| 65 | + final ByteBuffer plainIn; |
| 66 | + |
| 67 | + /** outbound net data (ciphered if TLS is on) */ |
| 68 | + final ByteBuffer cipherOut; |
| 69 | + |
| 70 | + /** inbound data (ciphered if TLS is on) */ |
| 71 | + final ByteBuffer cipherIn; |
| 72 | + |
| 73 | + final DataOutputStream outputStream; |
| 74 | + |
| 75 | + final DataInputStream inputStream; |
| 76 | + |
| 77 | + public SocketChannelFrameHandlerState(SocketChannel channel, NioLoopsState nioLoopsState, NioParams nioParams, SSLEngine sslEngine) { |
70 | 78 | this.channel = channel;
|
71 |
| - this.readSelectorState = readSelectorState; |
72 |
| - this.writeSelectorState = writeSelectorState; |
| 79 | + this.readSelectorState = nioLoopsState.readSelectorState; |
| 80 | + this.writeSelectorState = nioLoopsState.writeSelectorState; |
73 | 81 | this.writeQueue = new ArrayBlockingQueue<WriteRequest>(nioParams.getWriteQueueCapacity(), true);
|
74 | 82 | this.writeEnqueuingTimeoutInMs = nioParams.getWriteEnqueuingTimeoutInMs();
|
75 | 83 | this.sslEngine = sslEngine;
|
76 | 84 | if(this.sslEngine == null) {
|
77 | 85 | this.ssl = false;
|
78 |
| - this.localAppData = null; |
79 |
| - this.localNetData = null; |
80 |
| - this.peerAppData = null; |
81 |
| - this.peerNetData = null; |
| 86 | + this.plainOut = nioLoopsState.writeBuffer; |
| 87 | + this.cipherOut = null; |
| 88 | + this.plainIn = nioLoopsState.readBuffer; |
| 89 | + this.cipherIn = null; |
| 90 | + |
| 91 | + this.outputStream = new DataOutputStream( |
| 92 | + new ByteBufferOutputStream(channel, plainOut) |
| 93 | + ); |
| 94 | + this.inputStream = new DataInputStream( |
| 95 | + new ByteBufferInputStream(channel, plainIn) |
| 96 | + ); |
| 97 | + |
82 | 98 | } else {
|
83 | 99 | this.ssl = true;
|
84 |
| - this.localAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); |
85 |
| - this.localNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); |
86 |
| - this.peerAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); |
87 |
| - this.peerNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); |
| 100 | + this.plainOut = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); |
| 101 | + this.cipherOut = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); |
| 102 | + this.plainIn = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); |
| 103 | + this.cipherIn = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); |
| 104 | + |
| 105 | + this.outputStream = new DataOutputStream( |
| 106 | + new SslEngineByteBufferOutputStream(sslEngine, plainOut, cipherOut, channel) |
| 107 | + ); |
| 108 | + this.inputStream = new DataInputStream( |
| 109 | + new SslEngineByteBufferInputStream(sslEngine, plainIn, cipherIn, channel) |
| 110 | + ); |
88 | 111 | }
|
| 112 | + |
89 | 113 | }
|
90 | 114 |
|
91 | 115 | public SocketChannel getChannel() {
|
@@ -136,4 +160,55 @@ public void setLastActivity(long lastActivity) {
|
136 | 160 | public long getLastActivity() {
|
137 | 161 | return lastActivity;
|
138 | 162 | }
|
| 163 | + |
| 164 | + void prepareForWriteSequence() { |
| 165 | + if(ssl) { |
| 166 | + plainOut.clear(); |
| 167 | + cipherOut.clear(); |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + void endWriteSequence() { |
| 172 | + if(!ssl) { |
| 173 | + plainOut.clear(); |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + void prepareForReadSequence() throws IOException { |
| 178 | + if(ssl) { |
| 179 | + cipherIn.clear(); |
| 180 | + plainIn.clear(); |
| 181 | + |
| 182 | + cipherIn.flip(); |
| 183 | + plainIn.flip(); |
| 184 | + } else { |
| 185 | + channel.read(plainIn); |
| 186 | + plainIn.flip(); |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + boolean continueReading() throws IOException { |
| 191 | + if(ssl) { |
| 192 | + if (!plainIn.hasRemaining() && !cipherIn.hasRemaining()) { |
| 193 | + // need to try to read something |
| 194 | + cipherIn.clear(); |
| 195 | + int bytesRead = channel.read(cipherIn); |
| 196 | + if (bytesRead <= 0) { |
| 197 | + return false; |
| 198 | + } else { |
| 199 | + cipherIn.flip(); |
| 200 | + return true; |
| 201 | + } |
| 202 | + } else { |
| 203 | + return true; |
| 204 | + } |
| 205 | + } else { |
| 206 | + if (!plainIn.hasRemaining()) { |
| 207 | + plainIn.clear(); |
| 208 | + channel.read(plainIn); |
| 209 | + plainIn.flip(); |
| 210 | + } |
| 211 | + return plainIn.hasRemaining(); |
| 212 | + } |
| 213 | + } |
139 | 214 | }
|
0 commit comments