Skip to content

fix: support maps in step parameters #2661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/sagemaker/workflow/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(
for key, info in members.items():
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think second param should be list here in get

self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am little confused here. info[shape] is map here so why does constructor has str in it

else:
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])

Expand Down Expand Up @@ -109,6 +111,38 @@ def __getitem__(self, item: Union[int, str]):
return self._items.get(item)


class PropertiesMap(Properties):
"""PropertiesMap for use in workflow expressions."""

def __init__(self, path: str, shape_name: str = None):
"""Create a PropertiesMap instance representing the given shape.

Args:
path (str): The parent path of the PropertiesMap instance.
shape_name (str): The botocore sagemaker service model shape name.
"""
super(PropertiesMap, self).__init__(path, shape_name)
self.shape_name = shape_name
self._items: Dict[Union[int, str], Properties] = dict()

def __getitem__(self, item: Union[int, str]):
"""Populate the indexing item with a Property, for both lists and dictionaries.

Args:
item (Union[int, str]): The index of the item in sequence.
"""
if item not in self._items.keys():
shape = Properties._shapes.get(self.shape_name)
member = shape["value"]["shape"]
if isinstance(item, str):
property_item = Properties(f"{self._path}['{item}']", member)
else:
property_item = Properties(f"{self._path}[{item}]", member)
self._items[item] = property_item

return self._items.get(item)


@attr.s
class PropertyFile(Expression):
"""Provides a property file struct.
Expand Down
142 changes: 142 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
pass


def test_steps_with_map_params_pipeline(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing it but how is this testing map params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert at line 947 and 948 are checking hyperparameter values? Do we need different asserts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated for testing hyperparameters

sagemaker_session, role, script_dir, pipeline_name, region_name, athena_dataset_definition
):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
framework_version = "0.20.0"
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"

sklearn_processor = SKLearnProcessor(
framework_version=framework_version,
instance_type=instance_type,
instance_count=instance_count,
base_job_name="test-sklearn",
sagemaker_session=sagemaker_session,
role=role,
)
step_process = ProcessingStep(
name="my-process",
display_name="ProcessingStep",
description="description for Processing step",
processor=sklearn_processor,
inputs=[
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
ProcessingInput(dataset_definition=athena_dataset_definition),
],
outputs=[
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
ProcessingOutput(
output_name="test_data",
source="/opt/ml/processing/test",
destination=Join(
on="/",
values=[
"s3:/",
sagemaker_session.default_bucket(),
"test-sklearn",
output_prefix,
ExecutionVariables.PIPELINE_EXECUTION_ID,
],
),
),
],
code=os.path.join(script_dir, "preprocessing.py"),
)

sklearn_train = SKLearn(
framework_version=framework_version,
entry_point=os.path.join(script_dir, "train.py"),
instance_type=instance_type,
sagemaker_session=sagemaker_session,
role=role,
hyperparameters={
"batch-size": 500,
"epochs": 5,
},
)
step_train = TrainingStep(
name="my-train",
display_name="TrainingStep",
description="description for Training step",
estimator=sklearn_train,
inputs=TrainingInput(
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
"train_data"
].S3Output.S3Uri
),
)

model = Model(
image_uri=sklearn_train.image_uri,
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
sagemaker_session=sagemaker_session,
role=role,
)
model_inputs = CreateModelInput(
instance_type="ml.m5.large",
accelerator_type="ml.eia1.medium",
)
step_model = CreateModelStep(
name="my-model",
display_name="ModelStep",
description="description for Model step",
model=model,
inputs=model_inputs,
)

# Condition step for evaluating model quality and branching execution
cond_lte = ConditionGreaterThanOrEqualTo(
left=step_train.properties.HyperParameters["batch-size"],
right=6.0,
)

step_cond = ConditionStep(
name="CustomerChurnAccuracyCond",
conditions=[cond_lte],
if_steps=[],
else_steps=[step_model],
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_type, instance_count, output_prefix],
steps=[step_process, step_train, step_cond],
sagemaker_session=sagemaker_session,
)

definition = json.loads(pipeline.definition())
assert definition["Version"] == "2020-12-01"

steps = definition["Steps"]
assert len(steps) == 3
training_args = {}
condition_args = {}
for step in steps:
if step["Type"] == "Training":
training_args = step["Arguments"]
if step["Type"] == "Condition":
condition_args = step["Arguments"]

assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == {
"Get": "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
}
assert condition_args["Conditions"][0]["LeftValue"] == {
"Get": "Steps.my-train.HyperParameters['batch-size']"
}

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)

finally:
try:
pipeline.delete()
except Exception:
pass


def test_two_step_callback_pipeline_with_output_reference(
sagemaker_session, role, pipeline_name, region_name
):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/workflow/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_properties_describe_training_job_response():
for name in some_prop_names:
assert name in prop.__dict__.keys()
assert prop.CreationTime.expr == {"Get": "Steps.MyStep.CreationTime"}
assert prop.HyperParameters.expr == {"Get": "Steps.MyStep.HyperParameters"}
assert prop.OutputDataConfig.S3OutputPath.expr == {
"Get": "Steps.MyStep.OutputDataConfig.S3OutputPath"
}
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def test_training_step_base_estimator(sagemaker_session):
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
assert step.properties.HyperParameters.expr == {"Get": "Steps.MyTrainingStep.HyperParameters"}


def test_training_step_tensorflow(sagemaker_session):
Expand Down