Skip to content

Commit e6c210a

Browse files
add parameterized tests to transformer (#3155)
add parameterized tests to transformer
1 parent f14a4f5 commit e6c210a

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

tests/unit/sagemaker/workflow/helpers.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def __init__(self, name, display_name=None, description=None, depends_on=None):
4141
super(CustomStep, self).__init__(
4242
name, display_name, description, StepTypeEnum.TRAINING, depends_on
4343
)
44-
self._properties = Properties(path=f"Steps.{name}")
44+
# for testing property reference, we just use DescribeTrainingJobResponse shape here.
45+
self._properties = Properties(
46+
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
47+
)
4548

4649
@property
4750
def arguments(self):

tests/unit/sagemaker/workflow/test_transform_step.py

+57-13
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from sagemaker.parameter import IntegerParameter
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
26+
from tests.unit.sagemaker.workflow.helpers import CustomStep
2627

2728
from sagemaker.workflow.steps import TransformStep, TransformInput
2829
from sagemaker.workflow.pipeline import Pipeline
2930
from sagemaker.workflow.parameters import ParameterString
31+
from sagemaker.workflow.functions import Join
32+
from sagemaker.workflow import is_pipeline_variable
3033

3134
from sagemaker.transformer import Transformer
3235

@@ -53,6 +56,7 @@ def client():
5356
client_mock._client_config.user_agent = (
5457
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
5558
)
59+
client_mock.describe_model.return_value = {"PrimaryContainer": {}, "Containers": {}}
5660
return client_mock
5761

5862

@@ -80,18 +84,44 @@ def pipeline_session(boto_session, client):
8084
)
8185

8286

83-
def test_transform_step_with_transformer(pipeline_session):
84-
model_name = ParameterString("ModelName")
87+
@pytest.mark.parametrize(
88+
"model_name",
89+
[
90+
"my-model",
91+
ParameterString("ModelName"),
92+
ParameterString("ModelName", default_value="my-model"),
93+
Join(on="-", values=["my", "model"]),
94+
CustomStep(name="custom-step").properties.RoleArn,
95+
],
96+
)
97+
@pytest.mark.parametrize(
98+
"data",
99+
[
100+
"s3://my-bucket/my-data",
101+
ParameterString("MyTransformInput"),
102+
ParameterString("MyTransformInput", default_value="s3://my-model"),
103+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "input"]),
104+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath,
105+
],
106+
)
107+
@pytest.mark.parametrize(
108+
"output_path",
109+
[
110+
"s3://my-bucket/my-output-path",
111+
ParameterString("MyOutputPath"),
112+
ParameterString("MyOutputPath", default_value="s3://my-output"),
113+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "output"]),
114+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath,
115+
],
116+
)
117+
def test_transform_step_with_transformer(model_name, data, output_path, pipeline_session):
85118
transformer = Transformer(
86119
model_name=model_name,
87120
instance_type="ml.m5.xlarge",
88121
instance_count=1,
89-
output_path=f"s3://{pipeline_session.default_bucket()}/Transform",
122+
output_path=output_path,
90123
sagemaker_session=pipeline_session,
91124
)
92-
data = ParameterString(
93-
name="Data", default_value=f"s3://{pipeline_session.default_bucket()}/batch-data"
94-
)
95125
transform_inputs = TransformInput(data=data)
96126

97127
with warnings.catch_warnings(record=True) as w:
@@ -123,13 +153,27 @@ def test_transform_step_with_transformer(pipeline_session):
123153
parameters=[model_name, data],
124154
sagemaker_session=pipeline_session,
125155
)
126-
step_args.args["ModelName"] = model_name.expr
127-
step_args.args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr
128-
assert json.loads(pipeline.definition())["Steps"][0] == {
129-
"Name": "MyTransformStep",
130-
"Type": "Transform",
131-
"Arguments": step_args.args,
132-
}
156+
step_args = step_args.args
157+
step_def = json.loads(pipeline.definition())["Steps"][0]
158+
step_args["ModelName"] = model_name.expr if is_pipeline_variable(model_name) else model_name
159+
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = (
160+
data.expr if is_pipeline_variable(data) else data
161+
)
162+
step_args["TransformOutput"]["S3OutputPath"] = (
163+
output_path.expr if is_pipeline_variable(output_path) else output_path
164+
)
165+
166+
del (
167+
step_args["ModelName"],
168+
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"],
169+
step_args["TransformOutput"]["S3OutputPath"],
170+
)
171+
del (
172+
step_def["Arguments"]["ModelName"],
173+
step_def["Arguments"]["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"],
174+
step_def["Arguments"]["TransformOutput"]["S3OutputPath"],
175+
)
176+
assert step_def == {"Name": "MyTransformStep", "Type": "Transform", "Arguments": step_args}
133177

134178

135179
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)