From d2c9cd09278e95b821590d844677467a55e5a134 Mon Sep 17 00:00:00 2001 From: Jaya Talreja Date: Thu, 23 Sep 2021 16:44:20 -0700 Subject: [PATCH] fix: support maps in step parameters --- src/sagemaker/workflow/properties.py | 34 +++++ tests/integ/test_workflow.py | 142 ++++++++++++++++++ .../sagemaker/workflow/test_properties.py | 1 + tests/unit/sagemaker/workflow/test_steps.py | 1 + 4 files changed, 178 insertions(+) diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 15a2a2e3e2..96147e8e8b 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -67,6 +67,8 @@ def __init__( for key, info in members.items(): if Properties._shapes.get(info["shape"], {}).get("type") == "list": self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"]) + elif Properties._shapes.get(info["shape"], {}).get("type") == "map": + self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"]) else: self.__dict__[key] = Properties(f"{path}.{key}", info["shape"]) @@ -109,6 +111,38 @@ def __getitem__(self, item: Union[int, str]): return self._items.get(item) +class PropertiesMap(Properties): + """PropertiesMap for use in workflow expressions.""" + + def __init__(self, path: str, shape_name: str = None): + """Create a PropertiesMap instance representing the given shape. + + Args: + path (str): The parent path of the PropertiesMap instance. + shape_name (str): The botocore sagemaker service model shape name. + """ + super(PropertiesMap, self).__init__(path, shape_name) + self.shape_name = shape_name + self._items: Dict[Union[int, str], Properties] = dict() + + def __getitem__(self, item: Union[int, str]): + """Populate the indexing item with a Property, for both lists and dictionaries. + + Args: + item (Union[int, str]): The index of the item in sequence. + """ + if item not in self._items.keys(): + shape = Properties._shapes.get(self.shape_name) + member = shape["value"]["shape"] + if isinstance(item, str): + property_item = Properties(f"{self._path}['{item}']", member) + else: + property_item = Properties(f"{self._path}[{item}]", member) + self._items[item] = property_item + + return self._items.get(item) + + @attr.s class PropertyFile(Expression): """Provides a property file struct. diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index ade72c74a0..22feff2887 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -855,6 +855,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi pass +def test_steps_with_map_params_pipeline( + sagemaker_session, role, script_dir, pipeline_name, region_name, athena_dataset_definition +): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + framework_version = "0.20.0" + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + output_prefix = ParameterString(name="OutputPrefix", default_value="output") + input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" + + sklearn_processor = SKLearnProcessor( + framework_version=framework_version, + instance_type=instance_type, + instance_count=instance_count, + base_job_name="test-sklearn", + sagemaker_session=sagemaker_session, + role=role, + ) + step_process = ProcessingStep( + name="my-process", + display_name="ProcessingStep", + description="description for Processing step", + processor=sklearn_processor, + inputs=[ + ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), + ProcessingInput(dataset_definition=athena_dataset_definition), + ], + outputs=[ + ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), + ProcessingOutput( + output_name="test_data", + source="/opt/ml/processing/test", + destination=Join( + on="/", + values=[ + "s3:/", + sagemaker_session.default_bucket(), + "test-sklearn", + output_prefix, + ExecutionVariables.PIPELINE_EXECUTION_ID, + ], + ), + ), + ], + code=os.path.join(script_dir, "preprocessing.py"), + ) + + sklearn_train = SKLearn( + framework_version=framework_version, + entry_point=os.path.join(script_dir, "train.py"), + instance_type=instance_type, + sagemaker_session=sagemaker_session, + role=role, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + ) + step_train = TrainingStep( + name="my-train", + display_name="TrainingStep", + description="description for Training step", + estimator=sklearn_train, + inputs=TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ + "train_data" + ].S3Output.S3Uri + ), + ) + + model = Model( + image_uri=sklearn_train.image_uri, + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, + sagemaker_session=sagemaker_session, + role=role, + ) + model_inputs = CreateModelInput( + instance_type="ml.m5.large", + accelerator_type="ml.eia1.medium", + ) + step_model = CreateModelStep( + name="my-model", + display_name="ModelStep", + description="description for Model step", + model=model, + inputs=model_inputs, + ) + + # Condition step for evaluating model quality and branching execution + cond_lte = ConditionGreaterThanOrEqualTo( + left=step_train.properties.HyperParameters["batch-size"], + right=6.0, + ) + + step_cond = ConditionStep( + name="CustomerChurnAccuracyCond", + conditions=[cond_lte], + if_steps=[], + else_steps=[step_model], + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_type, instance_count, output_prefix], + steps=[step_process, step_train, step_cond], + sagemaker_session=sagemaker_session, + ) + + definition = json.loads(pipeline.definition()) + assert definition["Version"] == "2020-12-01" + + steps = definition["Steps"] + assert len(steps) == 3 + training_args = {} + condition_args = {} + for step in steps: + if step["Type"] == "Training": + training_args = step["Arguments"] + if step["Type"] == "Condition": + condition_args = step["Arguments"] + + assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == { + "Get": "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri" + } + assert condition_args["Conditions"][0]["LeftValue"] == { + "Get": "Steps.my-train.HyperParameters['batch-size']" + } + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_two_step_callback_pipeline_with_output_reference( sagemaker_session, role, pipeline_name, region_name ): diff --git a/tests/unit/sagemaker/workflow/test_properties.py b/tests/unit/sagemaker/workflow/test_properties.py index 5264a304ba..accaf46533 100644 --- a/tests/unit/sagemaker/workflow/test_properties.py +++ b/tests/unit/sagemaker/workflow/test_properties.py @@ -22,6 +22,7 @@ def test_properties_describe_training_job_response(): for name in some_prop_names: assert name in prop.__dict__.keys() assert prop.CreationTime.expr == {"Get": "Steps.MyStep.CreationTime"} + assert prop.HyperParameters.expr == {"Get": "Steps.MyStep.HyperParameters"} assert prop.OutputDataConfig.S3OutputPath.expr == { "Get": "Steps.MyStep.OutputDataConfig.S3OutputPath" } diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 839ad6a814..f33b12e0f5 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -226,6 +226,7 @@ def test_training_step_base_estimator(sagemaker_session): "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} + assert step.properties.HyperParameters.expr == {"Get": "Steps.MyTrainingStep.HyperParameters"} def test_training_step_tensorflow(sagemaker_session):