diff --git a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst index 720444a763..f2081d82ba 100644 --- a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst +++ b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst @@ -6,7 +6,6 @@ ConditionStep .. autoclass:: sagemaker.workflow.condition_step.ConditionStep -.. autoclass:: sagemaker.workflow.condition_step.JsonGet Conditions ---------- @@ -55,6 +54,7 @@ Functions --------- .. autoclass:: sagemaker.workflow.functions.Join +.. autoclass:: sagemaker.workflow.functions.JsonGet Parameters ---------- diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index 2c0f0f6355..ad961c0eb7 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -13,7 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List, Dict +from typing import List, Dict, Union from enum import Enum import attr @@ -84,7 +84,7 @@ def __init__( inputs: dict, outputs: List[CallbackOutput], cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, ): """Constructs a CallbackStep. @@ -95,8 +95,8 @@ def __init__( in the SQS message body of callback messages. outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep` - depends on + depends_on (List[str] or List[Step]): A list of step names or step instances + this `sagemaker.workflow.steps.CallbackStep` depends on """ super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on) self.sqs_queue_url = sqs_queue_url diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index 13ab405805..b5130cc780 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -15,17 +15,9 @@ from typing import List, Union -import attr - from sagemaker.workflow.conditions import Condition -from sagemaker.workflow.entities import ( - Expression, - RequestType, -) -from sagemaker.workflow.properties import ( - Properties, - PropertyFile, -) +from sagemaker.workflow.entities import RequestType +from sagemaker.workflow.properties import Properties from sagemaker.workflow.steps import ( Step, StepTypeEnum, @@ -40,7 +32,7 @@ class ConditionStep(Step): def __init__( self, name: str, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, conditions: List[Condition] = None, if_steps: List[Union[Step, StepCollection]] = None, else_steps: List[Union[Step, StepCollection]] = None, @@ -84,33 +76,3 @@ def arguments(self) -> RequestType: def properties(self): """A simple Properties object with `Outcome` as the only property""" return self._properties - - -@attr.s -class JsonGet(Expression): - """Get JSON properties from PropertyFiles. - - Attributes: - step (Step): The step from which to get the property file. - property_file (Union[PropertyFile, str]): Either a PropertyFile instance - or the name of a property file. - json_path (str): The JSON path expression to the requested value. - """ - - step: Step = attr.ib() - property_file: Union[PropertyFile, str] = attr.ib() - json_path: str = attr.ib() - - @property - def expr(self): - """The expression dict for a `JsonGet` function.""" - if isinstance(self.property_file, PropertyFile): - name = self.property_file.name - else: - name = self.property_file - return { - "Std:JsonGet": { - "PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"}, - "Path": self.json_path, - } - } diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index a5ebc68158..d60ecb2073 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -13,11 +13,12 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List +from typing import List, Union import attr from sagemaker.workflow.entities import Expression +from sagemaker.workflow.properties import PropertyFile @attr.s @@ -44,3 +45,38 @@ def expr(self): ], }, } + + +@attr.s +class JsonGet(Expression): + """Get JSON properties from PropertyFiles. + + Attributes: + processing_step_name (str): The step name of the `sagemaker.workflow.steps.ProcessingStep` + from which to get the property file. + property_file (Union[PropertyFile, str]): Either a PropertyFile instance + or the name of a property file. + json_path (str): The JSON path expression to the requested value. + """ + + processing_step_name: str = attr.ib() + property_file: Union[PropertyFile, str] = attr.ib() + json_path: str = attr.ib() + + @property + def expr(self): + """The expression dict for a `JsonGet` function.""" + if isinstance(self.property_file, PropertyFile): + name = self.property_file.name + else: + name = self.property_file + + if not isinstance(self.processing_step_name, str): + raise ValueError("processing_step_name passed in is not instance of a str") + + return { + "Std:JsonGet": { + "PropertyFile": {"Get": f"Steps.{self.processing_step_name}.PropertyFiles.{name}"}, + "Path": self.json_path, + } + } diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 143c59395c..dbdeee2c10 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -13,7 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List +from typing import List, Union import attr @@ -60,7 +60,7 @@ def __init__( response_types, inference_instances, transform_instances, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, model_package_group_name=None, model_metrics=None, approval_status=None, @@ -82,8 +82,8 @@ def __init__( generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). - depends_on (List[str]): The list of step names the first step in the collection - depends on + depends_on (List[str] or List[Step]): The list of step names or step instances + the first step in the collection depends on model_package_group_name (str): The Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). @@ -179,7 +179,7 @@ def __init__( max_payload=None, tags=None, volume_kms_key=None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, **kwargs, ): """Construct steps required for a Transformer step collection: @@ -216,8 +216,8 @@ def __init__( it will be the format of the batch transform output. env (dict): The Environment variables to be set for use during the transform job (default: None). - depends_on (List[str]): The list of step names the first step in - the collection depends on + depends_on (List[str] or List[Step]): The list of step names or step instances + the first step in the collection depends on """ steps = [] if "entry_point" in kwargs: diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 5e36392b70..c1f2171de4 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -60,12 +60,13 @@ class Step(Entity): Attributes: name (str): The name of the step. step_type (StepTypeEnum): The type of the step. - depends_on (List[str]): The list of step names the current step depends on + depends_on (List[str] or List[Step]): The list of step names or step + instances the current step depends on """ name: str = attr.ib(factory=str) step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory) - depends_on: List[str] = attr.ib(default=None) + depends_on: Union[List[str], List["Step"]] = attr.ib(default=None) @property @abc.abstractmethod @@ -85,11 +86,13 @@ def to_request(self) -> RequestType: "Arguments": self.arguments, } if self.depends_on: - request_dict["DependsOn"] = self.depends_on + request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on) + return request_dict - def add_depends_on(self, step_names: List[str]): - """Add step names to the current step depends on list""" + def add_depends_on(self, step_names: Union[List[str], List["Step"]]): + """Add step names or step instances to the current step depends on list""" + if not step_names: return if not self.depends_on: @@ -101,6 +104,19 @@ def ref(self) -> Dict[str, str]: """Gets a reference dict for steps""" return {"Name": self.name} + @staticmethod + def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]): + """Resolver the step depends on list""" + depends_on = [] + for step in depends_on_list: + if isinstance(step, Step): + depends_on.append(step.name) + elif isinstance(step, str): + depends_on.append(step) + else: + raise ValueError(f"Invalid input step name: {step}") + return depends_on + @attr.s class CacheConfig: @@ -143,7 +159,7 @@ def __init__( estimator: EstimatorBase, inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, ): """Construct a TrainingStep, given an `EstimatorBase` instance. @@ -171,8 +187,8 @@ def __init__( the path to the training dataset. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` - depends on + depends_on (List[str] or List[Step]): A list of step names or step instances + this `sagemaker.workflow.steps.TrainingStep` depends on """ super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on) self.estimator = estimator @@ -217,7 +233,11 @@ class CreateModelStep(Step): """CreateModel step for workflow.""" def __init__( - self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None + self, + name: str, + model: Model, + inputs: CreateModelInput, + depends_on: Union[List[str], List[Step]] = None, ): """Construct a CreateModelStep, given an `sagemaker.model.Model` instance. @@ -229,8 +249,8 @@ def __init__( model (Model): A `sagemaker.model.Model` instance. inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance. Defaults to `None`. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep` - depends on + depends_on (List[str] or List[Step]): A list of step names or step instances + this `sagemaker.workflow.steps.CreateModelStep` depends on """ super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on) self.model = model @@ -275,7 +295,7 @@ def __init__( transformer: Transformer, inputs: TransformInput, cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, ): """Constructs a TransformStep, given an `Transformer` instance. @@ -287,8 +307,8 @@ def __init__( transformer (Transformer): A `sagemaker.transformer.Transformer` instance. inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep` - depends on + depends_on (List[str] or List[Step]): A list of step names or step instances + this `sagemaker.workflow.steps.TransformStep` depends on """ super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on) self.transformer = transformer @@ -351,7 +371,7 @@ def __init__( code: str = None, property_files: List[PropertyFile] = None, cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, ): """Construct a ProcessingStep, given a `Processor` instance. @@ -372,8 +392,8 @@ def __init__( property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep` - depends on + depends_on (List[str] or List[Step]): A list of step names or step instance + this `sagemaker.workflow.steps.ProcessingStep` depends on """ super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on) self.processor = processor diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index c85e2f3aa2..810484cca6 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1128,7 +1128,7 @@ def test_two_processing_job_depends_on( step_pyspark_2 = ProcessingStep( name="pyspark-process-2", - depends_on=[step_pyspark_1.name], + depends_on=[step_pyspark_1], processor=pyspark_processor, inputs=spark_run_args.inputs, outputs=spark_run_args.outputs, diff --git a/tests/integ/test_workflow_with_clarify.py b/tests/integ/test_workflow_with_clarify.py index 53dae54e84..4664df9894 100644 --- a/tests/integ/test_workflow_with_clarify.py +++ b/tests/integ/test_workflow_with_clarify.py @@ -33,7 +33,8 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.session import get_execution_role from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo -from sagemaker.workflow.condition_step import ConditionStep, JsonGet +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.functions import JsonGet from sagemaker.workflow.parameters import ( ParameterInteger, ParameterString, @@ -237,7 +238,7 @@ def test_workflow_with_clarify( ) cond_left = JsonGet( - step=step_process, + processing_step_name=step_process.name, property_file="BiasOutput", json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value", ) diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 3f662a77ea..c820897fd8 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -14,13 +14,14 @@ from __future__ import absolute_import from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.workflow.functions import Join +from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.parameters import ( ParameterFloat, ParameterInteger, ParameterString, ) from sagemaker.workflow.properties import Properties +from sagemaker.workflow.properties import PropertyFile def test_join_primitives_default_on(): @@ -66,3 +67,16 @@ def test_join_expressions(): ], }, } + + +def test_json_get_expressions(): + params = PropertyFile(name="params", output_name="params", path="params.json") + + assert JsonGet( + processing_step_name="processing_step", property_file=params, json_path="alpha" + ).expr == { + "Std:JsonGet": { + "PropertyFile": {"Get": "Steps.processing_step.PropertyFiles.params"}, + "Path": "alpha", + } + } diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 34f1856c19..e3f22a5559 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -488,3 +488,56 @@ def test_properties_describe_processing_job_response(): assert prop.ProcessingOutputConfig.Outputs["MyOutputName"].S3Output.S3Uri.expr == { "Get": "Steps.MyStep.ProcessingOutputConfig.Outputs['MyOutputName'].S3Output.S3Uri" } + + +def test_add_depends_on(sagemaker_session): + processing_input_data_uri_parameter = ParameterString( + name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" + ) + instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.m4.4xlarge") + instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) + processor = Processor( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=instance_count_parameter, + instance_type=instance_type_parameter, + sagemaker_session=sagemaker_session, + ) + inputs = [ + ProcessingInput( + source=processing_input_data_uri_parameter, + destination="processing_manifest", + ) + ] + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + step_1 = ProcessingStep( + name="MyProcessingStep-1", + processor=processor, + inputs=inputs, + outputs=[], + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="MyProcessingStep-2", + depends_on=[step_1], + processor=processor, + inputs=inputs, + outputs=[], + cache_config=cache_config, + ) + + step_3 = ProcessingStep( + name="MyProcessingStep-3", + depends_on=[step_1], + processor=processor, + inputs=inputs, + outputs=[], + cache_config=cache_config, + ) + step_3.add_depends_on([step_2.name]) + + assert "DependsOn" not in step_1.to_request() + assert step_2.to_request()["DependsOn"] == ["MyProcessingStep-1"] + assert step_3.to_request()["DependsOn"] == ["MyProcessingStep-1", "MyProcessingStep-2"]