Skip to content

Commit 9eb921e

Browse files
authored
Support non-Json String payloads (#4450)
* Support String payloads for Json protocol * Changelog * Fix Spotbugs error * Refactoring * Remove unused import
1 parent 3ca853c commit 9eb921e

File tree

13 files changed

+482
-16
lines changed

13 files changed

+482
-16
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"category": "AWS SDK for Java v2",
3+
"contributor": "",
4+
"type": "feature",
5+
"description": "Adds support for non-Json String payloads"
6+
}

codegen/src/main/java/software/amazon/awssdk/codegen/AddOperations.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ private static boolean isBlobShape(Shape shape) {
6969
return shape != null && "blob".equals(shape.getType());
7070
}
7171

72+
/**
73+
* @return True if shape is a String type. False otherwise
74+
*/
75+
private static boolean isStringShape(Shape shape) {
76+
return shape != null && "String".equals(shape.getType());
77+
}
78+
7279
/**
7380
* If there is a member in the output shape that is explicitly marked as the payload (with the
7481
* payload trait) this method returns the target shape of that member. Otherwise this method
@@ -192,6 +199,9 @@ public Map<String, OperationModel> constructOperations() {
192199
if (isBlobShape(getPayloadShape(c2jShapes, outputShape))) {
193200
operationModel.setHasBlobMemberAsPayload(true);
194201
}
202+
if (isStringShape(getPayloadShape(c2jShapes, outputShape))) {
203+
operationModel.setHasStringMemberAsPayload(true);
204+
}
195205
}
196206

197207
if (op.getErrors() != null) {

codegen/src/main/java/software/amazon/awssdk/codegen/model/intermediate/OperationModel.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public class OperationModel extends DocumentationModel {
4848

4949
private boolean hasBlobMemberAsPayload;
5050

51+
private boolean hasStringMemberAsPayload;
52+
5153
private boolean isAuthenticated = true;
5254

5355
private AuthType authType;
@@ -211,6 +213,14 @@ public void setHasBlobMemberAsPayload(boolean hasBlobMemberAsPayload) {
211213
this.hasBlobMemberAsPayload = hasBlobMemberAsPayload;
212214
}
213215

216+
public boolean getHasStringMemberAsPayload() {
217+
return this.hasStringMemberAsPayload;
218+
}
219+
220+
public void setHasStringMemberAsPayload(boolean hasStringMemberAsPayload) {
221+
this.hasStringMemberAsPayload = hasStringMemberAsPayload;
222+
}
223+
214224
public boolean hasStreamingInput() {
215225
return inputShape != null && inputShape.isHasStreamingMember();
216226
}

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public CodeBlock responseHandler(IntermediateModel model, OperationModel opModel
142142
CodeBlock.builder()
143143
.add("$T operationMetadata = $T.builder()\n", JsonOperationMetadata.class, JsonOperationMetadata.class)
144144
.add(".hasStreamingSuccessResponse($L)\n", opModel.hasStreamingOutput())
145-
.add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload())
145+
.add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload() && !opModel.getHasStringMemberAsPayload())
146146
.add(".build();");
147147

148148
if (opModel.hasEventStreamOutput()) {

core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import java.io.ByteArrayInputStream;
2525
import java.net.URI;
26+
import java.nio.charset.StandardCharsets;
2627
import java.time.Instant;
2728
import java.util.Collections;
2829
import java.util.EnumMap;
@@ -182,6 +183,11 @@ void doMarshall(SdkPojo pojo) {
182183
if (val != null) {
183184
request.contentStreamProvider(((SdkBytes) val)::asInputStream);
184185
}
186+
} else if (isExplicitStringPayload(field)) {
187+
if (val != null) {
188+
byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8);
189+
request.contentStreamProvider(() -> new ByteArrayInputStream(content));
190+
}
185191
} else if (isExplicitPayloadMember(field)) {
186192
marshallExplicitJsonPayload(field, val);
187193
} else {
@@ -194,6 +200,10 @@ private boolean isExplicitBinaryPayload(SdkField<?> field) {
194200
return isExplicitPayloadMember(field) && MarshallingType.SDK_BYTES.equals(field.marshallingType());
195201
}
196202

203+
private boolean isExplicitStringPayload(SdkField<?> field) {
204+
return isExplicitPayloadMember(field) && MarshallingType.STRING.equals(field.marshallingType());
205+
}
206+
197207
private boolean isExplicitPayloadMember(SdkField<?> field) {
198208
return field.containsTrait(PayloadTrait.class);
199209
}

core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ public T unmarshall(JsonUnmarshallerContext context,
187187

188188
public <TypeT extends SdkPojo> TypeT unmarshall(SdkPojo sdkPojo,
189189
SdkHttpFullResponse response) throws IOException {
190-
if (hasPayloadMembersOnUnmarshall(sdkPojo) && !hasExplicitBlobPayloadMember(sdkPojo) && response.content().isPresent()) {
190+
if (hasPayloadMembersOnUnmarshall(sdkPojo)
191+
&& !hasExplicitBlobPayloadMember(sdkPojo)
192+
&& !hasExplicitStringPayloadMember(sdkPojo)
193+
&& response.content().isPresent()) {
191194
JsonNode jsonNode = parser.parse(response.content().get());
192195
return unmarshall(sdkPojo, response, jsonNode);
193196
} else {
@@ -201,6 +204,12 @@ private boolean hasExplicitBlobPayloadMember(SdkPojo sdkPojo) {
201204
.anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.SDK_BYTES);
202205
}
203206

207+
private boolean hasExplicitStringPayloadMember(SdkPojo sdkPojo) {
208+
return sdkPojo.sdkFields()
209+
.stream()
210+
.anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.STRING);
211+
}
212+
204213
private static boolean isExplicitPayloadMember(SdkField<?> f) {
205214
return f.containsTrait(PayloadTrait.class);
206215
}
@@ -234,6 +243,13 @@ private static <TypeT extends SdkPojo> TypeT unmarshallStructured(SdkPojo sdkPoj
234243
} else {
235244
field.set(sdkPojo, SdkBytes.fromByteArrayUnsafe(new byte[0]));
236245
}
246+
} else if (isExplicitPayloadMember(field) && field.marshallingType() == MarshallingType.STRING) {
247+
Optional<AbortableInputStream> responseContent = context.response().content();
248+
if (responseContent.isPresent()) {
249+
field.set(sdkPojo, SdkBytes.fromInputStream(responseContent.get()).asUtf8String());
250+
} else {
251+
field.set(sdkPojo, "");
252+
}
237253
} else {
238254
JsonNode jsonFieldContent = getJsonNode(jsonContent, field);
239255
JsonUnmarshaller<Object> unmarshaller = context.getUnmarshaller(field.location(), field.marshallingType());

test/codegen-generated-classes-test/src/main/resources/codegen-resources/customresponsemetadata/service-2.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@
111111
"input":{"shape":"OperationWithExplicitPayloadBlobInput"},
112112
"output":{"shape":"OperationWithExplicitPayloadBlobInput"}
113113
},
114+
"OperationWithExplicitPayloadString":{
115+
"name":"OperationWithExplicitPayloadString",
116+
"http":{
117+
"method":"POST",
118+
"requestUri":"/2016-03-11/operationWithExplicitPayloadString"
119+
},
120+
"input":{"shape":"OperationWithExplicitPayloadStringInput"},
121+
"output":{"shape":"OperationWithExplicitPayloadStringInput"}
122+
},
114123
"OperationWithExplicitPayloadStructure":{
115124
"name":"OperationWithExplicitPayloadStructure",
116125
"http":{
@@ -697,6 +706,13 @@
697706
},
698707
"payload":"PayloadMember"
699708
},
709+
"OperationWithExplicitPayloadStringInput":{
710+
"type":"structure",
711+
"members":{
712+
"PayloadMember":{"shape":"String"}
713+
},
714+
"payload":"PayloadMember"
715+
},
700716
"OperationWithExplicitPayloadStructureInput":{
701717
"type":"structure",
702718
"members":{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.protocolrestjson;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
22+
import static org.assertj.core.api.Assertions.assertThat;
23+
24+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
25+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
26+
import java.net.URI;
27+
import java.util.ArrayList;
28+
import java.util.List;
29+
import org.junit.jupiter.params.ParameterizedTest;
30+
import org.junit.jupiter.params.provider.Arguments;
31+
import org.junit.jupiter.params.provider.MethodSource;
32+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
33+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
34+
import software.amazon.awssdk.regions.Region;
35+
36+
@WireMockTest
37+
public class StringPayloadUnmarshallingTest {
38+
private static final String TEST_PAYLOAD = "X";
39+
40+
private static List<Arguments> testParameters() {
41+
List<Arguments> testCases = new ArrayList<>();
42+
for (ClientType clientType : ClientType.values()) {
43+
for (Protocol protocol : Protocol.values()) {
44+
for (StringLocation value : StringLocation.values()) {
45+
for (ContentLength contentLength : ContentLength.values()) {
46+
testCases.add(Arguments.arguments(clientType, protocol, value, contentLength));
47+
}
48+
}
49+
}
50+
}
51+
return testCases;
52+
}
53+
54+
private enum ClientType {
55+
SYNC,
56+
ASYNC
57+
}
58+
59+
private enum Protocol {
60+
JSON
61+
// TODO - add support for XML
62+
}
63+
64+
private enum StringLocation {
65+
PAYLOAD,
66+
FIELD
67+
}
68+
69+
private enum ContentLength {
70+
ZERO,
71+
NOT_PRESENT
72+
}
73+
74+
@ParameterizedTest
75+
@MethodSource("testParameters")
76+
public void missingStringPayload_unmarshalledCorrectly(ClientType clientType,
77+
Protocol protocol,
78+
StringLocation stringLoc,
79+
ContentLength contentLength,
80+
WireMockRuntimeInfo wm) {
81+
if (contentLength == ContentLength.ZERO) {
82+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("Content-Length", "0").withBody("")));
83+
} else if (contentLength == ContentLength.NOT_PRESENT) {
84+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody("")));
85+
}
86+
87+
String serviceResult = callService(wm, clientType, protocol, stringLoc);
88+
89+
if (stringLoc == StringLocation.PAYLOAD) {
90+
assertThat(serviceResult).isNotNull().isEqualTo("");
91+
} else if (stringLoc == StringLocation.FIELD) {
92+
assertThat(serviceResult).isNull();
93+
}
94+
}
95+
96+
@ParameterizedTest
97+
@MethodSource("testParameters")
98+
public void presentStringPayload_unmarshalledCorrectly(ClientType clientType,
99+
Protocol protocol,
100+
StringLocation stringLoc,
101+
ContentLength contentLength,
102+
WireMockRuntimeInfo wm) {
103+
String responsePayload = presentStringResponse(protocol, stringLoc);
104+
105+
if (contentLength == ContentLength.ZERO) {
106+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200)
107+
.withHeader("Content-Length", Integer.toString(responsePayload.length()))
108+
.withBody(responsePayload)));
109+
} else if (contentLength == ContentLength.NOT_PRESENT) {
110+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(responsePayload)));
111+
}
112+
113+
assertThat(callService(wm, clientType, protocol, stringLoc)).isEqualTo(TEST_PAYLOAD);
114+
}
115+
116+
private String presentStringResponse(Protocol protocol, StringLocation stringLoc) {
117+
switch (stringLoc) {
118+
case PAYLOAD: return TEST_PAYLOAD;
119+
case FIELD:
120+
switch (protocol) {
121+
case JSON: return "{\"StringMember\": \"X\"}";
122+
// TODO - add support for XML
123+
default: throw new UnsupportedOperationException();
124+
}
125+
default: throw new UnsupportedOperationException();
126+
}
127+
128+
}
129+
130+
private String callService(WireMockRuntimeInfo wm, ClientType clientType, Protocol protocol, StringLocation stringLoc) {
131+
switch (clientType) {
132+
case SYNC: return syncCallService(wm, protocol, stringLoc);
133+
case ASYNC: return asyncCallService(wm, protocol, stringLoc);
134+
default: throw new UnsupportedOperationException();
135+
}
136+
}
137+
138+
private String syncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) {
139+
switch (protocol) {
140+
case JSON: return syncJsonCallService(wm, stringLoc);
141+
// TODO - add support for XML
142+
default: throw new UnsupportedOperationException();
143+
}
144+
}
145+
146+
private String asyncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) {
147+
switch (protocol) {
148+
case JSON: return asyncJsonCallService(wm, stringLoc);
149+
// TODO - add support for XML
150+
default: throw new UnsupportedOperationException();
151+
}
152+
}
153+
154+
private String syncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) {
155+
ProtocolRestJsonClient client =
156+
ProtocolRestJsonClient.builder()
157+
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
158+
.region(Region.US_EAST_1)
159+
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
160+
.build();
161+
switch (stringLoc) {
162+
case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).payloadMember();
163+
case FIELD: return client.allTypes(r -> {}).stringMember();
164+
default: throw new UnsupportedOperationException();
165+
}
166+
}
167+
168+
private String asyncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) {
169+
ProtocolRestJsonAsyncClient client =
170+
ProtocolRestJsonAsyncClient.builder()
171+
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
172+
.region(Region.US_EAST_1)
173+
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
174+
.build();
175+
176+
switch (stringLoc) {
177+
case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).join().payloadMember();
178+
case FIELD: return client.allTypes(r -> {}).join().stringMember();
179+
default: throw new UnsupportedOperationException();
180+
}
181+
}
182+
}

0 commit comments

Comments
 (0)