Skip to content

Migrate from IMDSv1 to IMDSv2 #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.net.URI;
import java.util.Collections;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.javatuples.Pair;
import software.amazon.cloudwatchlogs.emf.Constants;
import software.amazon.cloudwatchlogs.emf.config.Configuration;
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
Expand All @@ -33,7 +35,13 @@ public class EC2Environment extends AgentBasedEnvironment {

private static final String INSTANCE_IDENTITY_URL =
"http://169.254.169.254/latest/dynamic/instance-identity/document";

private static final String INSTANCE_TOKEN_URL = "http://169.254.169.254/latest/api/token";
private static final String CFN_EC2_TYPE = "AWS::EC2::Instance";
private static final String TOKEN_REQUEST_HEADER_KEY = "X-aws-ec2-metadata-token-ttl-seconds";
private static final String TOKEN_REQUEST_HEADER_VALUE = "21600";

private static final String METADATA_REQUEST_TOKEN_HEADER_KEY = "X-aws-ec2-metadata-token";

EC2Environment(Configuration config, ResourceFetcher fetcher) {
super(config);
Expand All @@ -43,6 +51,28 @@ public class EC2Environment extends AgentBasedEnvironment {

@Override
public boolean probe() {
String token;
Pair<String, String> tokenRequestHeader =
new Pair<>(TOKEN_REQUEST_HEADER_KEY, TOKEN_REQUEST_HEADER_VALUE);

URI tokenEndpoint = null;
try {
tokenEndpoint = new URI(INSTANCE_TOKEN_URL);
} catch (Exception ex) {
log.debug("Failed to construct url: " + INSTANCE_IDENTITY_URL);
return false;
}
try {
token =
fetcher.fetch(
tokenEndpoint, "PUT", Collections.singletonList(tokenRequestHeader));
} catch (EMFClientException ex) {
log.debug("Failed to get response from: " + tokenEndpoint, ex);
return false;
}

Pair<String, String> metadataRequestTokenHeader =
new Pair<>(METADATA_REQUEST_TOKEN_HEADER_KEY, token);
URI endpoint = null;
try {
endpoint = new URI(INSTANCE_IDENTITY_URL);
Expand All @@ -51,7 +81,12 @@ public boolean probe() {
return false;
}
try {
metadata = fetcher.fetch(endpoint, EC2Metadata.class);
metadata =
fetcher.fetch(
endpoint,
"GET",
EC2Metadata.class,
Collections.singletonList(metadataRequestTokenHeader));
return true;
} catch (EMFClientException ex) {
log.debug("Failed to get response from: " + endpoint, ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import java.net.HttpURLConnection;
import java.net.Proxy;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.javatuples.Pair;
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
import software.amazon.cloudwatchlogs.emf.util.IOUtils;
import software.amazon.cloudwatchlogs.emf.util.Jackson;
Expand All @@ -33,24 +36,37 @@ public class ResourceFetcher {

/** Fetch a json object from a given uri and deserialize it to the specified class: clazz. */
<T> T fetch(URI endpoint, Class<T> clazz) {
String response = doReadResource(endpoint, "GET");
String response = doReadResource(endpoint, "GET", Collections.emptyList());
return Jackson.fromJsonString(response, clazz);
}

/**
* Request a json object from a given uri with the provided headers and deserialize it to the
* specified class: clazz.
*/
<T> T fetch(URI endpoint, String method, Class<T> clazz, List<Pair<String, String>> headers) {
String response = doReadResource(endpoint, method, headers);
return Jackson.fromJsonString(response, clazz);
}

/** Request a string from a given uri with the provided headers */
String fetch(URI endpoint, String method, List<Pair<String, String>> headers) {
return doReadResource(endpoint, method, headers);
}

/**
* Fetch a json object from a given uri and deserialize it to the specified class with a given
* Jackson ObjectMapper.
*/
<T> T fetch(URI endpoint, ObjectMapper objectMapper, Class<T> clazz) {
String response = doReadResource(endpoint, "GET");
String response = doReadResource(endpoint, "GET", Collections.emptyList());
return Jackson.fromJsonString(response, objectMapper, clazz);
}

private String doReadResource(URI endpoint, String method) {
private String doReadResource(URI endpoint, String method, List<Pair<String, String>> headers) {
InputStream inputStream = null;
try {

HttpURLConnection connection = connectToEndpoint(endpoint, method);
HttpURLConnection connection = connectToEndpoint(endpoint, method, headers);

int statusCode = connection.getResponseCode();

Expand Down Expand Up @@ -105,13 +121,16 @@ private void handleErrorResponse(InputStream errorStream, String responseMessage
}
}

private HttpURLConnection connectToEndpoint(URI endpoint, String method) throws IOException {
private HttpURLConnection connectToEndpoint(
URI endpoint, String method, List<Pair<String, String>> headers) throws IOException {
HttpURLConnection connection =
(HttpURLConnection) endpoint.toURL().openConnection(Proxy.NO_PROXY);
connection.setConnectTimeout(1000);
connection.setReadTimeout(1000);
connection.setRequestMethod(method);
connection.setDoOutput(true);
headers.forEach(
header -> connection.setRequestProperty(header.getValue0(), header.getValue1()));

connection.connect();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ public void setUp() {
environment = new EC2Environment(config, fetcher);
}

@SuppressWarnings("unchecked")
@Test
public void testProbeReturnFalse() {
when(fetcher.fetch(any(), any())).thenThrow(new EMFClientException("Invalid URL"));
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
.thenThrow(new EMFClientException("Invalid URL"));

assertFalse(environment.probe());
}
Expand All @@ -71,8 +73,10 @@ public void testGetTypeWhenNoMetadata() {
}

@Test
@SuppressWarnings("unchecked")
public void testGetTypeReturnDefined() {
when(fetcher.fetch(any(), any())).thenReturn(new EC2Environment.EC2Metadata());
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
.thenReturn(new EC2Environment.EC2Metadata());
environment.probe();
assertEquals(environment.getType(), "AWS::EC2::Instance");
}
Expand All @@ -87,10 +91,11 @@ public void testGetTypeFromConfiguration() {
}

@Test
@SuppressWarnings("unchecked")
public void testConfigureContext() {
EC2Environment.EC2Metadata metadata = new EC2Environment.EC2Metadata();
getRandomMetadata(metadata);
when(fetcher.fetch(any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any())).thenReturn(metadata);
environment.probe();

MetricsContext context = new MetricsContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.javafaker.Faker;
import java.net.InetAddress;
import java.net.UnknownHostException;
Expand Down Expand Up @@ -68,7 +69,7 @@ public void testReturnTrueWithCorrectURL() {
String uri = "http://ecs-metata.com";
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

assertTrue(environment.probe());
}
Expand All @@ -81,7 +82,7 @@ public void testFormatImageName() {
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
metadata.image = "testAccount.dkr.ecr.us-west-2.amazonaws.com/testImage:latest";
metadata.labels = new HashMap<>();
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

assertTrue(environment.probe());
assertEquals(environment.getName(), "testImage:latest");
Expand Down Expand Up @@ -122,7 +123,8 @@ public void testSetFluentBit() {
PowerMockito.when(SystemWrapper.getenv("FLUENT_HOST")).thenReturn(fluentHost);

environment.probe();
when(fetcher.fetch(any(), any(), any())).thenReturn(new ECSEnvironment.ECSMetadata());
when(fetcher.fetch(any(), (ObjectMapper) any(), any()))
.thenReturn(new ECSEnvironment.ECSMetadata());
ArgumentCaptor<String> argument = ArgumentCaptor.forClass(String.class);
Mockito.verify(config, times(1)).setAgentEndpoint(argument.capture());
assertEquals(
Expand Down Expand Up @@ -155,7 +157,7 @@ public void testConfigureContext() throws UnknownHostException {
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
getRandomMetadata(metadata);
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

environment.probe();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import java.net.ServerSocket;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import lombok.Data;
import org.javatuples.Pair;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
Expand Down Expand Up @@ -94,6 +96,33 @@ public void testReadDataWith200Response() {
assertEquals(data.size, 10);
}

@Test
public void testReadDataWithHeaders200Response() {
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
generateStub(200, "{\"name\":\"test\",\"size\":10}");
TestData data =
fetcher.fetch(
endpoint, "GET", TestData.class, Collections.singletonList(mockHeader));

verify(
getRequestedFor(urlEqualTo(endpoint_path))
.withHeader("X-mock-header-key", equalTo("headerValue")));
assertEquals(data.name, "test");
assertEquals(data.size, 10);
}

@Test
public void testWithProvidedMethodAndHeadersWith200Response() {
generatePutStub(200, "putResponseData");
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
String data = fetcher.fetch(endpoint, "PUT", Collections.singletonList(mockHeader));

verify(
putRequestedFor(urlEqualTo(endpoint_path))
.withHeader("X-mock-header-key", equalTo("headerValue")));
assertEquals(data, "putResponseData");
}

@Test
public void testReadCaseInsensitiveDataWith200Response() {
generateStub(200, "{\"Name\":\"test\",\"Size\":10}");
Expand Down Expand Up @@ -136,6 +165,17 @@ private void generateStub(int statusCode, String message) {
.withBody(message)));
}

private void generatePutStub(int statusCode, String message) {
stubFor(
put(urlPathEqualTo(endpoint_path))
.willReturn(
aResponse()
.withStatus(statusCode)
.withHeader("Content-Type", "application/json")
.withHeader("charset", "utf-8")
.withBody(message)));
}

@Data
private static class TestData {
private String name;
Expand Down