diff --git a/src/sagemaker/lambda_helper.py b/src/sagemaker/lambda_helper.py index c1e787d8d3..568ed7eb0c 100644 --- a/src/sagemaker/lambda_helper.py +++ b/src/sagemaker/lambda_helper.py @@ -36,6 +36,10 @@ def __init__( timeout: int = 120, memory_size: int = 128, runtime: str = "python3.8", + vpc_config: dict = None, + architectures: list = None, + environment: dict = None, + layers: list = None, ): """Constructs a Lambda instance. @@ -66,6 +70,11 @@ def __init__( timeout (int): Timeout of the Lambda function in seconds. Default is 120 seconds. memory_size (int): Memory of the Lambda function in megabytes. Default is 128 MB. runtime (str): Runtime of the Lambda function. Default is set to python3.8. + vpc_config (dict): VPC to deploy the Lambda function to. Default is None. + architectures (list): Which architecture to deploy to. Valid Values are + 'x86_64' and 'arm64', default is None. + environment (dict): Environment Variables for the Lambda function. Default is None. + layers (list): List of Lambda layers for the Lambda function. Default is None. """ self.function_arn = function_arn self.function_name = function_name @@ -78,6 +87,10 @@ def __init__( self.timeout = timeout self.memory_size = memory_size self.runtime = runtime + self.vpc_config = vpc_config + self.environment = environment + self.architectures = architectures + self.layers = layers if function_arn is None and function_name is None: raise ValueError("Either function_arn or function_name must be provided.") @@ -127,6 +140,10 @@ def create(self): Code=code, Timeout=self.timeout, MemorySize=self.memory_size, + VpcConfig=self.vpc_config, + Environment=self.environment, + Architectures=self.architectures, + Layers=self.layers, ) return response except ClientError as e: @@ -146,6 +163,7 @@ def update(self): response = lambda_client.update_function_code( FunctionName=self.function_name or self.function_arn, ZipFile=_zip_lambda_code(self.script), + Architectures=self.architectures, ) else: bucket = self.s3_bucket or self.session.default_bucket() @@ -168,6 +186,7 @@ def update(self): zipped_code_dir=self.zipped_code_dir, s3_bucket=bucket, ), + Architectures=self.architectures, ) return response except ClientError as e: diff --git a/tests/unit/test_lambda_helper.py b/tests/unit/test_lambda_helper.py index 4a542a1db7..64dc50fa68 100644 --- a/tests/unit/test_lambda_helper.py +++ b/tests/unit/test_lambda_helper.py @@ -187,6 +187,10 @@ def test_create_lambda_happycase1(sagemaker_session): Code=code, Timeout=120, MemorySize=128, + Architectures=None, + VpcConfig=None, + Environment=None, + Layers=None, ) @@ -212,6 +216,45 @@ def test_create_lambda_happycase2(sagemaker_session): Code=code, Timeout=120, MemorySize=128, + Architectures=None, + VpcConfig=None, + Environment=None, + Layers=None, + ) + + +@patch("sagemaker.lambda_helper._zip_lambda_code", return_value=ZIPPED_CODE) +def test_create_lambda_happycase3(sagemaker_session): + lambda_obj = lambda_helper.Lambda( + function_name=FUNCTION_NAME, + execution_role_arn=EXECUTION_ROLE, + script=SCRIPT, + handler=HANDLER, + session=sagemaker_session, + architectures=["x86_64"], + environment={"Name": "my-test-lambda"}, + vpc_config={ + "SubnetIds": ["test-subnet-1"], + "SecurityGroupIds": ["sec-group-1"], + }, + layers=["my-test-layer-1", "my-test-layer-2"], + ) + + lambda_obj.create() + code = {"ZipFile": ZIPPED_CODE} + + sagemaker_session.lambda_client.create_function.assert_called_with( + FunctionName=FUNCTION_NAME, + Runtime="python3.8", + Handler=HANDLER, + Role=EXECUTION_ROLE, + Code=code, + Timeout=120, + MemorySize=128, + Architectures=["x86_64"], + VpcConfig={"SubnetIds": ["test-subnet-1"], "SecurityGroupIds": ["sec-group-1"]}, + Environment={"Name": "my-test-lambda"}, + Layers=["my-test-layer-1", "my-test-layer-2"], ) @@ -241,7 +284,12 @@ def test_create_lambda_client_error(sagemaker_session): session=sagemaker_session, ) sagemaker_session.lambda_client.create_function.side_effect = ClientError( - {"Error": {"Code": "ResourceConflictException", "Message": "Function already exists"}}, + { + "Error": { + "Code": "ResourceConflictException", + "Message": "Function already exists", + } + }, "CreateFunction", ) @@ -264,7 +312,9 @@ def test_update_lambda_happycase1(sagemaker_session): lambda_obj.update() sagemaker_session.lambda_client.update_function_code.assert_called_with( - FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE + FunctionName=FUNCTION_NAME, + ZipFile=ZIPPED_CODE, + Architectures=None, ) @@ -282,7 +332,35 @@ def test_update_lambda_happycase2(sagemaker_session): lambda_obj.update() sagemaker_session.lambda_client.update_function_code.assert_called_with( - FunctionName=LAMBDA_ARN, S3Bucket=S3_BUCKET, S3Key=S3_KEY + FunctionName=LAMBDA_ARN, + S3Bucket=S3_BUCKET, + S3Key=S3_KEY, + Architectures=None, + ) + + +@patch("sagemaker.lambda_helper._zip_lambda_code", return_value=ZIPPED_CODE) +def test_update_lambda_happycase3(sagemaker_session): + lambda_obj = lambda_helper.Lambda( + function_name=FUNCTION_NAME, + execution_role_arn=EXECUTION_ROLE, + script=SCRIPT, + handler=HANDLER, + session=sagemaker_session, + architectures=["x86_64"], + environment={"Name": "my-test-lambda"}, + vpc_config={ + "SubnetIds": ["test-subnet-1"], + "SecurityGroupIds": ["sec-group-1"], + }, + ) + + lambda_obj.update() + + sagemaker_session.lambda_client.update_function_code.assert_called_with( + FunctionName=FUNCTION_NAME, + ZipFile=ZIPPED_CODE, + Architectures=["x86_64"], ) @@ -302,6 +380,7 @@ def test_update_lambda_s3bucket_not_provided(s3_upload, sagemaker_session): FunctionName=LAMBDA_ARN, S3Bucket=sagemaker_session.default_bucket(), S3Key=s3_upload.return_value, + Architectures=None, ) @@ -346,6 +425,10 @@ def test_upsert_lambda_happycase1(sagemaker_session): Code=code, Timeout=120, MemorySize=128, + Architectures=None, + VpcConfig=None, + Environment=None, + Layers=None, ) @@ -360,14 +443,19 @@ def test_upsert_lambda_happycase2(sagemaker_session): ) sagemaker_session.lambda_client.create_function.side_effect = ClientError( - {"Error": {"Code": "ResourceConflictException", "Message": "Lambda already exists"}}, + { + "Error": { + "Code": "ResourceConflictException", + "Message": "Lambda already exists", + } + }, "CreateFunction", ) lambda_obj.upsert() sagemaker_session.lambda_client.update_function_code.assert_called_once_with( - FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE + FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE, Architectures=None ) @@ -382,12 +470,22 @@ def test_upsert_lambda_client_error(sagemaker_session): ) sagemaker_session.lambda_client.create_function.side_effect = ClientError( - {"Error": {"Code": "ResourceConflictException", "Message": "Lambda already exists"}}, + { + "Error": { + "Code": "ResourceConflictException", + "Message": "Lambda already exists", + } + }, "CreateFunction", ) sagemaker_session.lambda_client.update_function_code.side_effect = ClientError( - {"Error": {"Code": "ResourceConflictException", "Message": "Cannot update code"}}, + { + "Error": { + "Code": "ResourceConflictException", + "Message": "Cannot update code", + } + }, "UpdateFunctionCode", ) @@ -410,7 +508,8 @@ def test_invoke_lambda_client_error(sagemaker_session): lambda_obj = lambda_helper.Lambda(function_arn=LAMBDA_ARN, session=sagemaker_session) sagemaker_session.lambda_client.invoke.side_effect = ClientError( - {"Error": {"Code": "InvalidCodeException", "Message": "invoke failed"}}, "Invoke" + {"Error": {"Code": "InvalidCodeException", "Message": "invoke failed"}}, + "Invoke", ) with pytest.raises(ValueError) as error: lambda_obj.invoke()