Skip to content

Commit e7a0519

Browse files
committed
reformatting
1 parent 9711fd6 commit e7a0519

File tree

2 files changed

+51
-35
lines changed

2 files changed

+51
-35
lines changed

tests/unit/sagemaker/workflow/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(self, name, display_name=None, description=None, depends_on=None):
4242
name, display_name, description, StepTypeEnum.TRAINING, depends_on
4343
)
4444
# for testing property reference, we just use DescribeTrainingJobResponse shape here.
45-
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse")
45+
self._properties = Properties(
46+
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
47+
)
4648

4749
@property
4850
def arguments(self):

tests/unit/sagemaker/workflow/test_transform_step.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,36 @@ def pipeline_session(boto_session, client):
8484
)
8585

8686

87-
@pytest.mark.parametrize("model_name", [
88-
"my-model",
89-
ParameterString("ModelName"),
90-
ParameterString("ModelName", default_value="my-model"),
91-
Join(on="-", values=["my", "model"]),
92-
CustomStep(name="custom-step").properties.RoleArn
93-
])
94-
@pytest.mark.parametrize("data", [
95-
"s3://my-bucket/my-data",
96-
ParameterString("MyTransformInput"),
97-
ParameterString("MyTransformInput", default_value="s3://my-model"),
98-
Join(on="/", values=["s3://my-bucket", "my-transform-data", "input"]),
99-
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath
100-
])
101-
@pytest.mark.parametrize("output_path", [
102-
"s3://my-bucket/my-output-path",
103-
ParameterString("MyOutputPath"),
104-
ParameterString("MyOutputPath", default_value="s3://my-output"),
105-
Join(on="/", values=["s3://my-bucket", "my-transform-data", "output"]),
106-
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath
107-
])
87+
@pytest.mark.parametrize(
88+
"model_name",
89+
[
90+
"my-model",
91+
ParameterString("ModelName"),
92+
ParameterString("ModelName", default_value="my-model"),
93+
Join(on="-", values=["my", "model"]),
94+
CustomStep(name="custom-step").properties.RoleArn,
95+
],
96+
)
97+
@pytest.mark.parametrize(
98+
"data",
99+
[
100+
"s3://my-bucket/my-data",
101+
ParameterString("MyTransformInput"),
102+
ParameterString("MyTransformInput", default_value="s3://my-model"),
103+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "input"]),
104+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath,
105+
],
106+
)
107+
@pytest.mark.parametrize(
108+
"output_path",
109+
[
110+
"s3://my-bucket/my-output-path",
111+
ParameterString("MyOutputPath"),
112+
ParameterString("MyOutputPath", default_value="s3://my-output"),
113+
Join(on="/", values=["s3://my-bucket", "my-transform-data", "output"]),
114+
CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath,
115+
],
116+
)
108117
def test_transform_step_with_transformer(model_name, data, output_path, pipeline_session):
109118
transformer = Transformer(
110119
model_name=model_name,
@@ -147,19 +156,24 @@ def test_transform_step_with_transformer(model_name, data, output_path, pipeline
147156
step_args = step_args.args
148157
step_def = json.loads(pipeline.definition())["Steps"][0]
149158
step_args["ModelName"] = model_name.expr if is_pipeline_variable(model_name) else model_name
150-
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr if is_pipeline_variable(data) else data
151-
step_args["TransformOutput"]["S3OutputPath"] = output_path.expr if is_pipeline_variable(
152-
output_path) else output_path
153-
154-
del step_args["ModelName"], step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], \
155-
step_args["TransformOutput"]["S3OutputPath"]
156-
del step_def['Arguments']["ModelName"], step_def['Arguments']["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], \
157-
step_def['Arguments']["TransformOutput"]["S3OutputPath"]
158-
assert step_def == {
159-
"Name": "MyTransformStep",
160-
"Type": "Transform",
161-
"Arguments": step_args
162-
}
159+
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = (
160+
data.expr if is_pipeline_variable(data) else data
161+
)
162+
step_args["TransformOutput"]["S3OutputPath"] = (
163+
output_path.expr if is_pipeline_variable(output_path) else output_path
164+
)
165+
166+
del (
167+
step_args["ModelName"],
168+
step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"],
169+
step_args["TransformOutput"]["S3OutputPath"],
170+
)
171+
del (
172+
step_def["Arguments"]["ModelName"],
173+
step_def["Arguments"]["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"],
174+
step_def["Arguments"]["TransformOutput"]["S3OutputPath"],
175+
)
176+
assert step_def == {"Name": "MyTransformStep", "Type": "Transform", "Arguments": step_args}
163177

164178

165179
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)