Skip to content

Commit 459e3a9

Browse files
jayatalrahsan-z-khan
authored andcommitted
fix: support maps in step parameters (aws#2661)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent e01f4c7 commit 459e3a9

File tree

4 files changed

+178
-0
lines changed

4 files changed

+178
-0
lines changed

src/sagemaker/workflow/properties.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(
7777
for key, info in members.items():
7878
if shapes.get(info["shape"], {}).get("type") == "list":
7979
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
80+
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
81+
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
8082
else:
8183
self.__dict__[key] = Properties(
8284
f"{path}.{key}", info["shape"], service_name=service_name
@@ -122,6 +124,38 @@ def __getitem__(self, item: Union[int, str]):
122124
return self._items.get(item)
123125

124126

127+
class PropertiesMap(Properties):
128+
"""PropertiesMap for use in workflow expressions."""
129+
130+
def __init__(self, path: str, shape_name: str = None):
131+
"""Create a PropertiesMap instance representing the given shape.
132+
133+
Args:
134+
path (str): The parent path of the PropertiesMap instance.
135+
shape_name (str): The botocore sagemaker service model shape name.
136+
"""
137+
super(PropertiesMap, self).__init__(path, shape_name)
138+
self.shape_name = shape_name
139+
self._items: Dict[Union[int, str], Properties] = dict()
140+
141+
def __getitem__(self, item: Union[int, str]):
142+
"""Populate the indexing item with a Property, for both lists and dictionaries.
143+
144+
Args:
145+
item (Union[int, str]): The index of the item in sequence.
146+
"""
147+
if item not in self._items.keys():
148+
shape = Properties._shapes.get(self.shape_name)
149+
member = shape["value"]["shape"]
150+
if isinstance(item, str):
151+
property_item = Properties(f"{self._path}['{item}']", member)
152+
else:
153+
property_item = Properties(f"{self._path}[{item}]", member)
154+
self._items[item] = property_item
155+
156+
return self._items.get(item)
157+
158+
125159
@attr.s
126160
class PropertyFile(Expression):
127161
"""Provides a property file struct.

tests/integ/test_workflow.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
856856
pass
857857

858858

859+
def test_steps_with_map_params_pipeline(
860+
sagemaker_session, role, script_dir, pipeline_name, region_name, athena_dataset_definition
861+
):
862+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
863+
framework_version = "0.20.0"
864+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
865+
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
866+
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
867+
868+
sklearn_processor = SKLearnProcessor(
869+
framework_version=framework_version,
870+
instance_type=instance_type,
871+
instance_count=instance_count,
872+
base_job_name="test-sklearn",
873+
sagemaker_session=sagemaker_session,
874+
role=role,
875+
)
876+
step_process = ProcessingStep(
877+
name="my-process",
878+
display_name="ProcessingStep",
879+
description="description for Processing step",
880+
processor=sklearn_processor,
881+
inputs=[
882+
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
883+
ProcessingInput(dataset_definition=athena_dataset_definition),
884+
],
885+
outputs=[
886+
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
887+
ProcessingOutput(
888+
output_name="test_data",
889+
source="/opt/ml/processing/test",
890+
destination=Join(
891+
on="/",
892+
values=[
893+
"s3:/",
894+
sagemaker_session.default_bucket(),
895+
"test-sklearn",
896+
output_prefix,
897+
ExecutionVariables.PIPELINE_EXECUTION_ID,
898+
],
899+
),
900+
),
901+
],
902+
code=os.path.join(script_dir, "preprocessing.py"),
903+
)
904+
905+
sklearn_train = SKLearn(
906+
framework_version=framework_version,
907+
entry_point=os.path.join(script_dir, "train.py"),
908+
instance_type=instance_type,
909+
sagemaker_session=sagemaker_session,
910+
role=role,
911+
hyperparameters={
912+
"batch-size": 500,
913+
"epochs": 5,
914+
},
915+
)
916+
step_train = TrainingStep(
917+
name="my-train",
918+
display_name="TrainingStep",
919+
description="description for Training step",
920+
estimator=sklearn_train,
921+
inputs=TrainingInput(
922+
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
923+
"train_data"
924+
].S3Output.S3Uri
925+
),
926+
)
927+
928+
model = Model(
929+
image_uri=sklearn_train.image_uri,
930+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
931+
sagemaker_session=sagemaker_session,
932+
role=role,
933+
)
934+
model_inputs = CreateModelInput(
935+
instance_type="ml.m5.large",
936+
accelerator_type="ml.eia1.medium",
937+
)
938+
step_model = CreateModelStep(
939+
name="my-model",
940+
display_name="ModelStep",
941+
description="description for Model step",
942+
model=model,
943+
inputs=model_inputs,
944+
)
945+
946+
# Condition step for evaluating model quality and branching execution
947+
cond_lte = ConditionGreaterThanOrEqualTo(
948+
left=step_train.properties.HyperParameters["batch-size"],
949+
right=6.0,
950+
)
951+
952+
step_cond = ConditionStep(
953+
name="CustomerChurnAccuracyCond",
954+
conditions=[cond_lte],
955+
if_steps=[],
956+
else_steps=[step_model],
957+
)
958+
959+
pipeline = Pipeline(
960+
name=pipeline_name,
961+
parameters=[instance_type, instance_count, output_prefix],
962+
steps=[step_process, step_train, step_cond],
963+
sagemaker_session=sagemaker_session,
964+
)
965+
966+
definition = json.loads(pipeline.definition())
967+
assert definition["Version"] == "2020-12-01"
968+
969+
steps = definition["Steps"]
970+
assert len(steps) == 3
971+
training_args = {}
972+
condition_args = {}
973+
for step in steps:
974+
if step["Type"] == "Training":
975+
training_args = step["Arguments"]
976+
if step["Type"] == "Condition":
977+
condition_args = step["Arguments"]
978+
979+
assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == {
980+
"Get": "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
981+
}
982+
assert condition_args["Conditions"][0]["LeftValue"] == {
983+
"Get": "Steps.my-train.HyperParameters['batch-size']"
984+
}
985+
986+
try:
987+
response = pipeline.create(role)
988+
create_arn = response["PipelineArn"]
989+
assert re.match(
990+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
991+
create_arn,
992+
)
993+
994+
finally:
995+
try:
996+
pipeline.delete()
997+
except Exception:
998+
pass
999+
1000+
8591001
def test_two_step_callback_pipeline_with_output_reference(
8601002
sagemaker_session, role, pipeline_name, region_name
8611003
):

tests/unit/sagemaker/workflow/test_properties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_properties_describe_training_job_response():
2222
for name in some_prop_names:
2323
assert name in prop.__dict__.keys()
2424
assert prop.CreationTime.expr == {"Get": "Steps.MyStep.CreationTime"}
25+
assert prop.HyperParameters.expr == {"Get": "Steps.MyStep.HyperParameters"}
2526
assert prop.OutputDataConfig.S3OutputPath.expr == {
2627
"Get": "Steps.MyStep.OutputDataConfig.S3OutputPath"
2728
}

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def test_training_step_base_estimator(sagemaker_session):
226226
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
227227
}
228228
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
229+
assert step.properties.HyperParameters.expr == {"Get": "Steps.MyTrainingStep.HyperParameters"}
229230

230231

231232
def test_training_step_tensorflow(sagemaker_session):

0 commit comments

Comments
 (0)