Skip to content

Commit c895e9f

Browse files
authored
Merge pull request #84 from evanuk/migrate-to-imdsv2
Migrate from IMDSv1 to IMDSv2
2 parents 8ac0e9f + fba1075 commit c895e9f

File tree

5 files changed

+115
-14
lines changed

5 files changed

+115
-14
lines changed

src/main/java/software/amazon/cloudwatchlogs/emf/environment/EC2Environment.java

+36-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
2020
import java.net.URI;
21+
import java.util.Collections;
2122
import lombok.Data;
2223
import lombok.extern.slf4j.Slf4j;
24+
import org.javatuples.Pair;
2325
import software.amazon.cloudwatchlogs.emf.Constants;
2426
import software.amazon.cloudwatchlogs.emf.config.Configuration;
2527
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
@@ -33,7 +35,13 @@ public class EC2Environment extends AgentBasedEnvironment {
3335

3436
private static final String INSTANCE_IDENTITY_URL =
3537
"http://169.254.169.254/latest/dynamic/instance-identity/document";
38+
39+
private static final String INSTANCE_TOKEN_URL = "http://169.254.169.254/latest/api/token";
3640
private static final String CFN_EC2_TYPE = "AWS::EC2::Instance";
41+
private static final String TOKEN_REQUEST_HEADER_KEY = "X-aws-ec2-metadata-token-ttl-seconds";
42+
private static final String TOKEN_REQUEST_HEADER_VALUE = "21600";
43+
44+
private static final String METADATA_REQUEST_TOKEN_HEADER_KEY = "X-aws-ec2-metadata-token";
3745

3846
EC2Environment(Configuration config, ResourceFetcher fetcher) {
3947
super(config);
@@ -43,6 +51,28 @@ public class EC2Environment extends AgentBasedEnvironment {
4351

4452
@Override
4553
public boolean probe() {
54+
String token;
55+
Pair<String, String> tokenRequestHeader =
56+
new Pair<>(TOKEN_REQUEST_HEADER_KEY, TOKEN_REQUEST_HEADER_VALUE);
57+
58+
URI tokenEndpoint = null;
59+
try {
60+
tokenEndpoint = new URI(INSTANCE_TOKEN_URL);
61+
} catch (Exception ex) {
62+
log.debug("Failed to construct url: " + INSTANCE_IDENTITY_URL);
63+
return false;
64+
}
65+
try {
66+
token =
67+
fetcher.fetch(
68+
tokenEndpoint, "PUT", Collections.singletonList(tokenRequestHeader));
69+
} catch (EMFClientException ex) {
70+
log.debug("Failed to get response from: " + tokenEndpoint, ex);
71+
return false;
72+
}
73+
74+
Pair<String, String> metadataRequestTokenHeader =
75+
new Pair<>(METADATA_REQUEST_TOKEN_HEADER_KEY, token);
4676
URI endpoint = null;
4777
try {
4878
endpoint = new URI(INSTANCE_IDENTITY_URL);
@@ -51,7 +81,12 @@ public boolean probe() {
5181
return false;
5282
}
5383
try {
54-
metadata = fetcher.fetch(endpoint, EC2Metadata.class);
84+
metadata =
85+
fetcher.fetch(
86+
endpoint,
87+
"GET",
88+
EC2Metadata.class,
89+
Collections.singletonList(metadataRequestTokenHeader));
5590
return true;
5691
} catch (EMFClientException ex) {
5792
log.debug("Failed to get response from: " + endpoint, ex);

src/main/java/software/amazon/cloudwatchlogs/emf/environment/ResourceFetcher.java

+25-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import java.net.HttpURLConnection;
2424
import java.net.Proxy;
2525
import java.net.URI;
26+
import java.util.Collections;
27+
import java.util.List;
2628
import lombok.extern.slf4j.Slf4j;
29+
import org.javatuples.Pair;
2730
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
2831
import software.amazon.cloudwatchlogs.emf.util.IOUtils;
2932
import software.amazon.cloudwatchlogs.emf.util.Jackson;
@@ -33,24 +36,37 @@ public class ResourceFetcher {
3336

3437
/** Fetch a json object from a given uri and deserialize it to the specified class: clazz. */
3538
<T> T fetch(URI endpoint, Class<T> clazz) {
36-
String response = doReadResource(endpoint, "GET");
39+
String response = doReadResource(endpoint, "GET", Collections.emptyList());
3740
return Jackson.fromJsonString(response, clazz);
3841
}
3942

43+
/**
44+
* Request a json object from a given uri with the provided headers and deserialize it to the
45+
* specified class: clazz.
46+
*/
47+
<T> T fetch(URI endpoint, String method, Class<T> clazz, List<Pair<String, String>> headers) {
48+
String response = doReadResource(endpoint, method, headers);
49+
return Jackson.fromJsonString(response, clazz);
50+
}
51+
52+
/** Request a string from a given uri with the provided headers */
53+
String fetch(URI endpoint, String method, List<Pair<String, String>> headers) {
54+
return doReadResource(endpoint, method, headers);
55+
}
56+
4057
/**
4158
* Fetch a json object from a given uri and deserialize it to the specified class with a given
4259
* Jackson ObjectMapper.
4360
*/
4461
<T> T fetch(URI endpoint, ObjectMapper objectMapper, Class<T> clazz) {
45-
String response = doReadResource(endpoint, "GET");
62+
String response = doReadResource(endpoint, "GET", Collections.emptyList());
4663
return Jackson.fromJsonString(response, objectMapper, clazz);
4764
}
4865

49-
private String doReadResource(URI endpoint, String method) {
66+
private String doReadResource(URI endpoint, String method, List<Pair<String, String>> headers) {
5067
InputStream inputStream = null;
5168
try {
52-
53-
HttpURLConnection connection = connectToEndpoint(endpoint, method);
69+
HttpURLConnection connection = connectToEndpoint(endpoint, method, headers);
5470

5571
int statusCode = connection.getResponseCode();
5672

@@ -105,13 +121,16 @@ private void handleErrorResponse(InputStream errorStream, String responseMessage
105121
}
106122
}
107123

108-
private HttpURLConnection connectToEndpoint(URI endpoint, String method) throws IOException {
124+
private HttpURLConnection connectToEndpoint(
125+
URI endpoint, String method, List<Pair<String, String>> headers) throws IOException {
109126
HttpURLConnection connection =
110127
(HttpURLConnection) endpoint.toURL().openConnection(Proxy.NO_PROXY);
111128
connection.setConnectTimeout(1000);
112129
connection.setReadTimeout(1000);
113130
connection.setRequestMethod(method);
114131
connection.setDoOutput(true);
132+
headers.forEach(
133+
header -> connection.setRequestProperty(header.getValue0(), header.getValue1()));
115134

116135
connection.connect();
117136

src/test/java/software/amazon/cloudwatchlogs/emf/environment/EC2EnvironmentTest.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ public void setUp() {
4949
environment = new EC2Environment(config, fetcher);
5050
}
5151

52+
@SuppressWarnings("unchecked")
5253
@Test
5354
public void testProbeReturnFalse() {
54-
when(fetcher.fetch(any(), any())).thenThrow(new EMFClientException("Invalid URL"));
55+
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
56+
.thenThrow(new EMFClientException("Invalid URL"));
5557

5658
assertFalse(environment.probe());
5759
}
@@ -71,8 +73,10 @@ public void testGetTypeWhenNoMetadata() {
7173
}
7274

7375
@Test
76+
@SuppressWarnings("unchecked")
7477
public void testGetTypeReturnDefined() {
75-
when(fetcher.fetch(any(), any())).thenReturn(new EC2Environment.EC2Metadata());
78+
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
79+
.thenReturn(new EC2Environment.EC2Metadata());
7680
environment.probe();
7781
assertEquals(environment.getType(), "AWS::EC2::Instance");
7882
}
@@ -87,10 +91,11 @@ public void testGetTypeFromConfiguration() {
8791
}
8892

8993
@Test
94+
@SuppressWarnings("unchecked")
9095
public void testConfigureContext() {
9196
EC2Environment.EC2Metadata metadata = new EC2Environment.EC2Metadata();
9297
getRandomMetadata(metadata);
93-
when(fetcher.fetch(any(), any())).thenReturn(metadata);
98+
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any())).thenReturn(metadata);
9499
environment.probe();
95100

96101
MetricsContext context = new MetricsContext();

src/test/java/software/amazon/cloudwatchlogs/emf/environment/ECSEnvironmentTest.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static org.mockito.ArgumentMatchers.any;
2121
import static org.mockito.Mockito.*;
2222

23+
import com.fasterxml.jackson.databind.ObjectMapper;
2324
import com.github.javafaker.Faker;
2425
import java.net.InetAddress;
2526
import java.net.UnknownHostException;
@@ -68,7 +69,7 @@ public void testReturnTrueWithCorrectURL() {
6869
String uri = "http://ecs-metata.com";
6970
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
7071
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
71-
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
72+
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);
7273

7374
assertTrue(environment.probe());
7475
}
@@ -81,7 +82,7 @@ public void testFormatImageName() {
8182
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
8283
metadata.image = "testAccount.dkr.ecr.us-west-2.amazonaws.com/testImage:latest";
8384
metadata.labels = new HashMap<>();
84-
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
85+
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);
8586

8687
assertTrue(environment.probe());
8788
assertEquals(environment.getName(), "testImage:latest");
@@ -122,7 +123,8 @@ public void testSetFluentBit() {
122123
PowerMockito.when(SystemWrapper.getenv("FLUENT_HOST")).thenReturn(fluentHost);
123124

124125
environment.probe();
125-
when(fetcher.fetch(any(), any(), any())).thenReturn(new ECSEnvironment.ECSMetadata());
126+
when(fetcher.fetch(any(), (ObjectMapper) any(), any()))
127+
.thenReturn(new ECSEnvironment.ECSMetadata());
126128
ArgumentCaptor<String> argument = ArgumentCaptor.forClass(String.class);
127129
Mockito.verify(config, times(1)).setAgentEndpoint(argument.capture());
128130
assertEquals(
@@ -155,7 +157,7 @@ public void testConfigureContext() throws UnknownHostException {
155157
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
156158
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
157159
getRandomMetadata(metadata);
158-
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
160+
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);
159161

160162
environment.probe();
161163

src/test/java/software/amazon/cloudwatchlogs/emf/environment/ResourceFetcherTest.java

+40
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
import java.net.ServerSocket;
2727
import java.net.URI;
2828
import java.net.URISyntaxException;
29+
import java.util.Collections;
2930
import lombok.Data;
31+
import org.javatuples.Pair;
3032
import org.junit.Before;
3133
import org.junit.ClassRule;
3234
import org.junit.Test;
@@ -94,6 +96,33 @@ public void testReadDataWith200Response() {
9496
assertEquals(data.size, 10);
9597
}
9698

99+
@Test
100+
public void testReadDataWithHeaders200Response() {
101+
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
102+
generateStub(200, "{\"name\":\"test\",\"size\":10}");
103+
TestData data =
104+
fetcher.fetch(
105+
endpoint, "GET", TestData.class, Collections.singletonList(mockHeader));
106+
107+
verify(
108+
getRequestedFor(urlEqualTo(endpoint_path))
109+
.withHeader("X-mock-header-key", equalTo("headerValue")));
110+
assertEquals(data.name, "test");
111+
assertEquals(data.size, 10);
112+
}
113+
114+
@Test
115+
public void testWithProvidedMethodAndHeadersWith200Response() {
116+
generatePutStub(200, "putResponseData");
117+
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
118+
String data = fetcher.fetch(endpoint, "PUT", Collections.singletonList(mockHeader));
119+
120+
verify(
121+
putRequestedFor(urlEqualTo(endpoint_path))
122+
.withHeader("X-mock-header-key", equalTo("headerValue")));
123+
assertEquals(data, "putResponseData");
124+
}
125+
97126
@Test
98127
public void testReadCaseInsensitiveDataWith200Response() {
99128
generateStub(200, "{\"Name\":\"test\",\"Size\":10}");
@@ -136,6 +165,17 @@ private void generateStub(int statusCode, String message) {
136165
.withBody(message)));
137166
}
138167

168+
private void generatePutStub(int statusCode, String message) {
169+
stubFor(
170+
put(urlPathEqualTo(endpoint_path))
171+
.willReturn(
172+
aResponse()
173+
.withStatus(statusCode)
174+
.withHeader("Content-Type", "application/json")
175+
.withHeader("charset", "utf-8")
176+
.withBody(message)));
177+
}
178+
139179
@Data
140180
private static class TestData {
141181
private String name;

0 commit comments

Comments
 (0)