-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: support maps in step parameters #2661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,8 @@ def __init__( | |
for key, info in members.items(): | ||
if Properties._shapes.get(info["shape"], {}).get("type") == "list": | ||
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"]) | ||
elif Properties._shapes.get(info["shape"], {}).get("type") == "map": | ||
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am little confused here. info[shape] is map here so why does constructor has str in it |
||
else: | ||
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"]) | ||
|
||
|
@@ -109,6 +111,38 @@ def __getitem__(self, item: Union[int, str]): | |
return self._items.get(item) | ||
|
||
|
||
class PropertiesMap(Properties): | ||
"""PropertiesMap for use in workflow expressions.""" | ||
|
||
def __init__(self, path: str, shape_name: str = None): | ||
"""Create a PropertiesMap instance representing the given shape. | ||
|
||
Args: | ||
path (str): The parent path of the PropertiesMap instance. | ||
shape_name (str): The botocore sagemaker service model shape name. | ||
""" | ||
super(PropertiesMap, self).__init__(path, shape_name) | ||
self.shape_name = shape_name | ||
self._items: Dict[Union[int, str], Properties] = dict() | ||
|
||
def __getitem__(self, item: Union[int, str]): | ||
"""Populate the indexing item with a Property, for both lists and dictionaries. | ||
|
||
Args: | ||
item (Union[int, str]): The index of the item in sequence. | ||
""" | ||
if item not in self._items.keys(): | ||
shape = Properties._shapes.get(self.shape_name) | ||
member = shape["value"]["shape"] | ||
if isinstance(item, str): | ||
property_item = Properties(f"{self._path}['{item}']", member) | ||
else: | ||
property_item = Properties(f"{self._path}[{item}]", member) | ||
self._items[item] = property_item | ||
|
||
return self._items.get(item) | ||
|
||
|
||
@attr.s | ||
class PropertyFile(Expression): | ||
"""Provides a property file struct. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -855,6 +855,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi | |
pass | ||
|
||
|
||
def test_steps_with_map_params_pipeline( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be missing it but how is this testing map params? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assert at line 947 and 948 are checking hyperparameter values? Do we need different asserts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated for testing hyperparameters |
||
sagemaker_session, role, script_dir, pipeline_name, region_name, athena_dataset_definition | ||
): | ||
instance_count = ParameterInteger(name="InstanceCount", default_value=2) | ||
framework_version = "0.20.0" | ||
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") | ||
output_prefix = ParameterString(name="OutputPrefix", default_value="output") | ||
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" | ||
|
||
sklearn_processor = SKLearnProcessor( | ||
framework_version=framework_version, | ||
instance_type=instance_type, | ||
instance_count=instance_count, | ||
base_job_name="test-sklearn", | ||
sagemaker_session=sagemaker_session, | ||
role=role, | ||
) | ||
step_process = ProcessingStep( | ||
name="my-process", | ||
display_name="ProcessingStep", | ||
description="description for Processing step", | ||
processor=sklearn_processor, | ||
inputs=[ | ||
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), | ||
ProcessingInput(dataset_definition=athena_dataset_definition), | ||
], | ||
outputs=[ | ||
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), | ||
ProcessingOutput( | ||
output_name="test_data", | ||
source="/opt/ml/processing/test", | ||
destination=Join( | ||
on="/", | ||
values=[ | ||
"s3:/", | ||
sagemaker_session.default_bucket(), | ||
"test-sklearn", | ||
output_prefix, | ||
ExecutionVariables.PIPELINE_EXECUTION_ID, | ||
], | ||
), | ||
), | ||
], | ||
code=os.path.join(script_dir, "preprocessing.py"), | ||
) | ||
|
||
sklearn_train = SKLearn( | ||
framework_version=framework_version, | ||
entry_point=os.path.join(script_dir, "train.py"), | ||
instance_type=instance_type, | ||
sagemaker_session=sagemaker_session, | ||
role=role, | ||
hyperparameters={ | ||
"batch-size": 500, | ||
"epochs": 5, | ||
}, | ||
) | ||
step_train = TrainingStep( | ||
name="my-train", | ||
display_name="TrainingStep", | ||
description="description for Training step", | ||
estimator=sklearn_train, | ||
inputs=TrainingInput( | ||
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ | ||
"train_data" | ||
].S3Output.S3Uri | ||
), | ||
) | ||
|
||
model = Model( | ||
image_uri=sklearn_train.image_uri, | ||
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, | ||
sagemaker_session=sagemaker_session, | ||
role=role, | ||
) | ||
model_inputs = CreateModelInput( | ||
instance_type="ml.m5.large", | ||
accelerator_type="ml.eia1.medium", | ||
) | ||
step_model = CreateModelStep( | ||
name="my-model", | ||
display_name="ModelStep", | ||
description="description for Model step", | ||
model=model, | ||
inputs=model_inputs, | ||
) | ||
|
||
# Condition step for evaluating model quality and branching execution | ||
cond_lte = ConditionGreaterThanOrEqualTo( | ||
left=step_train.properties.HyperParameters["batch-size"], | ||
right=6.0, | ||
) | ||
|
||
step_cond = ConditionStep( | ||
name="CustomerChurnAccuracyCond", | ||
conditions=[cond_lte], | ||
if_steps=[], | ||
else_steps=[step_model], | ||
) | ||
|
||
pipeline = Pipeline( | ||
name=pipeline_name, | ||
parameters=[instance_type, instance_count, output_prefix], | ||
steps=[step_process, step_train, step_cond], | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
definition = json.loads(pipeline.definition()) | ||
assert definition["Version"] == "2020-12-01" | ||
|
||
steps = definition["Steps"] | ||
assert len(steps) == 3 | ||
training_args = {} | ||
condition_args = {} | ||
for step in steps: | ||
if step["Type"] == "Training": | ||
training_args = step["Arguments"] | ||
if step["Type"] == "Condition": | ||
condition_args = step["Arguments"] | ||
|
||
assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == { | ||
"Get": "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri" | ||
} | ||
assert condition_args["Conditions"][0]["LeftValue"] == { | ||
"Get": "Steps.my-train.HyperParameters['batch-size']" | ||
} | ||
|
||
try: | ||
response = pipeline.create(role) | ||
create_arn = response["PipelineArn"] | ||
assert re.match( | ||
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", | ||
create_arn, | ||
) | ||
|
||
finally: | ||
try: | ||
pipeline.delete() | ||
except Exception: | ||
pass | ||
|
||
|
||
def test_two_step_callback_pipeline_with_output_reference( | ||
sagemaker_session, role, pipeline_name, region_name | ||
): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think second param should be list here in get