Skip to content

Commit 05a33d0

Browse files
author
Rohan Gujarathi
committed
fix: unit tests
1 parent db7c577 commit 05a33d0

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

tests/unit/sagemaker/workflow/test_lambda_step.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,29 @@
2424
from sagemaker.lambda_helper import Lambda
2525

2626

27-
@pytest.fixture
28-
def sagemaker_session_mock():
29-
return Mock()
27+
@pytest.fixture()
28+
def sagemaker_session():
29+
boto_mock = Mock(name="boto_session", region_name="us-west-2")
30+
session_mock = Mock(
31+
name="sagemaker_session",
32+
boto_session=boto_mock,
33+
boto_region_name="us-west-2",
34+
config=None,
35+
local_mode=False,
36+
)
37+
return session_mock
3038

3139

32-
def test_lambda_step():
40+
def test_lambda_step(sagemaker_session):
3341
param = ParameterInteger(name="MyInt")
3442
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
3543
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
3644
lambda_step = LambdaStep(
3745
name="MyLambdaStep",
3846
depends_on=["TestStep"],
3947
lambda_func=Lambda(
40-
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"
48+
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
49+
session=sagemaker_session,
4150
),
4251
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
4352
outputs=[outputParam1, outputParam2],
@@ -56,15 +65,16 @@ def test_lambda_step():
5665
}
5766

5867

59-
def test_lambda_step_output_expr():
68+
def test_lambda_step_output_expr(sagemaker_session):
6069
param = ParameterInteger(name="MyInt")
6170
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
6271
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
6372
lambda_step = LambdaStep(
6473
name="MyLambdaStep",
6574
depends_on=["TestStep"],
6675
lambda_func=Lambda(
67-
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"
76+
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
77+
session=sagemaker_session,
6878
),
6979
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
7080
outputs=[outputParam1, outputParam2],
@@ -78,15 +88,16 @@ def test_lambda_step_output_expr():
7888
}
7989

8090

81-
def test_pipeline_interpolates_lambda_outputs():
91+
def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
8292
parameter = ParameterString("MyStr")
8393
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
8494
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
8595
lambda_step1 = LambdaStep(
8696
name="MyLambdaStep1",
8797
depends_on=["TestStep"],
8898
lambda_func=Lambda(
89-
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"
99+
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
100+
session=sagemaker_session,
90101
),
91102
inputs={"arg1": "foo"},
92103
outputs=[outputParam1],
@@ -105,7 +116,7 @@ def test_pipeline_interpolates_lambda_outputs():
105116
name="MyPipeline",
106117
parameters=[parameter],
107118
steps=[lambda_step1, lambda_step2],
108-
sagemaker_session=sagemaker_session_mock,
119+
sagemaker_session=sagemaker_session,
109120
)
110121

111122
assert json.loads(pipeline.definition()) == {
@@ -137,12 +148,13 @@ def test_pipeline_interpolates_lambda_outputs():
137148
}
138149

139150

140-
def test_lambda_step_no_inputs_outputs():
151+
def test_lambda_step_no_inputs_outputs(sagemaker_session):
141152
lambda_step = LambdaStep(
142153
name="MyLambdaStep",
143154
depends_on=["TestStep"],
144155
lambda_func=Lambda(
145-
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"
156+
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
157+
session=sagemaker_session,
146158
),
147159
inputs={},
148160
outputs=[],

0 commit comments

Comments
 (0)