Skip to content

fix: Improve Pipeline workflow unit test branch coverage #2878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from distutils.dir_util import copy_tree


def repack(inference_script, model_archive, dependencies=None, source_dir=None):
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
"""Repack custom dependencies and code into an existing model TAR archive

Args:
Expand Down Expand Up @@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
copy_tree(src_dir, "/opt/ml/model")


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
parser = argparse.ArgumentParser()
parser.add_argument("--inference_script", type=str, default="inference.py")
parser.add_argument("--dependencies", type=str, default=None)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
artifacts. After the endpoint is created, the inference code
might use the IAM role, if it needs to access an AWS resource.
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file (default: None).
``.tar.gz`` file.
entry_point (str): Path (absolute or relative) to the local Python
source file which should be executed as the entry point to
inference. If ``source_dir`` is specified, then ``entry_point``
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def properties(self):


@attr.s
class JsonGet(Expression):
class JsonGet(Expression): # pragma: no cover
"""Get JSON properties from PropertyFiles.

Attributes:
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class JsonGet(Expression):
@property
def expr(self):
"""The expression dict for a `JsonGet` function."""
if not isinstance(self.step_name, str):
raise ValueError("Please give step name as a string")
if not isinstance(self.step_name, str) or not self.step_name:
raise ValueError("Please give a valid step name as a string")

if isinstance(self.property_file, PropertyFile):
name = self.property_file.name
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/lambda_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _get_function_arn(self):
partition = "aws"

if self.lambda_func.function_arn is None:
account_id = self.lambda_func.session.account_id()
try:
account_id = self.lambda_func.session.account_id()
response = self.lambda_func.create()
return response["FunctionArn"]
except ValueError as error:
Expand Down
14 changes: 9 additions & 5 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,15 @@ def __init__(
"""
steps = []
if "entry_point" in kwargs:
entry_point = kwargs["entry_point"]
source_dir = kwargs.get("source_dir")
dependencies = kwargs.get("dependencies")
entry_point = kwargs.get("entry_point", None)
source_dir = kwargs.get("source_dir", None)
dependencies = kwargs.get("dependencies", None)
repack_model_step = _RepackModelStep(
name=f"{name}RepackModel",
depends_on=depends_on,
retry_policies=repack_model_step_retry_policies,
sagemaker_session=estimator.sagemaker_session,
role=estimator.sagemaker_session,
role=estimator.role,
model_data=model_data,
entry_point=entry_point,
source_dir=source_dir,
Expand All @@ -357,7 +357,11 @@ def predict_wrapper(endpoint, session):
vpc_config=None,
sagemaker_session=estimator.sagemaker_session,
role=estimator.role,
**kwargs,
env=kwargs.get("env", None),
name=kwargs.get("name", None),
enable_network_isolation=kwargs.get("enable_network_isolation", None),
model_kms_key=kwargs.get("model_kms_key", None),
image_config=kwargs.get("image_config", None),
)
model_step = CreateModelStep(
name=f"{name}CreateModelStep",
Expand Down
5 changes: 2 additions & 3 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
ConditionLessThanOrEqualTo,
)
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet
from sagemaker.workflow.callback_step import (
CallbackStep,
CallbackOutput,
Expand Down Expand Up @@ -2835,8 +2834,8 @@ def test_end_to_end_pipeline_successful_execution(

# define condition step
cond_lte = ConditionLessThanOrEqualTo(
left=ConditionStepJsonGet(
step=step_eval,
left=JsonGet(
step_name=step_eval.name,
property_file=evaluation_report,
json_path="regression_metrics.mse.value",
),
Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_workflow_with_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.session import get_execution_role
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.functions import JsonGet
from sagemaker.workflow.parameters import (
ParameterInteger,
ParameterString,
Expand Down Expand Up @@ -237,7 +238,7 @@ def test_workflow_with_clarify(
)

cond_left = JsonGet(
step=step_process,
step_name=step_process.name,
property_file="BiasOutput",
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
)
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/sagemaker/workflow/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest

from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.functions import Join, JsonGet
from sagemaker.workflow.parameters import (
Expand Down Expand Up @@ -97,3 +99,23 @@ def test_json_get_expressions():
"Path": "my-json-path",
},
}


def test_json_get_expressions_with_invalid_step_name():
with pytest.raises(ValueError) as err:
JsonGet(
step_name="",
property_file="my-property-file",
json_path="my-json-path",
).expr

assert "Please give a valid step name as a string" in str(err.value)

with pytest.raises(ValueError) as err:
JsonGet(
step_name=ParameterString(name="MyString"),
property_file="my-property-file",
json_path="my-json-path",
).expr

assert "Please give a valid step name as a string" in str(err.value)
78 changes: 68 additions & 10 deletions tests/unit/sagemaker/workflow/test_lambda_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
from sagemaker.lambda_helper import Lambda
from sagemaker.workflow.steps import CacheConfig


@pytest.fixture()
Expand All @@ -38,10 +39,25 @@ def sagemaker_session():
return session_mock


@pytest.fixture()
def sagemaker_session_cn():
boto_mock = Mock(name="boto_session", region_name="cn-north-1")
session_mock = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name="cn-north-1",
config=None,
local_mode=False,
)
session_mock.account_id.return_value = "234567890123"
return session_mock


def test_lambda_step(sagemaker_session):
param = ParameterInteger(name="MyInt")
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean)
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
lambda_step = LambdaStep(
name="MyLambdaStep",
depends_on=["TestStep"],
Expand All @@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session):
display_name="MyLambdaStep",
description="MyLambdaStepDescription",
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
outputs=[outputParam1, outputParam2],
outputs=[output_param1, output_param2],
cache_config=cache_config,
)
lambda_step.add_depends_on(["SecondTestStep"])
assert lambda_step.to_request() == {
pipeline = Pipeline(
name="MyPipeline",
parameters=[param],
steps=[lambda_step],
sagemaker_session=sagemaker_session,
)
assert json.loads(pipeline.definition())["Steps"][0] == {
"Name": "MyLambdaStep",
"Type": "Lambda",
"DependsOn": ["TestStep", "SecondTestStep"],
Expand All @@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session):
{"OutputName": "output1", "OutputType": "String"},
{"OutputName": "output2", "OutputType": "Boolean"},
],
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}},
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}


Expand Down Expand Up @@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session):

def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
parameter = ParameterString("MyStr")
outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String)
output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String)
lambda_step1 = LambdaStep(
name="MyLambdaStep1",
depends_on=["TestStep"],
Expand All @@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
session=sagemaker_session,
),
inputs={"arg1": "foo"},
outputs=[outputParam1],
outputs=[output_param1],
)
lambda_step2 = LambdaStep(
name="MyLambdaStep2",
Expand All @@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
session=sagemaker_session,
),
inputs={"arg1": outputParam1},
outputs=[outputParam2],
inputs={"arg1": output_param1},
outputs=[output_param2],
)

pipeline = Pipeline(
Expand Down Expand Up @@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session):
)
lambda_step._get_function_arn()
sagemaker_session.account_id.assert_called_once()


def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn):
lambda_func = MagicMock(
function_arn=None,
function_name="name",
execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role",
zipped_code_dir="",
handler="",
session=sagemaker_session_cn,
)
# The raised ValueError contains ResourceConflictException
lambda_func.create.side_effect = ValueError("ResourceConflictException")
lambda_step1 = LambdaStep(
name="MyLambdaStep1",
depends_on=["TestStep"],
lambda_func=lambda_func,
inputs={},
outputs=[],
)
function_arn = lambda_step1._get_function_arn()
assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name"

# The raised ValueError does not contain ResourceConflictException
lambda_func.create.side_effect = ValueError()
lambda_step2 = LambdaStep(
name="MyLambdaStep2",
depends_on=["TestStep"],
lambda_func=lambda_func,
inputs={},
outputs=[],
)
with pytest.raises(ValueError):
lambda_step2._get_function_arn()
Loading