Skip to content

Add support for marshalling lists of strings in HTTP headers... #2588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-641dd1e.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "AWS SDK for Java v2",
"contributor": "Bennett-Lynch",
"type": "feature",
"description": "Add support for marshalling lists of strings in HTTP headers"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.traits.JsonValueTrait;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;
import software.amazon.awssdk.utils.BinaryUtils;

Expand All @@ -45,6 +48,17 @@ public final class HeaderMarshaller {
public static final JsonMarshaller<Instant> INSTANT
= new SimpleHeaderMarshaller<>(JsonProtocolMarshaller.INSTANT_VALUE_TO_STRING);

public static final JsonMarshaller<List<?>> LIST = (list, context, paramName, sdkField) -> {
if (list.isEmpty()) {
return;
}
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class).memberFieldInfo();
for (Object listValue : list) {
JsonMarshaller marshaller = context.marshallerRegistry().getMarshaller(MarshallLocation.HEADER, listValue);
marshaller.marshall(listValue, context, paramName, memberFieldInfo);
}
};

private HeaderMarshaller() {
}

Expand All @@ -58,8 +72,7 @@ private SimpleHeaderMarshaller(ValueToStringConverter.ValueToString<T> converter

@Override
public void marshall(T val, JsonMarshallerContext context, String paramName, SdkField<T> sdkField) {
context.request().putHeader(paramName, converter.convert(val, sdkField));
context.request().appendHeader(paramName, converter.convert(val, sdkField));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ private static JsonMarshallerRegistry createMarshallerRegistry() {
.headerMarshaller(MarshallingType.FLOAT, HeaderMarshaller.FLOAT)
.headerMarshaller(MarshallingType.BOOLEAN, HeaderMarshaller.BOOLEAN)
.headerMarshaller(MarshallingType.INSTANT, HeaderMarshaller.INSTANT)
.headerMarshaller(MarshallingType.LIST, HeaderMarshaller.LIST)
.headerMarshaller(MarshallingType.NULL, JsonMarshaller.NULL)

.queryParamMarshaller(MarshallingType.STRING, QueryParamMarshaller.STRING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@

package software.amazon.awssdk.protocols.json.internal.unmarshall;

import static java.util.stream.Collectors.toList;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.JsonValueTrait;
import software.amazon.awssdk.protocols.core.StringToValueConverter;
import software.amazon.awssdk.protocols.json.internal.dom.SdkJsonNode;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.http.SdkHttpUtils;

/**
* Header unmarshallers for all the simple types we support.
Expand All @@ -39,6 +43,11 @@ final class HeaderUnmarshaller {
public static final JsonUnmarshaller<Boolean> BOOLEAN = new SimpleHeaderUnmarshaller<>(StringToValueConverter.TO_BOOLEAN);
public static final JsonUnmarshaller<Float> FLOAT = new SimpleHeaderUnmarshaller<>(StringToValueConverter.TO_FLOAT);

// Only supports string value type
public static final JsonUnmarshaller<List<?>> LIST = (context, jsonContent, field) -> {
return SdkHttpUtils.allMatchingHeaders(context.response().headers(), field.locationName()).collect(toList());
};

private HeaderUnmarshaller() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private static JsonUnmarshallerRegistry createUnmarshallerRegistry(
.headerUnmarshaller(MarshallingType.BOOLEAN, HeaderUnmarshaller.BOOLEAN)
.headerUnmarshaller(MarshallingType.INSTANT, HeaderUnmarshaller.createInstantHeaderUnmarshaller(instantStringToValue))
.headerUnmarshaller(MarshallingType.FLOAT, HeaderUnmarshaller.FLOAT)
.headerUnmarshaller(MarshallingType.LIST, HeaderUnmarshaller.LIST)

.payloadUnmarshaller(MarshallingType.STRING, new SimpleTypeJsonUnmarshaller<>(StringToValueConverter.TO_STRING))
.payloadUnmarshaller(MarshallingType.INTEGER, new SimpleTypeJsonUnmarshaller<>(StringToValueConverter.TO_INTEGER))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
package software.amazon.awssdk.protocols.xml.internal.marshall;

import java.time.Instant;
import java.util.List;
import java.util.Map;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;

@SdkInternalApi
Expand Down Expand Up @@ -66,6 +68,24 @@ protected boolean shouldEmit(Map map) {
}
};

public static final XmlMarshaller<List<?>> LIST = new SimpleHeaderMarshaller<List<?>>(null) {
@Override
public void marshall(List<?> list, XmlMarshallerContext context, String paramName, SdkField<List<?>> sdkField) {
if (!shouldEmit(list)) {
return;
}
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class).memberFieldInfo();
for (Object listValue : list) {
XmlMarshaller marshaller = context.marshallerRegistry().getMarshaller(MarshallLocation.HEADER, listValue);
marshaller.marshall(listValue, context, paramName, memberFieldInfo);
}
}

@Override
protected boolean shouldEmit(List list) {
return list != null && !list.isEmpty();
}
};

private HeaderMarshaller() {
}
Expand All @@ -83,7 +103,7 @@ public void marshall(T val, XmlMarshallerContext context, String paramName, SdkF
return;
}

context.request().putHeader(paramName, converter.convert(val, sdkField));
context.request().appendHeader(paramName, converter.convert(val, sdkField));
}

protected boolean shouldEmit(T val) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ public final class QueryParamMarshaller {
return;
}

MapTrait mapTrait = sdkField.getOptionalTrait(MapTrait.class)
.orElseThrow(() -> new IllegalStateException("SdkField of list type is missing List trait"));
MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class);
SdkField valueField = mapTrait.valueFieldInfo();

for (Map.Entry<String, ?> entry : map.entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ public void marshall(List<?> val, XmlMarshallerContext context, String paramName
@Override
public void marshall(List<?> list, XmlMarshallerContext context, String paramName,
SdkField<List<?>> sdkField, ValueToStringConverter.ValueToString<List<?>> converter) {
ListTrait listTrait = sdkField
.getOptionalTrait(ListTrait.class)
.orElseThrow(() -> new IllegalStateException(paramName + " member is missing ListTrait"));
ListTrait listTrait = sdkField.getRequiredTrait(ListTrait.class);

if (!listTrait.isFlattened()) {
context.xmlGenerator().startElement(paramName);
Expand Down Expand Up @@ -125,8 +123,7 @@ protected boolean shouldEmit(List list, String paramName) {
public void marshall(Map<String, ?> map, XmlMarshallerContext context, String paramName,
SdkField<Map<String, ?>> sdkField, ValueToStringConverter.ValueToString<Map<String, ?>> converter) {

MapTrait mapTrait = sdkField.getOptionalTrait(MapTrait.class)
.orElseThrow(() -> new IllegalStateException(paramName + " member is missing MapTrait"));
MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class);

for (Map.Entry<String, ?> entry : map.entrySet()) {
context.xmlGenerator().startElement("entry");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ private static XmlMarshallerRegistry createMarshallerRegistry() {
.headerMarshaller(MarshallingType.BOOLEAN, HeaderMarshaller.BOOLEAN)
.headerMarshaller(MarshallingType.INSTANT, HeaderMarshaller.INSTANT)
.headerMarshaller(MarshallingType.MAP, HeaderMarshaller.MAP)
.headerMarshaller(MarshallingType.LIST, HeaderMarshaller.LIST)
.headerMarshaller(MarshallingType.NULL, XmlMarshaller.NULL)

.queryParamMarshaller(MarshallingType.STRING, QueryParamMarshaller.STRING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.protocols.xml.internal.unmarshall;

import static java.util.stream.Collectors.toList;
import static software.amazon.awssdk.utils.StringUtils.replacePrefixIgnoreCase;
import static software.amazon.awssdk.utils.StringUtils.startsWithIgnoreCase;

Expand All @@ -26,6 +27,7 @@
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.protocols.core.StringToValueConverter;
import software.amazon.awssdk.protocols.query.unmarshall.XmlElement;
import software.amazon.awssdk.utils.http.SdkHttpUtils;

@SdkInternalApi
public final class HeaderUnmarshaller {
Expand All @@ -39,6 +41,7 @@ public final class HeaderUnmarshaller {
public static final XmlUnmarshaller<Instant> INSTANT =
new SimpleHeaderUnmarshaller<>(XmlProtocolUnmarshaller.INSTANT_STRING_TO_VALUE);

// Only supports string value type
public static final XmlUnmarshaller<Map<String, ?>> MAP = ((context, content, field) -> {
Map<String, String> result = new HashMap<>();
context.response().headers().entrySet().stream()
Expand All @@ -48,6 +51,11 @@ public final class HeaderUnmarshaller {
return result;
});

// Only supports string value type
public static final XmlUnmarshaller<List<?>> LIST = (context, content, field) -> {
return SdkHttpUtils.allMatchingHeaders(context.response().headers(), field.locationName()).collect(toList());
};

private HeaderUnmarshaller() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ private static XmlUnmarshallerRegistry createUnmarshallerRegistry() {
.headerUnmarshaller(MarshallingType.INSTANT, HeaderUnmarshaller.INSTANT)
.headerUnmarshaller(MarshallingType.FLOAT, HeaderUnmarshaller.FLOAT)
.headerUnmarshaller(MarshallingType.MAP, HeaderUnmarshaller.MAP)
.headerUnmarshaller(MarshallingType.LIST, HeaderUnmarshaller.LIST)

.payloadUnmarshaller(MarshallingType.STRING, XmlPayloadUnmarshaller.STRING)
.payloadUnmarshaller(MarshallingType.INTEGER, XmlPayloadUnmarshaller.INTEGER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ public <T extends Trait> Optional<T> getOptionalTrait(Class<T> clzz) {
return Optional.ofNullable((T) traits.get(clzz));
}

/**
* Gets the trait of the specified class, or throw {@link IllegalStateException} if not available.
*
* @param clzz Trait class to get.
* @param <T> Type of trait.
* @return Trait instance.
* @throws IllegalStateException if trait is not present.
*/
@SuppressWarnings("unchecked")
public <T extends Trait> T getRequiredTrait(Class<T> clzz) throws IllegalStateException {
T trait = (T) traits.get(clzz);
if (trait == null) {
throw new IllegalStateException(memberName + " member is missing " + clzz.getSimpleName());
}
return trait;
}

/**
* Checks if a given {@link Trait} is present on the field.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import com.github.tomakehurst.wiremock.http.HttpHeaders;
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
Expand All @@ -28,11 +29,11 @@
*/
public class HeadersAssertion extends MarshallingAssertion {

private Map<String, String> contains;
private Map<String, List<String>> contains;

private List<String> doesNotContain;

public void setContains(Map<String, String> contains) {
public void setContains(Map<String, List<String>> contains) {
this.contains = contains;
}

Expand All @@ -51,8 +52,11 @@ protected void doAssert(LoggedRequest actual) throws Exception {
}

private void assertHeadersContains(HttpHeaders actual) {
contains.entrySet().forEach(e -> {
assertEquals(e.getValue(), actual.getHeader(e.getKey()).firstValue());
contains.forEach((expectedKey, expectedValues) -> {
assertTrue(String.format("Header '%s' was expected to be present. Actual headers: %s", expectedKey, actual),
actual.getHeader(expectedKey).isPresent());
List<String> actualValues = actual.getHeader(expectedKey).values();
assertEquals(expectedValues, actualValues);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ public abstract class MarshallingAssertion {
* @throws AssertionError If any assertions fail
*/
public final void assertMatches(LoggedRequest actual) throws AssertionError {
// Catches the exception to play nicer with lambda's
// Wrap checked exceptions to play nicer with lambda's
try {
doAssert(actual);
} catch (Error | RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
package software.amazon.awssdk.protocol.model;

import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;

public class GivenResponse {

@JsonProperty(value = "status_code")
private Integer statusCode;
private Map<String, String> headers;
private Map<String, List<String>> headers;
private String body;

public Integer getStatusCode() {
Expand All @@ -33,11 +34,11 @@ public void setStatusCode(Integer statusCode) {
this.statusCode = statusCode;
}

public Map<String, String> getHeaders() {
public Map<String, List<String>> getHeaders() {
return headers;
}

public void setHeaders(Map<String, String> headers) {
public void setHeaders(Map<String, List<String>> headers) {
this.headers = headers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ private ResponseDefinitionBuilder toResponseBuilder(GivenResponse givenResponse)

ResponseDefinitionBuilder responseBuilder = aResponse().withStatus(200);
if (givenResponse.getHeaders() != null) {
givenResponse.getHeaders().forEach(responseBuilder::withHeader);
givenResponse.getHeaders().forEach((key, values) -> {
responseBuilder.withHeader(key, values.toArray(new String[0]));
});
}
if (givenResponse.getStatusCode() != null) {
responseBuilder.withStatus(givenResponse.getStatusCode());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,36 @@
"uri": "/2016-03-11/operationWithGreedyLabel/pathParamValue/foo/bar/baz"
}
}
},
{
"description": "ListOfStrings in header is serialized as multi-valued header",
"given": {
"input": {
"StringMember": "singleValue",
"ListOfStringsMember": [
"listValueOne",
"listValueTwo"
]
}
},
"when": {
"action": "marshall",
"operation": "MembersInHeaders"
},
"then": {
"serializedAs": {
"uri": "/2016-03-11/membersInHeaders",
"headers": {
"contains": {
"x-amz-string": "singleValue",
"x-amz-string-list": [
"listValueOne",
"listValueTwo"
]
}
}
}
}
}
// TODO This is a post process customization for API Gateway
// {
Expand Down
Loading