Skip to content

Commit 0ba6102

Browse files
arturobernalgok2c
authored andcommitted
Simplify ProtocolSwitchStrategy by Leveraging ProtocolVersionParser (#627)
Unify HTTP and TLS token parsing in the Upgrade header by replacing custom version parsing with ProtocolVersionParser. This change removes redundant code and ensures that only supported protocols (HTTP/ and TLS tokens) are accepted, while all other upgrade protocols are rejected as unsupported.
1 parent ffc12f1 commit 0ba6102

File tree

2 files changed

+232
-21
lines changed

2 files changed

+232
-21
lines changed

httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java

Lines changed: 97 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,23 @@
2727
package org.apache.hc.client5.http.impl;
2828

2929
import java.util.Iterator;
30+
import java.util.concurrent.atomic.AtomicReference;
3031

3132
import org.apache.hc.core5.annotation.Internal;
33+
import org.apache.hc.core5.http.FormattedHeader;
34+
import org.apache.hc.core5.http.Header;
3235
import org.apache.hc.core5.http.HttpHeaders;
3336
import org.apache.hc.core5.http.HttpMessage;
37+
import org.apache.hc.core5.http.HttpVersion;
3438
import org.apache.hc.core5.http.ParseException;
3539
import org.apache.hc.core5.http.ProtocolException;
3640
import org.apache.hc.core5.http.ProtocolVersion;
37-
import org.apache.hc.core5.http.message.MessageSupport;
41+
import org.apache.hc.core5.http.ProtocolVersionParser;
42+
import org.apache.hc.core5.http.message.ParserCursor;
3843
import org.apache.hc.core5.http.ssl.TLS;
44+
import org.apache.hc.core5.util.Args;
45+
import org.apache.hc.core5.util.CharArrayBuffer;
46+
import org.apache.hc.core5.util.Tokenizer;
3947

4048
/**
4149
* Protocol switch handler.
@@ -45,31 +53,100 @@
4553
@Internal
4654
public final class ProtocolSwitchStrategy {
4755

48-
enum ProtocolSwitch { FAILURE, TLS }
56+
private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE;
57+
private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE;
58+
private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(',');
59+
private static final Tokenizer.Delimiter LAX_PROTO_DELIMITER = Tokenizer.delimiters('/', ',');
60+
61+
@FunctionalInterface
62+
private interface HeaderConsumer {
63+
64+
void accept(CharSequence buffer, ParserCursor cursor) throws ProtocolException;
65+
66+
}
4967

5068
public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
51-
final Iterator<String> it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE);
69+
final AtomicReference<ProtocolVersion> tlsUpgrade = new AtomicReference<>();
5270

53-
ProtocolVersion tlsUpgrade = null;
54-
while (it.hasNext()) {
55-
final String token = it.next();
56-
if (token.startsWith("TLS")) {
57-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
58-
try {
59-
tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv"));
60-
} catch (final ParseException ex) {
61-
throw new ProtocolException("Invalid protocol: " + token);
71+
parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> {
72+
final ProtocolVersion protocolVersion = parseProtocolVersion(buffer, cursor);
73+
if (protocolVersion != null) {
74+
if ("TLS".equalsIgnoreCase(protocolVersion.getProtocol())) {
75+
tlsUpgrade.set(protocolVersion);
76+
} else if (!protocolVersion.equals(HttpVersion.HTTP_1_1)) {
77+
throw new ProtocolException("Unsupported protocol or HTTP version: " + protocolVersion);
6278
}
63-
} else if (token.equals("HTTP/1.1")) {
64-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
65-
} else {
66-
throw new ProtocolException("Unsupported protocol: " + token);
6779
}
80+
});
81+
82+
final ProtocolVersion result = tlsUpgrade.get();
83+
if (result != null) {
84+
return result;
85+
} else {
86+
throw new ProtocolException("Invalid protocol switch response: no TLS version found");
6887
}
69-
if (tlsUpgrade == null) {
70-
throw new ProtocolException("Invalid protocol switch response");
88+
}
89+
90+
private ProtocolVersion parseProtocolVersion(final CharSequence buffer, final ParserCursor cursor) throws ProtocolException {
91+
TOKENIZER.skipWhiteSpace(buffer, cursor);
92+
final String proto = TOKENIZER.parseToken(buffer, cursor, LAX_PROTO_DELIMITER);
93+
if (!cursor.atEnd()) {
94+
final char ch = buffer.charAt(cursor.getPos());
95+
if (ch == '/') {
96+
if (proto.isEmpty()) {
97+
throw new ParseException("Invalid protocol", buffer, cursor.getLowerBound(), cursor.getUpperBound(), cursor.getPos());
98+
}
99+
cursor.updatePos(cursor.getPos() + 1);
100+
return PROTOCOL_VERSION_PARSER.parse(proto, null, buffer, cursor, UPGRADE_TOKEN_DELIMITER);
101+
}
102+
}
103+
if (proto.isEmpty()) {
104+
return null;
105+
} else if (proto.equalsIgnoreCase("TLS")) {
106+
return TLS.V_1_2.getVersion();
107+
} else {
108+
throw new ProtocolException("Unsupported or invalid protocol: " + proto);
109+
}
110+
}
111+
112+
113+
private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer)
114+
throws ProtocolException {
115+
final Iterator<Header> it = message.headerIterator(name);
116+
while (it.hasNext()) {
117+
parseHeader(it.next(), consumer);
118+
}
119+
}
120+
121+
private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException {
122+
Args.notNull(header, "Header");
123+
if (header instanceof FormattedHeader) {
124+
final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer();
125+
final ParserCursor cursor = new ParserCursor(0, buf.length());
126+
cursor.updatePos(((FormattedHeader) header).getValuePos());
127+
parseHeaderElements(buf, cursor, consumer);
128+
} else {
129+
final String value = header.getValue();
130+
if (value == null) {
131+
return;
132+
}
133+
final ParserCursor cursor = new ParserCursor(0, value.length());
134+
parseHeaderElements(value, cursor, consumer);
135+
}
136+
}
137+
138+
private void parseHeaderElements(final CharSequence buffer,
139+
final ParserCursor cursor,
140+
final HeaderConsumer consumer) throws ProtocolException {
141+
while (!cursor.atEnd()) {
142+
consumer.accept(buffer, cursor);
143+
if (!cursor.atEnd()) {
144+
final char ch = buffer.charAt(cursor.getPos());
145+
if (ch == ',') {
146+
cursor.updatePos(cursor.getPos() + 1);
147+
}
148+
}
71149
}
72-
return tlsUpgrade;
73150
}
74151

75-
}
152+
}

httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
import org.apache.hc.core5.http.HttpResponse;
3131
import org.apache.hc.core5.http.HttpStatus;
3232
import org.apache.hc.core5.http.ProtocolException;
33+
import org.apache.hc.core5.http.ProtocolVersion;
3334
import org.apache.hc.core5.http.message.BasicHttpResponse;
3435
import org.apache.hc.core5.http.ssl.TLS;
3536
import org.junit.jupiter.api.Assertions;
3637
import org.junit.jupiter.api.BeforeEach;
3738
import org.junit.jupiter.api.Test;
3839

3940
/**
40-
* Simple tests for {@link DefaultAuthenticationStrategy}.
41+
* Simple tests for {@link ProtocolSwitchStrategy}.
4142
*/
4243
class TestProtocolSwitchStrategy {
4344

@@ -95,4 +96,137 @@ void testSwitchInvalid() {
9596
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
9697
}
9798

99+
@Test
100+
void testNullToken() throws ProtocolException {
101+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
102+
response.addHeader(HttpHeaders.UPGRADE, "TLS,");
103+
response.addHeader(HttpHeaders.UPGRADE, null);
104+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
105+
}
106+
107+
@Test
108+
void testWhitespaceOnlyToken() throws ProtocolException {
109+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
110+
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
111+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
112+
}
113+
114+
@Test
115+
void testUnsupportedTlsVersion() throws Exception {
116+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
117+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
118+
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
119+
}
120+
121+
@Test
122+
void testUnsupportedTlsMajorVersion() throws Exception {
123+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
124+
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
125+
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
126+
}
127+
128+
@Test
129+
void testUnsupportedHttpVersion() {
130+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
131+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
132+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
133+
switchStrategy.switchProtocol(response));
134+
Assertions.assertEquals("Unsupported protocol or HTTP version: HTTP/2.0", ex.getMessage());
135+
}
136+
137+
@Test
138+
void testInvalidTlsFormat() {
139+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
140+
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
141+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
142+
switchStrategy.switchProtocol(response));
143+
Assertions.assertEquals("Invalid TLS major version number; error at offset 7: <TLS/abc>", ex.getMessage());
144+
}
145+
146+
@Test
147+
void testHttp11Only() {
148+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
149+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
150+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
151+
switchStrategy.switchProtocol(response));
152+
Assertions.assertEquals("Invalid protocol switch response: no TLS version found", ex.getMessage());
153+
}
154+
155+
@Test
156+
void testSwitchToTlsValid_TLS_1_2() throws Exception {
157+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
158+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
159+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
160+
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
161+
}
162+
163+
@Test
164+
void testSwitchToTlsValid_TLS_1_0() throws Exception {
165+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
166+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
167+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
168+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
169+
}
170+
171+
@Test
172+
void testSwitchToTlsValid_TLS_1_1() throws Exception {
173+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
174+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
175+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
176+
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
177+
}
178+
179+
@Test
180+
void testInvalidTlsFormat_NoSlash() {
181+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
182+
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
183+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
184+
switchStrategy.switchProtocol(response));
185+
Assertions.assertEquals("Unsupported or invalid protocol: TLSv1", ex.getMessage());
186+
}
187+
188+
@Test
189+
void testSwitchToTlsValid_TLS_1() throws Exception {
190+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
191+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
192+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
193+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
194+
}
195+
196+
@Test
197+
void testInvalidTlsFormat_MissingMajor() {
198+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
199+
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
200+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
201+
switchStrategy.switchProtocol(response));
202+
Assertions.assertEquals("Invalid TLS major version number; error at offset 4: <TLS/.1>", ex.getMessage());
203+
}
204+
205+
@Test
206+
void testMultipleHttp11Tokens() {
207+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
208+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
209+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
210+
switchStrategy.switchProtocol(response));
211+
Assertions.assertEquals("Invalid protocol switch response: no TLS version found", ex.getMessage());
212+
}
213+
214+
@Test
215+
void testMixedInvalidAndValidTokens() {
216+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
217+
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
218+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
219+
switchStrategy.switchProtocol(response));
220+
Assertions.assertEquals("Unsupported or invalid protocol: Crap", ex.getMessage());
221+
}
222+
223+
@Test
224+
void testInvalidTlsFormat_NoProtocolName() {
225+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
226+
response.addHeader(HttpHeaders.UPGRADE, ",,/1.1");
227+
final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () ->
228+
switchStrategy.switchProtocol(response));
229+
Assertions.assertEquals("Invalid protocol; error at offset 2: <,,/1.1>", ex.getMessage());
230+
}
231+
98232
}

0 commit comments

Comments
 (0)