Skip to content

fix: jsonGet interpolation issue 2426 + allow step depends on pass in step instance #2477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/workflows/pipelines/sagemaker.workflow.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ ConditionStep

.. autoclass:: sagemaker.workflow.condition_step.ConditionStep

.. autoclass:: sagemaker.workflow.condition_step.JsonGet

Conditions
----------
Expand Down Expand Up @@ -55,6 +54,7 @@ Functions
---------

.. autoclass:: sagemaker.workflow.functions.Join
.. autoclass:: sagemaker.workflow.functions.JsonGet

Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/workflow/callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Dict
from typing import List, Dict, Union
from enum import Enum

import attr
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
inputs: dict,
outputs: List[CallbackOutput],
cache_config: CacheConfig = None,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
):
"""Constructs a CallbackStep.

Expand All @@ -95,8 +95,8 @@ def __init__(
in the SQS message body of callback messages.
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
depends on
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.CallbackStep` depends on
"""
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
self.sqs_queue_url = sqs_queue_url
Expand Down
44 changes: 3 additions & 41 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,9 @@

from typing import List, Union

import attr

from sagemaker.workflow.conditions import Condition
from sagemaker.workflow.entities import (
Expression,
RequestType,
)
from sagemaker.workflow.properties import (
Properties,
PropertyFile,
)
from sagemaker.workflow.entities import RequestType
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import (
Step,
StepTypeEnum,
Expand All @@ -40,7 +32,7 @@ class ConditionStep(Step):
def __init__(
self,
name: str,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
conditions: List[Condition] = None,
if_steps: List[Union[Step, StepCollection]] = None,
else_steps: List[Union[Step, StepCollection]] = None,
Expand Down Expand Up @@ -84,33 +76,3 @@ def arguments(self) -> RequestType:
def properties(self):
"""A simple Properties object with `Outcome` as the only property"""
return self._properties


@attr.s
class JsonGet(Expression):
"""Get JSON properties from PropertyFiles.

Attributes:
step (Step): The step from which to get the property file.
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
or the name of a property file.
json_path (str): The JSON path expression to the requested value.
"""

step: Step = attr.ib()
property_file: Union[PropertyFile, str] = attr.ib()
json_path: str = attr.ib()

@property
def expr(self):
"""The expression dict for a `JsonGet` function."""
if isinstance(self.property_file, PropertyFile):
name = self.property_file.name
else:
name = self.property_file
return {
"Std:JsonGet": {
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
"Path": self.json_path,
}
}
38 changes: 37 additions & 1 deletion src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List
from typing import List, Union

import attr

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.properties import PropertyFile


@attr.s
Expand All @@ -44,3 +45,38 @@ def expr(self):
],
},
}


@attr.s
class JsonGet(Expression):
"""Get JSON properties from PropertyFiles.

Attributes:
processing_step_name (str): The step name of the `sagemaker.workflow.steps.ProcessingStep`
from which to get the property file.
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
or the name of a property file.
json_path (str): The JSON path expression to the requested value.
"""

processing_step_name: str = attr.ib()
property_file: Union[PropertyFile, str] = attr.ib()
json_path: str = attr.ib()

@property
def expr(self):
"""The expression dict for a `JsonGet` function."""
if isinstance(self.property_file, PropertyFile):
name = self.property_file.name
else:
name = self.property_file

if not isinstance(self.processing_step_name, str):
raise ValueError("processing_step_name passed in is not instance of a str")

return {
"Std:JsonGet": {
"PropertyFile": {"Get": f"Steps.{self.processing_step_name}.PropertyFiles.{name}"},
"Path": self.json_path,
}
}
14 changes: 7 additions & 7 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List
from typing import List, Union

import attr

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
response_types,
inference_instances,
transform_instances,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
model_package_group_name=None,
model_metrics=None,
approval_status=None,
Expand All @@ -82,8 +82,8 @@ def __init__(
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed (default: None).
depends_on (List[str]): The list of step names the first step in the collection
depends on
depends_on (List[str] or List[Step]): The list of step names or step instances
the first step in the collection depends on
model_package_group_name (str): The Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(
max_payload=None,
tags=None,
volume_kms_key=None,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
**kwargs,
):
"""Construct steps required for a Transformer step collection:
Expand Down Expand Up @@ -216,8 +216,8 @@ def __init__(
it will be the format of the batch transform output.
env (dict): The Environment variables to be set for use during the
transform job (default: None).
depends_on (List[str]): The list of step names the first step in
the collection depends on
depends_on (List[str] or List[Step]): The list of step names or step instances
the first step in the collection depends on
"""
steps = []
if "entry_point" in kwargs:
Expand Down
54 changes: 37 additions & 17 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ class Step(Entity):
Attributes:
name (str): The name of the step.
step_type (StepTypeEnum): The type of the step.
depends_on (List[str]): The list of step names the current step depends on
depends_on (List[str] or List[Step]): The list of step names or step
instances the current step depends on
"""

name: str = attr.ib(factory=str)
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
depends_on: List[str] = attr.ib(default=None)
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)

@property
@abc.abstractmethod
Expand All @@ -85,11 +86,13 @@ def to_request(self) -> RequestType:
"Arguments": self.arguments,
}
if self.depends_on:
request_dict["DependsOn"] = self.depends_on
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)

return request_dict

def add_depends_on(self, step_names: List[str]):
"""Add step names to the current step depends on list"""
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
"""Add step names or step instances to the current step depends on list"""

if not step_names:
return
if not self.depends_on:
Expand All @@ -101,6 +104,19 @@ def ref(self) -> Dict[str, str]:
"""Gets a reference dict for steps"""
return {"Name": self.name}

@staticmethod
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
"""Resolver the step depends on list"""
depends_on = []
for step in depends_on_list:
if isinstance(step, Step):
depends_on.append(step.name)
elif isinstance(step, str):
depends_on.append(step)
else:
raise ValueError(f"Invalid input step name: {step}")
return depends_on


@attr.s
class CacheConfig:
Expand Down Expand Up @@ -143,7 +159,7 @@ def __init__(
estimator: EstimatorBase,
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
):
"""Construct a TrainingStep, given an `EstimatorBase` instance.

Expand Down Expand Up @@ -171,8 +187,8 @@ def __init__(
the path to the training dataset.

cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
depends on
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.TrainingStep` depends on
"""
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
self.estimator = estimator
Expand Down Expand Up @@ -217,7 +233,11 @@ class CreateModelStep(Step):
"""CreateModel step for workflow."""

def __init__(
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
self,
name: str,
model: Model,
inputs: CreateModelInput,
depends_on: Union[List[str], List[Step]] = None,
):
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.

Expand All @@ -229,8 +249,8 @@ def __init__(
model (Model): A `sagemaker.model.Model` instance.
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
Defaults to `None`.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
depends on
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.CreateModelStep` depends on
"""
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
self.model = model
Expand Down Expand Up @@ -275,7 +295,7 @@ def __init__(
transformer: Transformer,
inputs: TransformInput,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
):
"""Constructs a TransformStep, given an `Transformer` instance.

Expand All @@ -287,8 +307,8 @@ def __init__(
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
depends on
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.TransformStep` depends on
"""
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
self.transformer = transformer
Expand Down Expand Up @@ -351,7 +371,7 @@ def __init__(
code: str = None,
property_files: List[PropertyFile] = None,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
):
"""Construct a ProcessingStep, given a `Processor` instance.

Expand All @@ -372,8 +392,8 @@ def __init__(
property_files (List[PropertyFile]): A list of property files that workflow looks
for and resolves from the configured processing output list.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
depends on
depends_on (List[str] or List[Step]): A list of step names or step instance
this `sagemaker.workflow.steps.ProcessingStep` depends on
"""
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
self.processor = processor
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ def test_two_processing_job_depends_on(

step_pyspark_2 = ProcessingStep(
name="pyspark-process-2",
depends_on=[step_pyspark_1.name],
depends_on=[step_pyspark_1],
processor=pyspark_processor,
inputs=spark_run_args.inputs,
outputs=spark_run_args.outputs,
Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_workflow_with_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.session import get_execution_role
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.functions import JsonGet
from sagemaker.workflow.parameters import (
ParameterInteger,
ParameterString,
Expand Down Expand Up @@ -237,7 +238,7 @@ def test_workflow_with_clarify(
)

cond_left = JsonGet(
step=step_process,
processing_step_name=step_process.name,
property_file="BiasOutput",
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
)
Expand Down
Loading