diff --git a/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/AWSLambda.java b/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/AWSLambda.java index 986f8b7b..2eeb14e3 100644 --- a/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/AWSLambda.java +++ b/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/AWSLambda.java @@ -2,6 +2,7 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ + package com.amazonaws.services.lambda.runtime.api.client; import com.amazonaws.services.lambda.crac.Core; @@ -35,7 +36,6 @@ import java.security.Security; import java.util.Properties; - /** * The entrypoint of this class is {@link AWSLambda#startRuntime}. It performs two main tasks: * @@ -137,6 +137,42 @@ private static LambdaRequestHandler findRequestHandler(final String handlerStrin return requestHandler; } + private static LambdaRequestHandler getLambdaRequestHandlerObject(String handler, LambdaContextLogger lambdaLogger) throws ClassNotFoundException, IOException { + UnsafeUtil.disableIllegalAccessWarning(); + + System.setOut(new PrintStream(new LambdaOutputStream(System.out), false, "UTF-8")); + System.setErr(new PrintStream(new LambdaOutputStream(System.err), false, "UTF-8")); + setupRuntimeLogger(lambdaLogger); + + runtimeClient = new LambdaRuntimeApiClientImpl(LambdaEnvironment.RUNTIME_API); + + String taskRoot = System.getProperty("user.dir"); + String libRoot = "/opt/java"; + // Make system classloader the customer classloader's parent to ensure any aws-lambda-java-core classes + // are loaded from the system classloader. + customerClassLoader = new CustomerClassLoader(taskRoot, libRoot, ClassLoader.getSystemClassLoader()); + Thread.currentThread().setContextClassLoader(customerClassLoader); + + // Load the user's handler + LambdaRequestHandler requestHandler = null; + try { + requestHandler = findRequestHandler(handler, customerClassLoader); + } catch (UserFault userFault) { + lambdaLogger.log(userFault.reportableError(), lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED); + LambdaError error = new LambdaError( + LambdaErrorConverter.fromUserFault(userFault), + RapidErrorType.BadFunctionCode); + runtimeClient.reportInitError(error); + System.exit(1); + } + + if (INIT_TYPE_SNAP_START.equals(AWS_LAMBDA_INITIALIZATION_TYPE)) { + onInitComplete(lambdaLogger); + } + + return requestHandler; + } + public static void setupRuntimeLogger(LambdaLogger lambdaLogger) throws ClassNotFoundException { ReflectUtil.setStaticField( @@ -176,55 +212,27 @@ private static LogSink createLogSink() { } } - public static void main(String[] args) { - startRuntime(args[0]); - } + public static void main(String[] args) throws Throwable { + try (LambdaContextLogger logger = initLogger()) { + LambdaRequestHandler lambdaRequestHandler = getLambdaRequestHandlerObject(args[0], logger); + startRuntimeLoop(lambdaRequestHandler, logger); - private static void startRuntime(String handler) { - try (LogSink logSink = createLogSink()) { - LambdaContextLogger logger = new LambdaContextLogger( - logSink, - LogLevel.fromString(LambdaEnvironment.LAMBDA_LOG_LEVEL), - LogFormat.fromString(LambdaEnvironment.LAMBDA_LOG_FORMAT) - ); - startRuntime(handler, logger); - } catch (Throwable t) { + } catch (IOException | ClassNotFoundException t) { throw new Error(t); } } - private static void startRuntime(String handler, LambdaContextLogger lambdaLogger) throws Throwable { - UnsafeUtil.disableIllegalAccessWarning(); - - System.setOut(new PrintStream(new LambdaOutputStream(System.out), false, "UTF-8")); - System.setErr(new PrintStream(new LambdaOutputStream(System.err), false, "UTF-8")); - setupRuntimeLogger(lambdaLogger); + private static LambdaContextLogger initLogger() { + LogSink logSink = createLogSink(); + LambdaContextLogger logger = new LambdaContextLogger( + logSink, + LogLevel.fromString(LambdaEnvironment.LAMBDA_LOG_LEVEL), + LogFormat.fromString(LambdaEnvironment.LAMBDA_LOG_FORMAT)); - runtimeClient = new LambdaRuntimeApiClientImpl(LambdaEnvironment.RUNTIME_API); - - String taskRoot = System.getProperty("user.dir"); - String libRoot = "/opt/java"; - // Make system classloader the customer classloader's parent to ensure any aws-lambda-java-core classes - // are loaded from the system classloader. - customerClassLoader = new CustomerClassLoader(taskRoot, libRoot, ClassLoader.getSystemClassLoader()); - Thread.currentThread().setContextClassLoader(customerClassLoader); + return logger; + } - // Load the user's handler - LambdaRequestHandler requestHandler; - try { - requestHandler = findRequestHandler(handler, customerClassLoader); - } catch (UserFault userFault) { - lambdaLogger.log(userFault.reportableError(), lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED); - LambdaError error = new LambdaError( - LambdaErrorConverter.fromUserFault(userFault), - RapidErrorType.BadFunctionCode); - runtimeClient.reportInitError(error); - System.exit(1); - return; - } - if (INIT_TYPE_SNAP_START.equals(AWS_LAMBDA_INITIALIZATION_TYPE)) { - onInitComplete(lambdaLogger); - } + private static void startRuntimeLoop(LambdaRequestHandler requestHandler, LambdaContextLogger lambdaLogger) throws Throwable { boolean shouldExit = false; while (!shouldExit) { UserFault userFault = null; @@ -240,7 +248,7 @@ private static void startRuntime(String handler, LambdaContextLogger lambdaLogge payload = requestHandler.call(request); runtimeClient.reportInvocationSuccess(request.getId(), payload.toByteArray()); // clear interrupted flag in case if it was set by user's code - boolean ignored = Thread.interrupted(); + Thread.interrupted(); } catch (UserFault f) { shouldExit = f.fatal; userFault = f; @@ -278,6 +286,7 @@ static void onInitComplete(final LambdaContextLogger lambdaLogger) throws IOExce RapidErrorType.BeforeCheckpointError)); System.exit(64); } + try { Core.getGlobalContext().afterRestore(null); } catch (Exception restoreExc) { diff --git a/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/logging/LambdaContextLogger.java b/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/logging/LambdaContextLogger.java index 693eb015..dd356912 100644 --- a/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/logging/LambdaContextLogger.java +++ b/aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/logging/LambdaContextLogger.java @@ -7,9 +7,11 @@ import com.amazonaws.services.lambda.runtime.logging.LogFormat; import com.amazonaws.services.lambda.runtime.logging.LogLevel; +import java.io.Closeable; +import java.io.IOException; import static java.nio.charset.StandardCharsets.UTF_8; -public class LambdaContextLogger extends AbstractLambdaLogger { +public class LambdaContextLogger extends AbstractLambdaLogger implements Closeable { // If a null string is passed in, replace it with "null", // replicating the behavior of System.out.println(null); private static final byte[] NULL_BYTES_VALUE = "null".getBytes(UTF_8); @@ -29,4 +31,10 @@ protected void logMessage(byte[] message, LogLevel logLevel) { sink.log(logLevel, this.logFormat, message); } } + + @Override + public void close() throws IOException { + sink.close(); + + } }