Skip to content

feature: Add Extra Parameters to Lambda Function Wrapper #3594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 27, 2023
Merged
19 changes: 19 additions & 0 deletions src/sagemaker/lambda_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(
timeout: int = 120,
memory_size: int = 128,
runtime: str = "python3.8",
vpc_config: dict = None,
architectures: list = None,
environment: dict = None,
layers: list = None,
):
"""Constructs a Lambda instance.

Expand Down Expand Up @@ -66,6 +70,11 @@ def __init__(
timeout (int): Timeout of the Lambda function in seconds. Default is 120 seconds.
memory_size (int): Memory of the Lambda function in megabytes. Default is 128 MB.
runtime (str): Runtime of the Lambda function. Default is set to python3.8.
vpc_config (dict): VPC to deploy the Lambda function to. Default is None.
architectures (list): Which architecture to deploy to. Valid Values are
'x86_64' and 'arm64', default is None.
environment (dict): Environment Variables for the Lambda function. Default is None.
layers (list): List of Lambda layers for the Lambda function. Default is None.
"""
self.function_arn = function_arn
self.function_name = function_name
Expand All @@ -78,6 +87,10 @@ def __init__(
self.timeout = timeout
self.memory_size = memory_size
self.runtime = runtime
self.vpc_config = vpc_config
self.environment = environment
self.architectures = architectures
self.layers = layers

if function_arn is None and function_name is None:
raise ValueError("Either function_arn or function_name must be provided.")
Expand Down Expand Up @@ -127,6 +140,10 @@ def create(self):
Code=code,
Timeout=self.timeout,
MemorySize=self.memory_size,
VpcConfig=self.vpc_config,
Environment=self.environment,
Architectures=self.architectures,
Layers=self.layers,
)
return response
except ClientError as e:
Expand All @@ -146,6 +163,7 @@ def update(self):
response = lambda_client.update_function_code(
FunctionName=self.function_name or self.function_arn,
ZipFile=_zip_lambda_code(self.script),
Architectures=self.architectures,
)
else:
bucket = self.s3_bucket or self.session.default_bucket()
Expand All @@ -168,6 +186,7 @@ def update(self):
zipped_code_dir=self.zipped_code_dir,
s3_bucket=bucket,
),
Architectures=self.architectures,
)
return response
except ClientError as e:
Expand Down
115 changes: 107 additions & 8 deletions tests/unit/test_lambda_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def test_create_lambda_happycase1(sagemaker_session):
Code=code,
Timeout=120,
MemorySize=128,
Architectures=None,
VpcConfig=None,
Environment=None,
Layers=None,
)


Expand All @@ -212,6 +216,45 @@ def test_create_lambda_happycase2(sagemaker_session):
Code=code,
Timeout=120,
MemorySize=128,
Architectures=None,
VpcConfig=None,
Environment=None,
Layers=None,
)


@patch("sagemaker.lambda_helper._zip_lambda_code", return_value=ZIPPED_CODE)
def test_create_lambda_happycase3(sagemaker_session):
lambda_obj = lambda_helper.Lambda(
function_name=FUNCTION_NAME,
execution_role_arn=EXECUTION_ROLE,
script=SCRIPT,
handler=HANDLER,
session=sagemaker_session,
architectures=["x86_64"],
environment={"Name": "my-test-lambda"},
vpc_config={
"SubnetIds": ["test-subnet-1"],
"SecurityGroupIds": ["sec-group-1"],
},
layers=["my-test-layer-1", "my-test-layer-2"],
)

lambda_obj.create()
code = {"ZipFile": ZIPPED_CODE}

sagemaker_session.lambda_client.create_function.assert_called_with(
FunctionName=FUNCTION_NAME,
Runtime="python3.8",
Handler=HANDLER,
Role=EXECUTION_ROLE,
Code=code,
Timeout=120,
MemorySize=128,
Architectures=["x86_64"],
VpcConfig={"SubnetIds": ["test-subnet-1"], "SecurityGroupIds": ["sec-group-1"]},
Environment={"Name": "my-test-lambda"},
Layers=["my-test-layer-1", "my-test-layer-2"],
)


Expand Down Expand Up @@ -241,7 +284,12 @@ def test_create_lambda_client_error(sagemaker_session):
session=sagemaker_session,
)
sagemaker_session.lambda_client.create_function.side_effect = ClientError(
{"Error": {"Code": "ResourceConflictException", "Message": "Function already exists"}},
{
"Error": {
"Code": "ResourceConflictException",
"Message": "Function already exists",
}
},
"CreateFunction",
)

Expand All @@ -264,7 +312,9 @@ def test_update_lambda_happycase1(sagemaker_session):
lambda_obj.update()

sagemaker_session.lambda_client.update_function_code.assert_called_with(
FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE
FunctionName=FUNCTION_NAME,
ZipFile=ZIPPED_CODE,
Architectures=None,
)


Expand All @@ -282,7 +332,35 @@ def test_update_lambda_happycase2(sagemaker_session):
lambda_obj.update()

sagemaker_session.lambda_client.update_function_code.assert_called_with(
FunctionName=LAMBDA_ARN, S3Bucket=S3_BUCKET, S3Key=S3_KEY
FunctionName=LAMBDA_ARN,
S3Bucket=S3_BUCKET,
S3Key=S3_KEY,
Architectures=None,
)


@patch("sagemaker.lambda_helper._zip_lambda_code", return_value=ZIPPED_CODE)
def test_update_lambda_happycase3(sagemaker_session):
lambda_obj = lambda_helper.Lambda(
function_name=FUNCTION_NAME,
execution_role_arn=EXECUTION_ROLE,
script=SCRIPT,
handler=HANDLER,
session=sagemaker_session,
architectures=["x86_64"],
environment={"Name": "my-test-lambda"},
vpc_config={
"SubnetIds": ["test-subnet-1"],
"SecurityGroupIds": ["sec-group-1"],
},
)

lambda_obj.update()

sagemaker_session.lambda_client.update_function_code.assert_called_with(
FunctionName=FUNCTION_NAME,
ZipFile=ZIPPED_CODE,
Architectures=["x86_64"],
)


Expand All @@ -302,6 +380,7 @@ def test_update_lambda_s3bucket_not_provided(s3_upload, sagemaker_session):
FunctionName=LAMBDA_ARN,
S3Bucket=sagemaker_session.default_bucket(),
S3Key=s3_upload.return_value,
Architectures=None,
)


Expand Down Expand Up @@ -346,6 +425,10 @@ def test_upsert_lambda_happycase1(sagemaker_session):
Code=code,
Timeout=120,
MemorySize=128,
Architectures=None,
VpcConfig=None,
Environment=None,
Layers=None,
)


Expand All @@ -360,14 +443,19 @@ def test_upsert_lambda_happycase2(sagemaker_session):
)

sagemaker_session.lambda_client.create_function.side_effect = ClientError(
{"Error": {"Code": "ResourceConflictException", "Message": "Lambda already exists"}},
{
"Error": {
"Code": "ResourceConflictException",
"Message": "Lambda already exists",
}
},
"CreateFunction",
)

lambda_obj.upsert()

sagemaker_session.lambda_client.update_function_code.assert_called_once_with(
FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE
FunctionName=FUNCTION_NAME, ZipFile=ZIPPED_CODE, Architectures=None
)


Expand All @@ -382,12 +470,22 @@ def test_upsert_lambda_client_error(sagemaker_session):
)

sagemaker_session.lambda_client.create_function.side_effect = ClientError(
{"Error": {"Code": "ResourceConflictException", "Message": "Lambda already exists"}},
{
"Error": {
"Code": "ResourceConflictException",
"Message": "Lambda already exists",
}
},
"CreateFunction",
)

sagemaker_session.lambda_client.update_function_code.side_effect = ClientError(
{"Error": {"Code": "ResourceConflictException", "Message": "Cannot update code"}},
{
"Error": {
"Code": "ResourceConflictException",
"Message": "Cannot update code",
}
},
"UpdateFunctionCode",
)

Expand All @@ -410,7 +508,8 @@ def test_invoke_lambda_client_error(sagemaker_session):
lambda_obj = lambda_helper.Lambda(function_arn=LAMBDA_ARN, session=sagemaker_session)

sagemaker_session.lambda_client.invoke.side_effect = ClientError(
{"Error": {"Code": "InvalidCodeException", "Message": "invoke failed"}}, "Invoke"
{"Error": {"Code": "InvalidCodeException", "Message": "invoke failed"}},
"Invoke",
)
with pytest.raises(ValueError) as error:
lambda_obj.invoke()
Expand Down