Skip to content

Commit cb830cc

Browse files
qidewenwhenDewen Qi
authored andcommitted
fix: Improve Pipeline workflow unit test branch coverage (#2878)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 3f4b3a7 commit cb830cc

14 files changed

+273
-35
lines changed

src/sagemaker/workflow/_repack_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from distutils.dir_util import copy_tree
3535

3636

37-
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
3838
"""Repack custom dependencies and code into an existing model TAR archive
3939
4040
Args:
@@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
9595
copy_tree(src_dir, "/opt/ml/model")
9696

9797

98-
if __name__ == "__main__":
98+
if __name__ == "__main__": # pragma: no cover
9999
parser = argparse.ArgumentParser()
100100
parser.add_argument("--inference_script", type=str, default="inference.py")
101101
parser.add_argument("--dependencies", type=str, default=None)

src/sagemaker/workflow/_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
artifacts. After the endpoint is created, the inference code
8181
might use the IAM role, if it needs to access an AWS resource.
8282
model_data (str): The S3 location of a SageMaker model data
83-
``.tar.gz`` file (default: None).
83+
``.tar.gz`` file.
8484
entry_point (str): Path (absolute or relative) to the local Python
8585
source file which should be executed as the entry point to
8686
inference. If ``source_dir`` is specified, then ``entry_point``

src/sagemaker/workflow/condition_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def properties(self):
9595

9696

9797
@attr.s
98-
class JsonGet(Expression):
98+
class JsonGet(Expression): # pragma: no cover
9999
"""Get JSON properties from PropertyFiles.
100100
101101
Attributes:

src/sagemaker/workflow/functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class JsonGet(Expression):
7575
@property
7676
def expr(self):
7777
"""The expression dict for a `JsonGet` function."""
78-
if not isinstance(self.step_name, str):
79-
raise ValueError("Please give step name as a string")
78+
if not isinstance(self.step_name, str) or not self.step_name:
79+
raise ValueError("Please give a valid step name as a string")
8080

8181
if isinstance(self.property_file, PropertyFile):
8282
name = self.property_file.name

src/sagemaker/workflow/lambda_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def _get_function_arn(self):
161161
partition = "aws"
162162

163163
if self.lambda_func.function_arn is None:
164+
account_id = self.lambda_func.session.account_id()
164165
try:
165-
account_id = self.lambda_func.session.account_id()
166166
response = self.lambda_func.create()
167167
return response["FunctionArn"]
168168
except ValueError as error:

src/sagemaker/workflow/step_collections.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,15 @@ def __init__(
323323
"""
324324
steps = []
325325
if "entry_point" in kwargs:
326-
entry_point = kwargs["entry_point"]
327-
source_dir = kwargs.get("source_dir")
328-
dependencies = kwargs.get("dependencies")
326+
entry_point = kwargs.get("entry_point", None)
327+
source_dir = kwargs.get("source_dir", None)
328+
dependencies = kwargs.get("dependencies", None)
329329
repack_model_step = _RepackModelStep(
330330
name=f"{name}RepackModel",
331331
depends_on=depends_on,
332332
retry_policies=repack_model_step_retry_policies,
333333
sagemaker_session=estimator.sagemaker_session,
334-
role=estimator.sagemaker_session,
334+
role=estimator.role,
335335
model_data=model_data,
336336
entry_point=entry_point,
337337
source_dir=source_dir,
@@ -357,7 +357,11 @@ def predict_wrapper(endpoint, session):
357357
vpc_config=None,
358358
sagemaker_session=estimator.sagemaker_session,
359359
role=estimator.role,
360-
**kwargs,
360+
env=kwargs.get("env", None),
361+
name=kwargs.get("name", None),
362+
enable_network_isolation=kwargs.get("enable_network_isolation", None),
363+
model_kms_key=kwargs.get("model_kms_key", None),
364+
image_config=kwargs.get("image_config", None),
361365
)
362366
model_step = CreateModelStep(
363367
name=f"{name}CreateModelStep",

tests/integ/test_workflow.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
ConditionLessThanOrEqualTo,
6868
)
6969
from sagemaker.workflow.condition_step import ConditionStep
70-
from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet
7170
from sagemaker.workflow.callback_step import (
7271
CallbackStep,
7372
CallbackOutput,
@@ -2835,8 +2834,8 @@ def test_end_to_end_pipeline_successful_execution(
28352834

28362835
# define condition step
28372836
cond_lte = ConditionLessThanOrEqualTo(
2838-
left=ConditionStepJsonGet(
2839-
step=step_eval,
2837+
left=JsonGet(
2838+
step_name=step_eval.name,
28402839
property_file=evaluation_report,
28412840
json_path="regression_metrics.mse.value",
28422841
),

tests/integ/test_workflow_with_clarify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from sagemaker.processing import ProcessingInput, ProcessingOutput
3434
from sagemaker.session import get_execution_role
3535
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
36-
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
36+
from sagemaker.workflow.condition_step import ConditionStep
37+
from sagemaker.workflow.functions import JsonGet
3738
from sagemaker.workflow.parameters import (
3839
ParameterInteger,
3940
ParameterString,
@@ -237,7 +238,7 @@ def test_workflow_with_clarify(
237238
)
238239

239240
cond_left = JsonGet(
240-
step=step_process,
241+
step_name=step_process.name,
241242
property_file="BiasOutput",
242243
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
243244
)

tests/unit/sagemaker/workflow/test_functions.py

+22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
import pytest
17+
1618
from sagemaker.workflow.execution_variables import ExecutionVariables
1719
from sagemaker.workflow.functions import Join, JsonGet
1820
from sagemaker.workflow.parameters import (
@@ -97,3 +99,23 @@ def test_json_get_expressions():
9799
"Path": "my-json-path",
98100
},
99101
}
102+
103+
104+
def test_json_get_expressions_with_invalid_step_name():
105+
with pytest.raises(ValueError) as err:
106+
JsonGet(
107+
step_name="",
108+
property_file="my-property-file",
109+
json_path="my-json-path",
110+
).expr
111+
112+
assert "Please give a valid step name as a string" in str(err.value)
113+
114+
with pytest.raises(ValueError) as err:
115+
JsonGet(
116+
step_name=ParameterString(name="MyString"),
117+
property_file="my-property-file",
118+
json_path="my-json-path",
119+
).expr
120+
121+
assert "Please give a valid step name as a string" in str(err.value)

tests/unit/sagemaker/workflow/test_lambda_step.py

+68-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.workflow.pipeline import Pipeline
2323
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
2424
from sagemaker.lambda_helper import Lambda
25+
from sagemaker.workflow.steps import CacheConfig
2526

2627

2728
@pytest.fixture()
@@ -38,10 +39,25 @@ def sagemaker_session():
3839
return session_mock
3940

4041

42+
@pytest.fixture()
43+
def sagemaker_session_cn():
44+
boto_mock = Mock(name="boto_session", region_name="cn-north-1")
45+
session_mock = MagicMock(
46+
name="sagemaker_session",
47+
boto_session=boto_mock,
48+
boto_region_name="cn-north-1",
49+
config=None,
50+
local_mode=False,
51+
)
52+
session_mock.account_id.return_value = "234567890123"
53+
return session_mock
54+
55+
4156
def test_lambda_step(sagemaker_session):
4257
param = ParameterInteger(name="MyInt")
43-
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
44-
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
58+
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
59+
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
60+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
4561
lambda_step = LambdaStep(
4662
name="MyLambdaStep",
4763
depends_on=["TestStep"],
@@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session):
5268
display_name="MyLambdaStep",
5369
description="MyLambdaStepDescription",
5470
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
55-
outputs=[outputParam1, outputParam2],
71+
outputs=[output_param1, output_param2],
72+
cache_config=cache_config,
5673
)
5774
lambda_step.add_depends_on(["SecondTestStep"])
58-
assert lambda_step.to_request() == {
75+
pipeline = Pipeline(
76+
name="MyPipeline",
77+
parameters=[param],
78+
steps=[lambda_step],
79+
sagemaker_session=sagemaker_session,
80+
)
81+
assert json.loads(pipeline.definition())["Steps"][0] == {
5982
"Name": "MyLambdaStep",
6083
"Type": "Lambda",
6184
"DependsOn": ["TestStep", "SecondTestStep"],
@@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session):
6689
{"OutputName": "output1", "OutputType": "String"},
6790
{"OutputName": "output2", "OutputType": "Boolean"},
6891
],
69-
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
92+
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}},
93+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
7094
}
7195

7296

@@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session):
95119

96120
def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
97121
parameter = ParameterString("MyStr")
98-
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
99-
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
122+
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
123+
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
100124
lambda_step1 = LambdaStep(
101125
name="MyLambdaStep1",
102126
depends_on=["TestStep"],
@@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
105129
session=sagemaker_session,
106130
),
107131
inputs={"arg1": "foo"},
108-
outputs=[outputParam1],
132+
outputs=[output_param1],
109133
)
110134
lambda_step2 = LambdaStep(
111135
name="MyLambdaStep2",
@@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
114138
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
115139
session=sagemaker_session,
116140
),
117-
inputs={"arg1": outputParam1},
118-
outputs=[outputParam2],
141+
inputs={"arg1": output_param1},
142+
outputs=[output_param2],
119143
)
120144

121145
pipeline = Pipeline(
@@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session):
207231
)
208232
lambda_step._get_function_arn()
209233
sagemaker_session.account_id.assert_called_once()
234+
235+
236+
def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn):
237+
lambda_func = MagicMock(
238+
function_arn=None,
239+
function_name="name",
240+
execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role",
241+
zipped_code_dir="",
242+
handler="",
243+
session=sagemaker_session_cn,
244+
)
245+
# The raised ValueError contains ResourceConflictException
246+
lambda_func.create.side_effect = ValueError("ResourceConflictException")
247+
lambda_step1 = LambdaStep(
248+
name="MyLambdaStep1",
249+
depends_on=["TestStep"],
250+
lambda_func=lambda_func,
251+
inputs={},
252+
outputs=[],
253+
)
254+
function_arn = lambda_step1._get_function_arn()
255+
assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name"
256+
257+
# The raised ValueError does not contain ResourceConflictException
258+
lambda_func.create.side_effect = ValueError()
259+
lambda_step2 = LambdaStep(
260+
name="MyLambdaStep2",
261+
depends_on=["TestStep"],
262+
lambda_func=lambda_func,
263+
inputs={},
264+
outputs=[],
265+
)
266+
with pytest.raises(ValueError):
267+
lambda_step2._get_function_arn()

0 commit comments

Comments
 (0)