From 2bcad2672cbc9d0d4127ce3d857b6bef1584e194 Mon Sep 17 00:00:00 2001 From: Maxime David Date: Wed, 12 Mar 2025 18:17:41 +0000 Subject: [PATCH 1/5] test: Pojo serializer --- .../api/client/PojoSerializerLoaderTest.java | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/PojoSerializerLoaderTest.java diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/PojoSerializerLoaderTest.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/PojoSerializerLoaderTest.java new file mode 100644 index 00000000..7c6e9dcb --- /dev/null +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/PojoSerializerLoaderTest.java @@ -0,0 +1,151 @@ +/* +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package com.amazonaws.services.lambda.runtime.api.client; + +import com.amazonaws.services.lambda.runtime.CustomPojoSerializer; +import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Field; +import java.lang.reflect.Type; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class PojoSerializerLoaderTest { + + @Mock + private CustomPojoSerializer mockSerializer; + + @BeforeEach + void setUp() throws Exception { + resetStaticFields(); + } + + private void resetStaticFields() throws Exception { + Field serializerField = PojoSerializerLoader.class.getDeclaredField("customPojoSerializer"); + serializerField.setAccessible(true); + serializerField.set(null, null); + + Field initializedField = PojoSerializerLoader.class.getDeclaredField("initialized"); + initializedField.setAccessible(true); + initializedField.set(null, false); + } + + + private void setMockSerializer(CustomPojoSerializer serializer) throws Exception { + Field serializerField = PojoSerializerLoader.class.getDeclaredField("customPojoSerializer"); + serializerField.setAccessible(true); + serializerField.set(null, serializer); + } + + @Test + void testGetCustomerSerializerNoSerializerAvailable() throws Exception { + PojoSerializer serializer = PojoSerializerLoader.getCustomerSerializer(String.class); + assertNull(serializer); + Field initializedField = PojoSerializerLoader.class.getDeclaredField("initialized"); + initializedField.setAccessible(true); + assert((Boolean) initializedField.get(null)); + } + + @Test + void testGetCustomerSerializerWithValidSerializer() throws Exception { + setMockSerializer(mockSerializer); + String testInput = "test input"; + String testOutput = "test output"; + Type testType = String.class; + when(mockSerializer.fromJson(any(InputStream.class), eq(testType))).thenReturn(testOutput); + when(mockSerializer.fromJson(eq(testInput), eq(testType))).thenReturn(testOutput); + + PojoSerializer serializer = PojoSerializerLoader.getCustomerSerializer(testType); + assertNotNull(serializer); + + ByteArrayInputStream inputStream = new ByteArrayInputStream(testInput.getBytes()); + Object result1 = serializer.fromJson(inputStream); + assertEquals(testOutput, result1); + + Object result2 = serializer.fromJson(testInput); + assertEquals(testOutput, result2); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + serializer.toJson(testInput, outputStream); + verify(mockSerializer).toJson(eq(testInput), any(OutputStream.class), eq(testType)); + } + + @Test + void testGetCustomerSerializerCachingBehavior() throws Exception { + setMockSerializer(mockSerializer); + + Type testType = String.class; + PojoSerializer serializer1 = PojoSerializerLoader.getCustomerSerializer(testType); + PojoSerializer serializer2 = PojoSerializerLoader.getCustomerSerializer(testType); + + assertNotNull(serializer1); + assertNotNull(serializer2); + + String testInput = "test"; + serializer1.fromJson(testInput); + serializer2.fromJson(testInput); + + verify(mockSerializer, times(2)).fromJson(eq(testInput), eq(testType)); + } + + @Test + void testGetCustomerSerializerDifferentTypes() throws Exception { + setMockSerializer(mockSerializer); + + PojoSerializer stringSerializer = PojoSerializerLoader.getCustomerSerializer(String.class); + PojoSerializer integerSerializer = PojoSerializerLoader.getCustomerSerializer(Integer.class); + + assertNotNull(stringSerializer); + assertNotNull(integerSerializer); + + String testString = "test"; + Integer testInt = 123; + + stringSerializer.fromJson(testString); + integerSerializer.fromJson(testInt.toString()); + + verify(mockSerializer).fromJson(eq(testString), eq(String.class)); + verify(mockSerializer).fromJson(eq(testInt.toString()), eq(Integer.class)); + } + + @Test + void testGetCustomerSerializerNullType() throws Exception { + setMockSerializer(mockSerializer); + + PojoSerializer serializer = PojoSerializerLoader.getCustomerSerializer(null); + assertNotNull(serializer); + + String testInput = "test"; + serializer.fromJson(testInput); + verify(mockSerializer).fromJson(eq(testInput), eq(null)); + } + + @Test + void testGetCustomerSerializerExceptionHandling() throws Exception { + setMockSerializer(mockSerializer); + + doThrow(new RuntimeException("Test exception")) + .when(mockSerializer) + .fromJson(any(String.class), any(Type.class)); + + PojoSerializer serializer = PojoSerializerLoader.getCustomerSerializer(String.class); + assertNotNull(serializer); + assertThrows(RuntimeException.class, () -> serializer.fromJson("test")); + } +} From 811f907359598c0fe56c4f9e87e939eff3833a61 Mon Sep 17 00:00:00 2001 From: Maxime David Date: Wed, 12 Mar 2025 18:30:28 +0000 Subject: [PATCH 2/5] test: add UserFaultTests --- .../runtime/api/client/UserFaultTest.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/UserFaultTest.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/UserFaultTest.java index 5a57e6e0..479162ad 100644 --- a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/UserFaultTest.java +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/UserFaultTest.java @@ -124,4 +124,42 @@ public void testCircularSuppressedExceptionReference() { assertEquals(expectedStackTrace, stackTrace); } } + + private Exception createExceptionWithStackTrace() { + try { + throw new RuntimeException("Test exception"); + } catch (RuntimeException e) { + return e; + } + } + + @Test + void testMakeInitErrorUserFault() { + String className = "com.example.TestClass"; + Exception testException = createExceptionWithStackTrace(); + + UserFault initFault = UserFault.makeInitErrorUserFault(testException, className); + UserFault notFoundFault = UserFault.makeClassNotFoundUserFault(testException, className); + + assertNotNull(initFault.trace); + assertNotNull(notFoundFault.trace); + + assertFalse(initFault.trace.contains("com.amazonaws.services.lambda.runtime")); + assertFalse(notFoundFault.trace.contains("com.amazonaws.services.lambda.runtime")); + } + + @Test + void testMakeClassNotFoundUserFault() { + String className = "com.example.MissingClass"; + Exception testException = new ClassNotFoundException("Class not found in classpath"); + + UserFault fault = UserFault.makeClassNotFoundUserFault(testException, className); + + assertNotNull(fault); + assertEquals("Class not found: com.example.MissingClass", fault.msg); + assertEquals("java.lang.ClassNotFoundException", fault.exception); + assertNotNull(fault.trace); + assertFalse(fault.fatal); + assertTrue(fault.trace.contains("ClassNotFoundException")); + } } From 3ee43636dd45f251f2519ed6444700da89ed6737 Mon Sep 17 00:00:00 2001 From: Maxime David Date: Wed, 12 Mar 2025 18:39:09 +0000 Subject: [PATCH 3/5] test: more tests --- .../api/client/LambdaRequestHandler.java | 142 ++++++++++++++++++ ...anyServiceProvidersFoundExceptionTest.java | 59 ++++++++ 2 files changed, 201 insertions(+) create mode 100644 aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/LambdaRequestHandler.java create mode 100644 aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/TooManyServiceProvidersFoundExceptionTest.java diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/LambdaRequestHandler.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/LambdaRequestHandler.java new file mode 100644 index 00000000..d86b7385 --- /dev/null +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/LambdaRequestHandler.java @@ -0,0 +1,142 @@ +/* +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package com.amazonaws.services.lambda.runtime.api.client; + +import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class LambdaRequestHandlerTest { + + private InvocationRequest mockRequest; + + @BeforeEach + void setUp() { + mockRequest = mock(InvocationRequest.class); + } + + @Test + void testInitErrorHandler() { + String className = "com.example.TestClass"; + Exception testException = new RuntimeException("initialization error"); + + LambdaRequestHandler handler = LambdaRequestHandler.initErrorHandler(testException, className); + + assertNotNull(handler); + assertTrue(handler instanceof LambdaRequestHandler.UserFaultHandler); + + LambdaRequestHandler.UserFaultHandler userFaultHandler = (LambdaRequestHandler.UserFaultHandler) handler; + UserFault fault = userFaultHandler.fault; + + assertNotNull(fault); + assertEquals("Error loading class " + className + ": initialization error", fault.msg); + assertEquals("java.lang.RuntimeException", fault.exception); + assertTrue(fault.fatal); + } + + @Test + void testClassNotFound() { + String className = "com.example.MissingClass"; + Exception testException = new ClassNotFoundException("class not found"); + + LambdaRequestHandler handler = LambdaRequestHandler.classNotFound(testException, className); + + assertNotNull(handler); + assertTrue(handler instanceof LambdaRequestHandler.UserFaultHandler); + + LambdaRequestHandler.UserFaultHandler userFaultHandler = (LambdaRequestHandler.UserFaultHandler) handler; + UserFault fault = userFaultHandler.fault; + + assertNotNull(fault); + assertEquals("Class not found: " + className, fault.msg); + assertEquals("java.lang.ClassNotFoundException", fault.exception); + assertFalse(fault.fatal); + } + + @Test + void testUserFaultHandlerConstructor() { + UserFault testFault = new UserFault("test message", "TestException", "test trace"); + LambdaRequestHandler.UserFaultHandler handler = new LambdaRequestHandler.UserFaultHandler(testFault); + + assertNotNull(handler); + assertSame(testFault, handler.fault); + } + + @Test + void testUserFaultHandlerCallThrowsFault() { + UserFault testFault = new UserFault("test message", "TestException", "test trace"); + LambdaRequestHandler.UserFaultHandler handler = new LambdaRequestHandler.UserFaultHandler(testFault); + + UserFault thrownFault = assertThrows(UserFault.class, () -> handler.call(mockRequest)); + assertSame(testFault, thrownFault); + } + + @Test + void testInitErrorHandlerWithNullMessage() { + String className = "com.example.TestClass"; + Exception testException = new RuntimeException(); + + LambdaRequestHandler handler = LambdaRequestHandler.initErrorHandler(testException, className); + + assertNotNull(handler); + assertTrue(handler instanceof LambdaRequestHandler.UserFaultHandler); + + LambdaRequestHandler.UserFaultHandler userFaultHandler = (LambdaRequestHandler.UserFaultHandler) handler; + UserFault fault = userFaultHandler.fault; + + assertNotNull(fault); + assertEquals("Error loading class " + className, fault.msg); + assertEquals("java.lang.RuntimeException", fault.exception); + assertTrue(fault.fatal); + } + + @Test + void testInitErrorHandlerWithNullClassName() { + Exception testException = new RuntimeException("test error"); + + LambdaRequestHandler handler = LambdaRequestHandler.initErrorHandler(testException, null); + + assertNotNull(handler); + assertTrue(handler instanceof LambdaRequestHandler.UserFaultHandler); + + LambdaRequestHandler.UserFaultHandler userFaultHandler = (LambdaRequestHandler.UserFaultHandler) handler; + UserFault fault = userFaultHandler.fault; + + assertNotNull(fault); + assertEquals("Error loading class null: test error", fault.msg); + assertEquals("java.lang.RuntimeException", fault.exception); + assertTrue(fault.fatal); + } + + @Test + void testClassNotFoundWithNullClassName() { + Exception testException = new ClassNotFoundException("test error"); + + LambdaRequestHandler handler = LambdaRequestHandler.classNotFound(testException, null); + + assertNotNull(handler); + assertTrue(handler instanceof LambdaRequestHandler.UserFaultHandler); + + LambdaRequestHandler.UserFaultHandler userFaultHandler = (LambdaRequestHandler.UserFaultHandler) handler; + UserFault fault = userFaultHandler.fault; + + assertNotNull(fault); + assertEquals("Class not found: null", fault.msg); + assertEquals("java.lang.ClassNotFoundException", fault.exception); + assertFalse(fault.fatal); + } + + @Test + void testUserFaultHandlerCallWithNullRequest() { + UserFault testFault = new UserFault("test message", "TestException", "test trace"); + LambdaRequestHandler.UserFaultHandler handler = new LambdaRequestHandler.UserFaultHandler(testFault); + + UserFault thrownFault = assertThrows(UserFault.class, () -> handler.call(null)); + assertSame(testFault, thrownFault); + } +} diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/TooManyServiceProvidersFoundExceptionTest.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/TooManyServiceProvidersFoundExceptionTest.java new file mode 100644 index 00000000..38d33f63 --- /dev/null +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/TooManyServiceProvidersFoundExceptionTest.java @@ -0,0 +1,59 @@ +/* +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package com.amazonaws.services.lambda.runtime.api.client; + +import org.junit.jupiter.api.Test; + +import com.amazonaws.services.lambda.runtime.api.client.TooManyServiceProvidersFoundException; + +import static org.junit.jupiter.api.Assertions.*; + +class TooManyServiceProvidersFoundExceptionTest { + + @Test + void testDefaultConstructor() { + TooManyServiceProvidersFoundException exception = new TooManyServiceProvidersFoundException(); + + assertNotNull(exception); + assertNull(exception.getMessage()); + assertNull(exception.getCause()); + } + + @Test + void testMessageConstructor() { + String errorMessage = "Too many service providers found"; + TooManyServiceProvidersFoundException exception = + new TooManyServiceProvidersFoundException(errorMessage); + + assertNotNull(exception); + assertEquals(errorMessage, exception.getMessage()); + assertNull(exception.getCause()); + } + + @Test + void testCauseConstructor() { + Throwable cause = new IllegalStateException("Original error"); + TooManyServiceProvidersFoundException exception = + new TooManyServiceProvidersFoundException(cause); + + assertNotNull(exception); + assertEquals(cause.toString(), exception.getMessage()); + assertSame(cause, exception.getCause()); + } + + @Test + void testMessageAndCauseConstructor() { + String errorMessage = "Too many service providers found"; + Throwable cause = new IllegalStateException("Original error"); + TooManyServiceProvidersFoundException exception = + new TooManyServiceProvidersFoundException(errorMessage, cause); + + assertNotNull(exception); + assertEquals(errorMessage, exception.getMessage()); + assertSame(cause, exception.getCause()); + } + +} From 981ec08fc4fd41cbf8d228787f685680504091de Mon Sep 17 00:00:00 2001 From: Maxime David Date: Wed, 12 Mar 2025 20:06:14 +0000 Subject: [PATCH 4/5] test: add test around blocklisting --- .../api/client/ClasspathLoaderTest.java | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/ClasspathLoaderTest.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/ClasspathLoaderTest.java index 547f238c..38147d21 100644 --- a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/ClasspathLoaderTest.java +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/ClasspathLoaderTest.java @@ -109,6 +109,34 @@ void testLoadAllClassesWithMultipleEntries(@TempDir Path tempDir) throws IOExcep } } + @Test + void testLoadAllClassesWithBlocklistedClass(@TempDir Path tempDir) throws IOException { + File jarFile = tempDir.resolve("blocklist-test.jar").toFile(); + + try (JarOutputStream jos = new JarOutputStream(new FileOutputStream(jarFile))) { + JarEntry blockedEntry = new JarEntry("META-INF/versions/9/module-info.class"); + jos.putNextEntry(blockedEntry); + jos.write("dummy content".getBytes()); + jos.closeEntry(); + + JarEntry normalEntry = new JarEntry("com/test/Normal.class"); + jos.putNextEntry(normalEntry); + jos.write("dummy content".getBytes()); + jos.closeEntry(); + } + + String originalClasspath = System.getProperty("java.class.path"); + try { + System.setProperty("java.class.path", jarFile.getAbsolutePath()); + ClasspathLoader.main(new String[]{}); + // The test passes if no exception is thrown and the blocklisted class is skipped + } finally { + if (originalClasspath != null) { + System.setProperty("java.class.path", originalClasspath); + } + } + } + private File createSimpleJar(Path tempDir, String jarName, String className) throws IOException { File jarFile = tempDir.resolve(jarName).toFile(); From 6b29f4038c75e9f9413fdd4fd9b175c602c25070 Mon Sep 17 00:00:00 2001 From: Maxime David Date: Wed, 12 Mar 2025 20:11:07 +0000 Subject: [PATCH 5/5] test: add test for HandlerInfoTest --- .../runtime/api/client/HandlerInfoTest.java | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/HandlerInfoTest.java diff --git a/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/HandlerInfoTest.java b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/HandlerInfoTest.java new file mode 100644 index 00000000..e134ddc8 --- /dev/null +++ b/aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/HandlerInfoTest.java @@ -0,0 +1,132 @@ +/* +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package com.amazonaws.services.lambda.runtime.api.client; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +class HandlerInfoTest { + + @Test + void testConstructor() { + Class testClass = String.class; + String methodName = "testMethod"; + + HandlerInfo info = new HandlerInfo(testClass, methodName); + + assertNotNull(info); + assertEquals(testClass, info.clazz); + assertEquals(methodName, info.methodName); + } + + @Test + void testFromStringWithoutMethod() throws Exception { + String handler = "java.lang.String"; + HandlerInfo info = HandlerInfo.fromString(handler, ClassLoader.getSystemClassLoader()); + + assertEquals(String.class, info.clazz); + assertNull(info.methodName); + } + + @Test + void testFromStringWithMethod() throws Exception { + String handler = "java.lang.String::length"; + HandlerInfo info = HandlerInfo.fromString(handler, ClassLoader.getSystemClassLoader()); + + assertEquals(String.class, info.clazz); + assertEquals("length", info.methodName); + } + + @Test + void testFromStringWithEmptyClass() { + String handler = "::method"; + + assertThrows(HandlerInfo.InvalidHandlerException.class, () -> + HandlerInfo.fromString(handler, ClassLoader.getSystemClassLoader()) + ); + } + + @Test + void testFromStringWithEmptyMethod() { + String handler = "java.lang.String::"; + + assertThrows(HandlerInfo.InvalidHandlerException.class, () -> + HandlerInfo.fromString(handler, ClassLoader.getSystemClassLoader()) + ); + } + + @Test + void testFromStringWithNonexistentClass() { + String handler = "com.nonexistent.TestClass::method"; + + assertThrows(ClassNotFoundException.class, () -> + HandlerInfo.fromString(handler, ClassLoader.getSystemClassLoader()) + ); + } + + @Test + void testFromStringWithNullHandler() { + assertThrows(NullPointerException.class, () -> + HandlerInfo.fromString(null, ClassLoader.getSystemClassLoader()) + ); + } + + @Test + void testClassNameWithoutMethod() { + String handler = "java.lang.String"; + String className = HandlerInfo.className(handler); + + assertEquals("java.lang.String", className); + } + + @Test + void testClassNameWithMethod() { + String handler = "java.lang.String::length"; + String className = HandlerInfo.className(handler); + + assertEquals("java.lang.String", className); + } + + @Test + void testClassNameWithEmptyString() { + String handler = ""; + String className = HandlerInfo.className(handler); + + assertEquals("", className); + } + + @Test + void testClassNameWithOnlyDelimiter() { + String handler = "::"; + String className = HandlerInfo.className(handler); + + assertEquals("", className); + } + + @Test + void testInvalidHandlerExceptionSerialVersionUID() { + assertEquals(-1L, HandlerInfo.InvalidHandlerException.serialVersionUID); + } + + @Test + void testFromStringWithInnerClass() throws Exception { + // Create a custom class loader that can load our test class + ClassLoader cl = new ClassLoader() { + @Override + public Class loadClass(String name) throws ClassNotFoundException { + if (name.equals("com.test.OuterClass$InnerClass")) { + throw new ClassNotFoundException("Test class not found"); + } + return super.loadClass(name); + } + }; + + String handler = "com.test.OuterClass$InnerClass::method"; + assertThrows(ClassNotFoundException.class, () -> + HandlerInfo.fromString(handler, cl) + ); + } +}