Skip to content

Commit 79a64a6

Browse files
author
Dewen Qi
committed
fix: improve Pipeline workflow unit test branch coverage
fix: Update to use the new JsonGet constructor fix: Update to use the new JsonGet constructor
1 parent bb7563f commit 79a64a6

14 files changed

+274
-35
lines changed

src/sagemaker/workflow/_repack_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from distutils.dir_util import copy_tree
3535

3636

37-
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
3838
"""Repack custom dependencies and code into an existing model TAR archive
3939
4040
Args:
@@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
9595
copy_tree(src_dir, "/opt/ml/model")
9696

9797

98-
if __name__ == "__main__":
98+
if __name__ == "__main__": # pragma: no cover
9999
parser = argparse.ArgumentParser()
100100
parser.add_argument("--inference_script", type=str, default="inference.py")
101101
parser.add_argument("--dependencies", type=str, default=None)

src/sagemaker/workflow/_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
name: str,
5555
sagemaker_session,
5656
role,
57-
model_data: str,
57+
model_data: Union[str, Properties],
5858
entry_point: str,
5959
display_name: str = None,
6060
description: str = None,
@@ -79,8 +79,8 @@ def __init__(
7979
endpoints use this role to access training data and model
8080
artifacts. After the endpoint is created, the inference code
8181
might use the IAM role, if it needs to access an AWS resource.
82-
model_data (str): The S3 location of a SageMaker model data
83-
``.tar.gz`` file (default: None).
82+
model_data (str or Properties): The S3 location of a SageMaker model data
83+
``.tar.gz`` file.
8484
entry_point (str): Path (absolute or relative) to the local Python
8585
source file which should be executed as the entry point to
8686
inference. If ``source_dir`` is specified, then ``entry_point``

src/sagemaker/workflow/condition_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def properties(self):
9595

9696

9797
@attr.s
98-
class JsonGet(Expression):
98+
class JsonGet(Expression): # pragma: no cover
9999
"""Get JSON properties from PropertyFiles.
100100
101101
Attributes:

src/sagemaker/workflow/functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class JsonGet(Expression):
7575
@property
7676
def expr(self):
7777
"""The expression dict for a `JsonGet` function."""
78-
if not isinstance(self.step_name, str):
79-
raise ValueError("Please give step name as a string")
78+
if not isinstance(self.step_name, str) or not self.step_name:
79+
raise ValueError("Please give a valid step name as a string")
8080

8181
if isinstance(self.property_file, PropertyFile):
8282
name = self.property_file.name

src/sagemaker/workflow/lambda_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def _get_function_arn(self):
161161
partition = "aws"
162162

163163
if self.lambda_func.function_arn is None:
164+
account_id = self.lambda_func.session.account_id()
164165
try:
165-
account_id = self.lambda_func.session.account_id()
166166
response = self.lambda_func.create()
167167
return response["FunctionArn"]
168168
except ValueError as error:

src/sagemaker/workflow/step_collections.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,15 @@ def __init__(
318318
"""
319319
steps = []
320320
if "entry_point" in kwargs:
321-
entry_point = kwargs["entry_point"]
322-
source_dir = kwargs.get("source_dir")
323-
dependencies = kwargs.get("dependencies")
321+
entry_point = kwargs.get("entry_point")
322+
source_dir = kwargs.get("source_dir", None)
323+
dependencies = kwargs.get("dependencies", None)
324324
repack_model_step = _RepackModelStep(
325325
name=f"{name}RepackModel",
326326
depends_on=depends_on,
327327
retry_policies=repack_model_step_retry_policies,
328328
sagemaker_session=estimator.sagemaker_session,
329-
role=estimator.sagemaker_session,
329+
role=estimator.role,
330330
model_data=model_data,
331331
entry_point=entry_point,
332332
source_dir=source_dir,
@@ -352,7 +352,11 @@ def predict_wrapper(endpoint, session):
352352
vpc_config=None,
353353
sagemaker_session=estimator.sagemaker_session,
354354
role=estimator.role,
355-
**kwargs,
355+
env=kwargs.get("env", None),
356+
name=kwargs.get("name", None),
357+
enable_network_isolation=kwargs.get("enable_network_isolation", None),
358+
model_kms_key=kwargs.get("model_kms_key", None),
359+
image_config=kwargs.get("image_config", None),
356360
)
357361
model_step = CreateModelStep(
358362
name=f"{name}CreateModelStep",

tests/integ/test_workflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2832,7 +2832,7 @@ def test_end_to_end_pipeline_successful_execution(
28322832
# define condition step
28332833
cond_lte = ConditionLessThanOrEqualTo(
28342834
left=JsonGet(
2835-
step=step_eval,
2835+
step_name=step_eval.name,
28362836
property_file=evaluation_report,
28372837
json_path="regression_metrics.mse.value",
28382838
),

tests/integ/test_workflow_with_clarify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from sagemaker.processing import ProcessingInput, ProcessingOutput
3434
from sagemaker.session import get_execution_role
3535
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
36-
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
36+
from sagemaker.workflow.condition_step import ConditionStep
37+
from sagemaker.workflow.functions import JsonGet
3738
from sagemaker.workflow.parameters import (
3839
ParameterInteger,
3940
ParameterString,
@@ -237,7 +238,7 @@ def test_workflow_with_clarify(
237238
)
238239

239240
cond_left = JsonGet(
240-
step=step_process,
241+
step_name=step_process.name,
241242
property_file="BiasOutput",
242243
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
243244
)

tests/unit/sagemaker/workflow/test_functions.py

+22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
import pytest
17+
1618
from sagemaker.workflow.execution_variables import ExecutionVariables
1719
from sagemaker.workflow.functions import Join, JsonGet
1820
from sagemaker.workflow.parameters import (
@@ -97,3 +99,23 @@ def test_json_get_expressions():
9799
"Path": "my-json-path",
98100
},
99101
}
102+
103+
104+
def test_json_get_expressions_with_invalid_step_name():
105+
with pytest.raises(ValueError) as err:
106+
JsonGet(
107+
step_name="",
108+
property_file="my-property-file",
109+
json_path="my-json-path",
110+
).expr
111+
112+
assert "Please give a valid step name as a string" in str(err.value)
113+
114+
with pytest.raises(ValueError) as err:
115+
JsonGet(
116+
step_name=ParameterString(name="MyString"),
117+
property_file="my-property-file",
118+
json_path="my-json-path",
119+
).expr
120+
121+
assert "Please give a valid step name as a string" in str(err.value)

tests/unit/sagemaker/workflow/test_lambda_step.py

+68-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.workflow.pipeline import Pipeline
2323
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
2424
from sagemaker.lambda_helper import Lambda
25+
from sagemaker.workflow.steps import CacheConfig
2526

2627

2728
@pytest.fixture()
@@ -38,10 +39,25 @@ def sagemaker_session():
3839
return session_mock
3940

4041

42+
@pytest.fixture()
43+
def sagemaker_session_cn():
44+
boto_mock = Mock(name="boto_session", region_name="cn-north-1")
45+
session_mock = MagicMock(
46+
name="sagemaker_session",
47+
boto_session=boto_mock,
48+
boto_region_name="cn-north-1",
49+
config=None,
50+
local_mode=False,
51+
)
52+
session_mock.account_id.return_value = "234567890123"
53+
return session_mock
54+
55+
4156
def test_lambda_step(sagemaker_session):
4257
param = ParameterInteger(name="MyInt")
43-
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
44-
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
58+
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
59+
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
60+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
4561
lambda_step = LambdaStep(
4662
name="MyLambdaStep",
4763
depends_on=["TestStep"],
@@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session):
5268
display_name="MyLambdaStep",
5369
description="MyLambdaStepDescription",
5470
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
55-
outputs=[outputParam1, outputParam2],
71+
outputs=[output_param1, output_param2],
72+
cache_config=cache_config,
5673
)
5774
lambda_step.add_depends_on(["SecondTestStep"])
58-
assert lambda_step.to_request() == {
75+
pipeline = Pipeline(
76+
name="MyPipeline",
77+
parameters=[param],
78+
steps=[lambda_step],
79+
sagemaker_session=sagemaker_session,
80+
)
81+
assert json.loads(pipeline.definition())["Steps"][0] == {
5982
"Name": "MyLambdaStep",
6083
"Type": "Lambda",
6184
"DependsOn": ["TestStep", "SecondTestStep"],
@@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session):
6689
{"OutputName": "output1", "OutputType": "String"},
6790
{"OutputName": "output2", "OutputType": "Boolean"},
6891
],
69-
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
92+
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}},
93+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
7094
}
7195

7296

@@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session):
95119

96120
def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
97121
parameter = ParameterString("MyStr")
98-
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
99-
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
122+
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
123+
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
100124
lambda_step1 = LambdaStep(
101125
name="MyLambdaStep1",
102126
depends_on=["TestStep"],
@@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
105129
session=sagemaker_session,
106130
),
107131
inputs={"arg1": "foo"},
108-
outputs=[outputParam1],
132+
outputs=[output_param1],
109133
)
110134
lambda_step2 = LambdaStep(
111135
name="MyLambdaStep2",
@@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
114138
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
115139
session=sagemaker_session,
116140
),
117-
inputs={"arg1": outputParam1},
118-
outputs=[outputParam2],
141+
inputs={"arg1": output_param1},
142+
outputs=[output_param2],
119143
)
120144

121145
pipeline = Pipeline(
@@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session):
207231
)
208232
lambda_step._get_function_arn()
209233
sagemaker_session.account_id.assert_called_once()
234+
235+
236+
def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn):
237+
lambda_func = MagicMock(
238+
function_arn=None,
239+
function_name="name",
240+
execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role",
241+
zipped_code_dir="",
242+
handler="",
243+
session=sagemaker_session_cn,
244+
)
245+
# The raised ValueError contains ResourceConflictException
246+
lambda_func.create.side_effect = ValueError("ResourceConflictException")
247+
lambda_step1 = LambdaStep(
248+
name="MyLambdaStep1",
249+
depends_on=["TestStep"],
250+
lambda_func=lambda_func,
251+
inputs={},
252+
outputs=[],
253+
)
254+
function_arn = lambda_step1._get_function_arn()
255+
assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name"
256+
257+
# The raised ValueError does not contain ResourceConflictException
258+
lambda_func.create.side_effect = ValueError()
259+
lambda_step2 = LambdaStep(
260+
name="MyLambdaStep2",
261+
depends_on=["TestStep"],
262+
lambda_func=lambda_func,
263+
inputs={},
264+
outputs=[],
265+
)
266+
with pytest.raises(ValueError):
267+
lambda_step2._get_function_arn()

0 commit comments

Comments
 (0)