Skip to content

Commit 9711fd6

Browse files
committed
add parameterized tests to transformer
1 parent 60723ed commit 9711fd6

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

tests/unit/sagemaker/workflow/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def __init__(self, name, display_name=None, description=None, depends_on=None):
4141
super(CustomStep, self).__init__(
4242
name, display_name, description, StepTypeEnum.TRAINING, depends_on
4343
)
44-
self._properties = Properties(path=f"Steps.{name}")
44+
# for testing property reference, we just use DescribeTrainingJobResponse shape here.
45+
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse")
4546

4647
@property
4748
def arguments(self):

tests/unit/sagemaker/workflow/test_transform_step.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from sagemaker.parameter import IntegerParameter
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
26+
from tests.unit.sagemaker.workflow.helpers import CustomStep
2627

2728
from sagemaker.workflow.steps import TransformStep, TransformInput
2829
from sagemaker.workflow.pipeline import Pipeline
2930
from sagemaker.workflow.parameters import ParameterString
31+
from sagemaker.workflow.functions import Join
32+
from sagemaker.workflow import is_pipeline_variable
3033

3134
from sagemaker.transformer import Transformer
3235

@@ -53,6 +56,7 @@ def client():
5356
client_mock._client_config.user_agent = (
5457
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
5558
)
59+
client_mock.describe_model.return_value = {"PrimaryContainer": {}, "Containers": {}}
5660
return client_mock
5761

5862

@@ -80,18 +84,35 @@ def pipeline_session(boto_session, client):
8084
)
8185

8286

83-
def test_transform_step_with_transformer(pipeline_session):
84-
model_name = ParameterString("ModelName")
87+
@pytest.mark.parametrize("model_name", [
88+
"my-model",
89+
ParameterString("ModelName"),
90+
ParameterString("ModelName", default_value="my-model"),
91+
Join(on="-", values=["my", "model"]),
92+
CustomStep(name="custom-step").properties.RoleArn
93+
])
94+
@pytest.mark.parametrize("data", [
95+
"s3://my-bucket/my-data",
96+
ParameterString("MyTransformInput"),
97+
ParameterString("MyTransformInput", default_value="s3://my-model"),
98+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "input"]),
99+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath
100+
])
101+
@pytest.mark.parametrize("output_path", [
102+
"s3://my-bucket/my-output-path",
103+
ParameterString("MyOutputPath"),
104+
ParameterString("MyOutputPath", default_value="s3://my-output"),
105+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "output"]),
106+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath
107+
])
108+
def test_transform_step_with_transformer(model_name, data, output_path, pipeline_session):
85109
transformer = Transformer(
86110
model_name=model_name,
87111
instance_type="ml.m5.xlarge",
88112
instance_count=1,
89-
output_path=f"s3://{pipeline_session.default_bucket()}/Transform",
113+
output_path=output_path,
90114
sagemaker_session=pipeline_session,
91115
)
92-
data = ParameterString(
93-
name="Data", default_value=f"s3://{pipeline_session.default_bucket()}/batch-data"
94-
)
95116
transform_inputs = TransformInput(data=data)
96117

97118
with warnings.catch_warnings(record=True) as w:
@@ -123,12 +144,21 @@ def test_transform_step_with_transformer(pipeline_session):
123144
parameters=[model_name, data],
124145
sagemaker_session=pipeline_session,
125146
)
126-
step_args.args["ModelName"] = model_name.expr
127-
step_args.args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr
128-
assert json.loads(pipeline.definition())["Steps"][0] == {
147+
step_args = step_args.args
148+
step_def = json.loads(pipeline.definition())["Steps"][0]
149+
step_args["ModelName"] = model_name.expr if is_pipeline_variable(model_name) else model_name
150+
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr if is_pipeline_variable(data) else data
151+
step_args["TransformOutput"]["S3OutputPath"] = output_path.expr if is_pipeline_variable(
152+
output_path) else output_path
153+
154+
del step_args["ModelName"], step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], \
155+
step_args["TransformOutput"]["S3OutputPath"]
156+
del step_def['Arguments']["ModelName"], step_def['Arguments']["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], \
157+
step_def['Arguments']["TransformOutput"]["S3OutputPath"]
158+
assert step_def == {
129159
"Name": "MyTransformStep",
130160
"Type": "Transform",
131-
"Arguments": step_args.args,
161+
"Arguments": step_args
132162
}
133163

134164

0 commit comments

Comments
 (0)