From 6ec51c2517db0f6dda86138117827d32185a1185 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 4 Feb 2022 21:28:12 +0000 Subject: [PATCH 1/8] feat: override jumpstart content bucket --- src/sagemaker/jumpstart/constants.py | 2 ++ src/sagemaker/jumpstart/utils.py | 7 +++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 8 ++++++++ 3 files changed, 17 insertions(+) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 363e542b02..a748beac89 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -122,3 +122,5 @@ TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) + +ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 16bdd9fc4f..72092a0ea0 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -13,6 +13,7 @@ """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import import logging +import os from typing import Dict, List, Optional from urllib.parse import urlparse from packaging.version import Version @@ -60,6 +61,12 @@ def get_jumpstart_content_bucket(region: str) -> str: Raises: RuntimeError: If JumpStart is not launched in ``region``. """ + + if ( + constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + return os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe494eb459..69d0412e01 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -11,11 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock.mock import Mock, patch import pytest import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, JumpStartScriptScope, @@ -40,6 +42,12 @@ def test_get_jumpstart_content_bucket(): utils.get_jumpstart_content_bucket(bad_region) +def test_get_jumpstart_content_bucket_override(): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + + def test_get_jumpstart_launched_regions_message(): with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): From 7aa181d633d8571dbcaa74ff02300f30ab31d893 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 8 Feb 2022 18:47:48 +0000 Subject: [PATCH 2/8] chore: log info msg when overriding jumpstart bucket --- src/sagemaker/jumpstart/utils.py | 4 +++- tests/unit/sagemaker/jumpstart/test_utils.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 72092a0ea0..c59966d1b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -66,7 +66,9 @@ def get_jumpstart_content_bucket(region: str) -> str: constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 ): - return os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] + bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] + LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) + return bucket_override try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 69d0412e01..04eddced08 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -44,8 +44,13 @@ def test_get_jumpstart_content_bucket(): def test_get_jumpstart_content_bucket_override(): with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart bucket override: '%s'", + "some-val", + ) def test_get_jumpstart_launched_regions_message(): From b095a469cd6e6511622a9eef5ff022ac42dfe18a Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 10 Feb 2022 21:11:29 +0000 Subject: [PATCH 3/8] fix: jumpstart docs --- doc/overview.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/overview.rst b/doc/overview.rst index 103d431b05..7c0b0c4818 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -741,6 +741,7 @@ see `Model Date: Tue, 15 Feb 2022 00:30:50 -0800 Subject: [PATCH 4/8] fix: Update Static Endpoint (#2931) --- tests/integ/sagemaker/lineage/conftest.py | 6 +++--- tests/integ/test_workflow.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 672af41de9..0139a5b658 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -45,9 +45,9 @@ SLEEP_TIME_SECONDS = 1 SLEEP_TIME_TWO_SECONDS = 2 -STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17" -STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17" -STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup" +STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline20" +STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint20" +STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline20ModelPackageGroup" @pytest.fixture diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index d2c142ee38..160f9f934b 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -67,6 +67,7 @@ ConditionLessThanOrEqualTo, ) from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet from sagemaker.workflow.callback_step import ( CallbackStep, CallbackOutput, @@ -2831,7 +2832,7 @@ def test_end_to_end_pipeline_successful_execution( # define condition step cond_lte = ConditionLessThanOrEqualTo( - left=JsonGet( + left=ConditionStepJsonGet( step=step_eval, property_file=evaluation_report, json_path="regression_metrics.mse.value", From 3c5ea3af8168e8085fe775614c2de2e2fda2da99 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Tue, 15 Feb 2022 18:27:32 -0800 Subject: [PATCH 5/8] Add exception in test_action (#2938) --- tests/integ/sagemaker/lineage/test_action.py | 24 ++++++++++++-------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/integ/sagemaker/lineage/test_action.py b/tests/integ/sagemaker/lineage/test_action.py index 8b462279ca..cae7a395aa 100644 --- a/tests/integ/sagemaker/lineage/test_action.py +++ b/tests/integ/sagemaker/lineage/test_action.py @@ -139,12 +139,15 @@ def test_downstream_artifacts(static_approval_action): def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session): + try: + sagemaker_session.sagemaker_client.add_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + AssociationType="ContributedTo", + ) + except Exception: + print("Source and Destination association already exists.") - sagemaker_session.sagemaker_client.add_association( - SourceArn=static_dataset_artifact.artifact_arn, - DestinationArn=static_approval_action.action_arn, - AssociationType="ContributedTo", - ) time.sleep(3) artifacts_from_query = static_approval_action.datasets() @@ -153,10 +156,13 @@ def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_ses assert "artifact" in artifact.artifact_arn assert artifact.artifact_type == "DataSet" - sagemaker_session.sagemaker_client.delete_association( - SourceArn=static_dataset_artifact.artifact_arn, - DestinationArn=static_approval_action.action_arn, - ) + try: + sagemaker_session.sagemaker_client.delete_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + ) + except Exception: + pass def test_endpoints(static_approval_action): From 9d84e2e3e33243492ecdfb24fde68e1c9262d3e6 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Tue, 15 Feb 2022 19:37:28 -0800 Subject: [PATCH 6/8] change: pin test dependencies (#2929) --- README.rst | 1 + setup.py | 44 +++++++++++++++++++++++--------------------- tox.ini | 22 +++++++++++----------- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/README.rst b/README.rst index d646153516..ab62eddad0 100644 --- a/README.rst +++ b/README.rst @@ -90,6 +90,7 @@ SageMaker Python SDK is tested on: - Python 3.6 - Python 3.7 - Python 3.8 +- Python 3.9 AWS Permissions ~~~~~~~~~~~~~~~ diff --git a/setup.py b/setup.py index 2eb5838a64..df10a002af 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ - "attrs", + "attrs==20.3.0", "boto3>=1.20.21", "google-pasta", "numpy>=1.9.0", @@ -49,12 +49,12 @@ def read_version(): # Specific use case dependencies extras = { "local": [ - "urllib3>=1.21.1,!=1.25,!=1.25.1", - "docker-compose>=1.25.2", - "docker==5.0.0", - "PyYAML>=5.3, <6", # PyYAML version has to match docker-compose requirements + "urllib3==1.26.8", + "docker-compose==1.29.2", + "docker~=5.0.0", + "PyYAML==5.4.1", # PyYAML version has to match docker-compose requirements ], - "scipy": ["scipy>=0.19.0"], + "scipy": ["scipy==1.5.4"], } # Meta dependency groups extras["all"] = [item for group in extras.values() for item in group] @@ -62,23 +62,25 @@ def read_version(): extras["test"] = ( [ extras["all"], - "tox", - "flake8", - "pytest<6.1.0", - "pytest-cov", - "pytest-rerunfailures", - "pytest-timeout", + "tox==3.24.5", + "flake8==4.0.1", + "pytest==6.0.2", + "pytest-cov==3.0.0", + "pytest-rerunfailures==10.2", + "pytest-timeout==2.1.0", "pytest-xdist==2.4.0", - "coverage<6.2", - "mock", - "contextlib2", - "awslogs", - "black", + "coverage>=5.2, <6.2", + "mock==4.0.3", + "contextlib2==21.6.0", + "awslogs==0.14.0", + "black==22.1.0", "stopit==1.1.2", - "apache-airflow==1.10.11", - "fabric>=2.0", - "requests>=2.20.0, <3", - "sagemaker-experiments", + "apache-airflow==2.2.3", + "apache-airflow-providers-amazon==3.0.0", + "attrs==20.3.0", + "fabric==2.6.0", + "requests==2.27.1", + "sagemaker-experiments==0.1.35", ], ) diff --git a/tox.ini b/tox.ini index 50c75cf18c..a0379214a2 100644 --- a/tox.ini +++ b/tox.ini @@ -59,7 +59,7 @@ markers = timeout: mark a test as a timeout. [testenv] -pip_version = pip==20.2 +pip_version = pip==21.3 passenv = AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY @@ -80,8 +80,8 @@ depends = skipdist = true skip_install = true deps = - flake8 - flake8-future-import + flake8==4.0.1 + flake8-future-import==0.4.6 commands = flake8 [testenv:pylint] @@ -106,7 +106,7 @@ commands = # twine check was added starting in 1.12.0 # https://github.com/pypa/twine/blob/master/docs/changelog.rst deps = - twine>=1.12.0 + twine==3.8.0 # https://packaging.python.org/guides/making-a-pypi-friendly-readme/#validating-restructuredtext-markup commands = python setup.py sdist @@ -118,15 +118,15 @@ changedir = doc # having the requirements.txt installed in deps above results in Double Requirement exception # https://github.com/pypa/pip/issues/988 deps = - pip==20.2 + pip==21.3 commands = pip install --exists-action=w -r requirements.txt sphinx-build -T -W -b html -d _build/doctrees-readthedocs -D language=en . _build/html [testenv:doc8] deps = - doc8 - Pygments + doc8==0.10.1 + Pygments==2.11.2 commands = doc8 [testenv:black-format] @@ -134,7 +134,7 @@ commands = doc8 setenv = LC_ALL=C.UTF-8 LANG=C.UTF-8 -deps = black +deps = black==22.1.0 commands = black -l 100 ./ @@ -143,12 +143,12 @@ commands = setenv = LC_ALL=C.UTF-8 LANG=C.UTF-8 -deps = black +deps = black==22.1.0 commands = black -l 100 --check ./ [testenv:clean] -deps = coverage +deps = coverage==6.2 skip_install = true commands = coverage erase @@ -158,7 +158,7 @@ commands = mypy src/sagemaker [testenv:docstyle] -deps = pydocstyle +deps = pydocstyle==6.1.1 commands = pydocstyle src/sagemaker From cede5faf5820a53685cb02c1e5cdea4dbbdb0619 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Tue, 15 Feb 2022 21:30:55 -0800 Subject: [PATCH 7/8] feature: Add FailStep Support for Sagemaker Pipeline (#2872) --- .../sagemaker.workflow.pipelines.rst | 2 + src/sagemaker/workflow/fail_step.py | 71 ++++ src/sagemaker/workflow/steps.py | 228 ++++++------ tests/integ/test_workflow_with_fail_steps.py | 325 ++++++++++++++++++ .../unit/sagemaker/workflow/test_fail_step.py | 121 +++++++ 5 files changed, 634 insertions(+), 113 deletions(-) create mode 100644 src/sagemaker/workflow/fail_step.py create mode 100644 tests/integ/test_workflow_with_fail_steps.py create mode 100644 tests/unit/sagemaker/workflow/test_fail_step.py diff --git a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst index 1999ca89ef..00e692852e 100644 --- a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst +++ b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst @@ -147,3 +147,5 @@ Steps .. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckConfig .. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckStep + +.. autoclass:: sagemaker.workflow.fail_step.FailStep diff --git a/src/sagemaker/workflow/fail_step.py b/src/sagemaker/workflow/fail_step.py new file mode 100644 index 0000000000..cc908a2a2a --- /dev/null +++ b/src/sagemaker/workflow/fail_step.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The `Step` definitions for SageMaker Pipelines Workflows.""" +from __future__ import absolute_import + +from typing import List, Union + +from sagemaker.workflow import PipelineNonPrimitiveInputTypes +from sagemaker.workflow.entities import ( + RequestType, +) +from sagemaker.workflow.steps import Step, StepTypeEnum + + +class FailStep(Step): + """`FailStep` for SageMaker Pipelines Workflows.""" + + def __init__( + self, + name: str, + error_message: Union[str, PipelineNonPrimitiveInputTypes] = None, + display_name: str = None, + description: str = None, + depends_on: Union[List[str], List[Step]] = None, + ): + """Constructs a `FailStep`. + + Args: + name (str): The name of the `FailStep`. A name is required and must be + unique within a pipeline. + error_message (str or PipelineNonPrimitiveInputTypes): + An error message defined by the user. + Once the `FailStep` is reached, the execution fails and the + error message is set as the failure reason (default: None). + display_name (str): The display name of the `FailStep`. + The display name provides better UI readability. (default: None). + description (str): The description of the `FailStep` (default: None). + depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances + that this `FailStep` depends on. + If a listed `Step` name does not exist, an error is returned (default: None). + """ + super(FailStep, self).__init__( + name, display_name, description, StepTypeEnum.FAIL, depends_on + ) + self.error_message = error_message if error_message is not None else "" + + @property + def arguments(self) -> RequestType: + """The arguments dictionary that is used to define the `FailStep`.""" + return dict(ErrorMessage=self.error_message) + + @property + def properties(self): + """A `Properties` object is not available for the `FailStep`. + + Executing a `FailStep` will terminate the pipeline. + `FailStep` properties should not be referenced. + """ + raise RuntimeError( + "FailStep is a terminal step and the Properties object is not available for it." + ) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 6f5a78b8ad..99f3444f23 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The `Step` definitions for SageMaker Pipelines Workflows.""" from __future__ import absolute_import import abc @@ -48,7 +48,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): - """Enum of step types.""" + """Enum of `Step` types.""" CONDITION = "Condition" CREATE_MODEL = "Model" @@ -62,20 +62,21 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): QUALITY_CHECK = "QualityCheck" CLARIFY_CHECK = "ClarifyCheck" EMR = "EMR" + FAIL = "Fail" @attr.s class Step(Entity): - """Pipeline step for workflow. + """Pipeline `Step` for SageMaker Pipelines Workflows. Attributes: - name (str): The name of the step. - display_name (str): The display name of the step. - description (str): The description of the step. - step_type (StepTypeEnum): The type of the step. - depends_on (List[str] or List[Step]): The list of step names or step - instances the current step depends on - retry_policies (List[RetryPolicy]): The custom retry policy configuration + name (str): The name of the `Step`. + display_name (str): The display name of the `Step`. + description (str): The description of the `Step`. + step_type (StepTypeEnum): The type of the `Step`. + depends_on (List[str] or List[Step]): The list of `Step` names or `Step` + instances that the current `Step` depends on. + retry_policies (List[RetryPolicy]): The custom retry policy configuration. """ name: str = attr.ib(factory=str) @@ -87,12 +88,12 @@ class Step(Entity): @property @abc.abstractmethod def arguments(self) -> RequestType: - """The arguments to the particular step service call.""" + """The arguments to the particular `Step` service call.""" @property @abc.abstractmethod def properties(self): - """The properties of the particular step.""" + """The properties of the particular `Step`.""" def to_request(self) -> RequestType: """Gets the request structure for workflow service calls.""" @@ -111,7 +112,7 @@ def to_request(self) -> RequestType: return request_dict def add_depends_on(self, step_names: Union[List[str], List["Step"]]): - """Add step names or step instances to the current step depends on list""" + """Add `Step` names or `Step` instances to the current `Step` depends on list.""" if not step_names: return @@ -122,12 +123,12 @@ def add_depends_on(self, step_names: Union[List[str], List["Step"]]): @property def ref(self) -> Dict[str, str]: - """Gets a reference dict for steps""" + """Gets a reference dictionary for `Step` instances.""" return {"Name": self.name} @staticmethod def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]) -> List[str]: - """Resolve the step depends on list""" + """Resolve the `Step` depends on list.""" depends_on = [] for step in depends_on_list: if isinstance(step, Step): @@ -141,18 +142,19 @@ def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]) -> List @attr.s class CacheConfig: - """Configuration class to enable caching in pipeline workflow. + """Configuration class to enable caching in SageMaker Pipelines Workflows. - If caching is enabled, the pipeline attempts to find a previous execution of a step - that was called with the same arguments. Step caching only considers successful execution. + If caching is enabled, the pipeline attempts to find a previous execution of a `Step` + that was called with the same arguments. `Step` caching only considers successful execution. If a successful previous execution is found, the pipeline propagates the values - from previous execution rather than recomputing the step. When multiple successful executions - exist within the timeout period, it uses the result for the most recent successful execution. + from the previous execution rather than recomputing the `Step`. + When multiple successful executions exist within the timeout period, + it uses the result for the most recent successful execution. Attributes: - enable_caching (bool): To enable step caching. Defaults to `False`. - expire_after (str): If step caching is enabled, a timeout also needs to defined. + enable_caching (bool): To enable `Step` caching. Defaults to `False`. + expire_after (str): If `Step` caching is enabled, a timeout also needs to defined. It defines how old a previous execution can be to be considered for reuse. Value should be an ISO 8601 duration string. Defaults to `None`. @@ -170,7 +172,7 @@ class CacheConfig: @property def config(self): - """Configures caching in pipeline steps.""" + """Configures `Step` caching for SageMaker Pipelines Workflows.""" config = {"Enabled": self.enable_caching} if self.expire_after is not None: config["ExpireAfter"] = self.expire_after @@ -178,7 +180,7 @@ def config(self): class ConfigurableRetryStep(Step): - """ConfigurableRetryStep step for workflow.""" + """`ConfigurableRetryStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -199,7 +201,7 @@ def __init__( self.retry_policies = [] if not retry_policies else retry_policies def add_retry_policy(self, retry_policy: RetryPolicy): - """Add a retry policy to the current step retry policies list.""" + """Add a policy to the current `ConfigurableRetryStep` retry policies list.""" if not retry_policy: return @@ -208,7 +210,7 @@ def add_retry_policy(self, retry_policy: RetryPolicy): self.retry_policies.append(retry_policy) def to_request(self) -> RequestType: - """Gets the request structure for ConfigurableRetryStep""" + """Gets the request structure for `ConfigurableRetryStep`.""" step_dict = super().to_request() if self.retry_policies: step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies) @@ -216,12 +218,12 @@ def to_request(self) -> RequestType: @staticmethod def _resolve_retry_policy(retry_policy_list: List[RetryPolicy]) -> List[RequestType]: - """Resolve the step retry policy list""" + """Resolve the `ConfigurableRetryStep` retry policy list.""" return [retry_policy.to_request() for retry_policy in retry_policy_list] class TrainingStep(ConfigurableRetryStep): - """Training step for workflow.""" + """`TrainingStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -234,23 +236,23 @@ def __init__( depends_on: Union[List[str], List[Step]] = None, retry_policies: List[RetryPolicy] = None, ): - """Construct a TrainingStep, given an `EstimatorBase` instance. + """Construct a `TrainingStep`, given an `EstimatorBase` instance. - In addition to the estimator instance, the other arguments are those that are supplied to - the `fit` method of the `sagemaker.estimator.Estimator`. + In addition to the `EstimatorBase` instance, the other arguments are those + that are supplied to the `fit` method of the `sagemaker.estimator.Estimator`. Args: - name (str): The name of the training step. + name (str): The name of the `TrainingStep`. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. - display_name (str): The display name of the training step. - description (str): The description of the training step. + display_name (str): The display name of the `TrainingStep`. + description (str): The description of the `TrainingStep`. inputs (Union[str, dict, TrainingInput, FileSystemInput]): Information about the training data. This can be one of three types: * (str) the S3 location where training data is saved, or a file:// path in local mode. * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple - channels for training data, you can specify a dict mapping channel names to + channels for training data, you can specify a dictionary mapping channel names to strings or :func:`~sagemaker.inputs.TrainingInput` objects. * (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources that can provide additional information as well as the path to the training @@ -261,9 +263,9 @@ def __init__( the path to the training dataset. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of step names or step instances - this `sagemaker.workflow.steps.TrainingStep` depends on - retry_policies (List[RetryPolicy]): A list of retry policy + depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances + this `sagemaker.workflow.steps.TrainingStep` depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. """ super(TrainingStep, self).__init__( name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies @@ -287,10 +289,10 @@ def __init__( @property def arguments(self) -> RequestType: - """The arguments dict that is used to call `create_training_job`. + """The arguments dictionary that is used to call `create_training_job`. - NOTE: The CreateTrainingJob request is not quite the args list that workflow needs. - The TrainingJobName and ExperimentConfig attributes cannot be included. + NOTE: The `CreateTrainingJob` request is not quite the args list that workflow needs. + The `TrainingJobName` and `ExperimentConfig` attributes cannot be included. """ self.estimator._prepare_for_training() @@ -304,11 +306,11 @@ def arguments(self) -> RequestType: @property def properties(self): - """A Properties object representing the DescribeTrainingJobResponse data model.""" + """A `Properties` object representing the `DescribeTrainingJobResponse` data model.""" return self._properties def to_request(self) -> RequestType: - """Updates the dictionary with cache configuration.""" + """Updates the request dictionary with cache configuration.""" request_dict = super().to_request() if self.cache_config: request_dict.update(self.cache_config.config) @@ -317,7 +319,7 @@ def to_request(self) -> RequestType: class CreateModelStep(ConfigurableRetryStep): - """CreateModel step for workflow.""" + """`CreateModelStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -329,22 +331,22 @@ def __init__( display_name: str = None, description: str = None, ): - """Construct a CreateModelStep, given an `sagemaker.model.Model` instance. + """Construct a `CreateModelStep`, given an `sagemaker.model.Model` instance. - In addition to the Model instance, the other arguments are those that are supplied to + In addition to the `Model` instance, the other arguments are those that are supplied to the `_create_sagemaker_model` method of the `sagemaker.model.Model._create_sagemaker_model`. Args: - name (str): The name of the CreateModel step. + name (str): The name of the `CreateModelStep`. model (Model or PipelineModel): A `sagemaker.model.Model` or `sagemaker.pipeline.PipelineModel` instance. inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance. Defaults to `None`. - depends_on (List[str] or List[Step]): A list of step names or step instances - this `sagemaker.workflow.steps.CreateModelStep` depends on - retry_policies (List[RetryPolicy]): A list of retry policy - display_name (str): The display name of the CreateModel step. - description (str): The description of the CreateModel step. + depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances + this `sagemaker.workflow.steps.CreateModelStep` depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. + display_name (str): The display name of the `CreateModelStep`. + description (str): The description of the `CreateModelStep`. """ super(CreateModelStep, self).__init__( name, StepTypeEnum.CREATE_MODEL, display_name, description, depends_on, retry_policies @@ -356,10 +358,10 @@ def __init__( @property def arguments(self) -> RequestType: - """The arguments dict that is used to call `create_model`. + """The arguments dictionary that is used to call `create_model`. - NOTE: The CreateModelRequest is not quite the args list that workflow needs. - ModelName cannot be included in the arguments. + NOTE: The `CreateModelRequest` is not quite the args list that workflow needs. + `ModelName` cannot be included in the arguments. """ if isinstance(self.model, PipelineModel): @@ -387,12 +389,12 @@ def arguments(self) -> RequestType: @property def properties(self): - """A Properties object representing the DescribeModelResponse data model.""" + """A `Properties` object representing the `DescribeModelResponse` data model.""" return self._properties class TransformStep(ConfigurableRetryStep): - """Transform step for workflow.""" + """`TransformStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -405,21 +407,21 @@ def __init__( depends_on: Union[List[str], List[Step]] = None, retry_policies: List[RetryPolicy] = None, ): - """Constructs a TransformStep, given an `Transformer` instance. + """Constructs a `TransformStep`, given a `Transformer` instance. - In addition to the transformer instance, the other arguments are those that are supplied to - the `transform` method of the `sagemaker.transformer.Transformer`. + In addition to the `Transformer` instance, the other arguments are those + that are supplied to the `transform` method of the `sagemaker.transformer.Transformer`. Args: - name (str): The name of the transform step. + name (str): The name of the `TransformStep`. transformer (Transformer): A `sagemaker.transformer.Transformer` instance. inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. - cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - display_name (str): The display name of the transform step. - description (str): The description of the transform step. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep` - depends on - retry_policies (List[RetryPolicy]): A list of retry policy + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. + display_name (str): The display name of the `TransformStep`. + description (str): The description of the `TransformStep`. + depends_on (List[str]): A list of `Step` names that this `sagemaker.workflow.steps.TransformStep` + depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. """ super(TransformStep, self).__init__( name, StepTypeEnum.TRANSFORM, display_name, description, depends_on, retry_policies @@ -433,10 +435,10 @@ def __init__( @property def arguments(self) -> RequestType: - """The arguments dict that is used to call `create_transform_job`. + """The arguments dictionary that is used to call `create_transform_job`. - NOTE: The CreateTransformJob request is not quite the args list that workflow needs. - TransformJobName and ExperimentConfig cannot be included in the arguments. + NOTE: The `CreateTransformJob` request is not quite the args list that workflow needs. + `TransformJobName` and `ExperimentConfig` cannot be included in the arguments. """ transform_args = _TransformJob._get_transform_args( transformer=self.transformer, @@ -459,7 +461,7 @@ def arguments(self) -> RequestType: @property def properties(self): - """A Properties object representing the DescribeTransformJobResponse data model.""" + """A `Properties` object representing the `DescribeTransformJobResponse` data model.""" return self._properties def to_request(self) -> RequestType: @@ -472,7 +474,7 @@ def to_request(self) -> RequestType: class ProcessingStep(ConfigurableRetryStep): - """Processing step for workflow.""" + """`ProcessingStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -490,16 +492,16 @@ def __init__( retry_policies: List[RetryPolicy] = None, kms_key=None, ): - """Construct a ProcessingStep, given a `Processor` instance. + """Construct a `ProcessingStep`, given a `Processor` instance. - In addition to the processor instance, the other arguments are those that are supplied to + In addition to the `Processor` instance, the other arguments are those that are supplied to the `process` method of the `sagemaker.processing.Processor`. Args: - name (str): The name of the processing step. + name (str): The name of the `ProcessingStep`. processor (Processor): A `sagemaker.processing.Processor` instance. - display_name (str): The display name of the processing step. - description (str): The description of the processing step. + display_name (str): The display name of the `ProcessingStep`. + description (str): The description of the `ProcessingStep` inputs (List[ProcessingInput]): A list of `sagemaker.processing.ProcessorInput` instances. Defaults to `None`. outputs (List[ProcessingOutput]): A list of `sagemaker.processing.ProcessorOutput` @@ -511,9 +513,9 @@ def __init__( property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of step names or step instance - this `sagemaker.workflow.steps.ProcessingStep` depends on - retry_policies (List[RetryPolicy]): A list of retry policy + depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances that + this `sagemaker.workflow.steps.ProcessingStep` depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file. Defaults to `None`. """ @@ -529,8 +531,8 @@ def __init__( self.job_name = None self.kms_key = kms_key - # Examine why run method in sagemaker.processing.Processor mutates the processor instance - # by setting the instance's arguments attribute. Refactor Processor.run, if possible. + # Examine why run method in `sagemaker.processing.Processor` mutates the processor instance + # by setting the instance's arguments attribute. Refactor `Processor.run`, if possible. self.processor.arguments = job_arguments self._properties = Properties( @@ -541,20 +543,20 @@ def __init__( if code: code_url = urlparse(code) if code_url.scheme == "" or code_url.scheme == "file": - # By default, Processor will upload the local code to an S3 path + # By default, `Processor` will upload the local code to an S3 path # containing a timestamp. This causes cache misses whenever a # pipeline is updated, even if the underlying script hasn't changed. # To avoid this, hash the contents of the script and include it - # in the job_name passed to the Processor, which will be used + # in the `job_name` passed to the `Processor`, which will be used # instead of the timestamped path. self.job_name = self._generate_code_upload_path() @property def arguments(self) -> RequestType: - """The arguments dict that is used to call `create_processing_job`. + """The arguments dictionary that is used to call `create_processing_job`. - NOTE: The CreateProcessingJob request is not quite the args list that workflow needs. - ProcessingJobName and ExperimentConfig cannot be included in the arguments. + NOTE: The `CreateProcessingJob` request is not quite the args list that workflow needs. + `ProcessingJobName` and `ExperimentConfig` cannot be included in the arguments. """ normalized_inputs, normalized_outputs = self.processor._normalize_args( job_name=self.job_name, @@ -574,7 +576,7 @@ def arguments(self) -> RequestType: @property def properties(self): - """A Properties object representing the DescribeProcessingJobResponse data model.""" + """A `Properties` object representing the `DescribeProcessingJobResponse` data model.""" return self._properties def to_request(self) -> RequestType: @@ -589,7 +591,7 @@ def to_request(self) -> RequestType: return request_dict def _generate_code_upload_path(self) -> str: - """Generate an upload path for local processing scripts based on its contents""" + """Generate an upload path for local processing scripts based on its contents.""" from sagemaker.workflow.utilities import hash_file code_hash = hash_file(self.code) @@ -597,7 +599,7 @@ def _generate_code_upload_path(self) -> str: class TuningStep(ConfigurableRetryStep): - """Tuning step for workflow.""" + """`TuningStep` for SageMaker Pipelines Workflows.""" def __init__( self, @@ -611,24 +613,24 @@ def __init__( depends_on: Union[List[str], List[Step]] = None, retry_policies: List[RetryPolicy] = None, ): - """Construct a TuningStep, given a `HyperparameterTuner` instance. + """Construct a `TuningStep`, given a `HyperparameterTuner` instance. - In addition to the tuner instance, the other arguments are those that are supplied to - the `fit` method of the `sagemaker.tuner.HyperparameterTuner`. + In addition to the `HyperparameterTuner` instance, the other arguments are those + that are supplied to the `fit` method of the `sagemaker.tuner.HyperparameterTuner`. Args: - name (str): The name of the tuning step. + name (str): The name of the `TuningStep`. tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance. - display_name (str): The display name of the tuning step. - description (str): The description of the tuning step. + display_name (str): The display name of the `TuningStep`. + description (str): The description of the `TuningStep`. inputs: Information about the training data. Please refer to the - ``fit()`` method of the associated estimator, as this can take + `fit()` method of the associated estimator, as this can take any of the following forms: * (str) - The S3 location where training data is saved. * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple channels for training data, you can specify - a dict mapping channel names to strings or + a dictionary mapping channel names to strings or :func:`~sagemaker.inputs.TrainingInput` objects. * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can provide additional information about the training dataset. @@ -651,9 +653,9 @@ def __init__( job_arguments (List[str]): A list of strings to be passed into the processing job. Defaults to `None`. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of step names or step instance - this `sagemaker.workflow.steps.ProcessingStep` depends on - retry_policies (List[RetryPolicy]): A list of retry policy + depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances that + this `sagemaker.workflow.steps.ProcessingStep` depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. """ super(TuningStep, self).__init__( name, StepTypeEnum.TUNING, display_name, description, depends_on, retry_policies @@ -672,11 +674,11 @@ def __init__( @property def arguments(self) -> RequestType: - """The arguments dict that is used to call `create_hyper_parameter_tuning_job`. + """The arguments dictionary that is used to call `create_hyper_parameter_tuning_job`. - NOTE: The CreateHyperParameterTuningJob request is not quite the + NOTE: The `CreateHyperParameterTuningJob` request is not quite the args list that workflow needs. - The HyperParameterTuningJobName attribute cannot be included. + The `HyperParameterTuningJobName` attribute cannot be included. """ if self.tuner.estimator is not None: self.tuner.estimator._prepare_for_training() @@ -693,9 +695,9 @@ def arguments(self) -> RequestType: @property def properties(self): - """A Properties object representing + """A `Properties` object - `DescribeHyperParameterTuningJobResponse` and + A `Properties` object representing `DescribeHyperParameterTuningJobResponse` and `ListTrainingJobsForHyperParameterTuningJobResponse` data model. """ return self._properties @@ -709,15 +711,15 @@ def to_request(self) -> RequestType: return request_dict def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") -> Join: - """Get the model artifact s3 uri from the top performing training jobs. + """Get the model artifact S3 URI from the top performing training jobs. Args: - top_k (int): the index of the top performing training job - tuning step stores up to 50 top performing training jobs, hence - a valid top_k value is from 0 to 49. The best training job - model is at index 0 - s3_bucket (str): the s3 bucket to store the training job output artifact - prefix (str): the s3 key prefix to store the training job output artifact + top_k (int): The index of the top performing training job + tuning step stores up to 50 top performing training jobs. + A valid top_k value is from 0 to 49. The best training job + model is at index 0. + s3_bucket (str): The S3 bucket to store the training job output artifact. + prefix (str): The S3 key prefix to store the training job output artifact. """ values = ["s3:/", s3_bucket] if prefix != "" and prefix is not None: diff --git a/tests/integ/test_workflow_with_fail_steps.py b/tests/integ/test_workflow_with_fail_steps.py new file mode 100644 index 0000000000..ba00b4f972 --- /dev/null +++ b/tests/integ/test_workflow_with_fail_steps.py @@ -0,0 +1,325 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from botocore.exceptions import WaiterError + +from sagemaker import get_execution_role, utils +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals +from sagemaker.workflow.fail_step import FailStep + +from sagemaker.workflow.functions import Join +from sagemaker.workflow.parameters import ParameterInteger, ParameterString +from sagemaker.workflow.pipeline import Pipeline + + +@pytest.fixture +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture +def pipeline_name(): + return utils.unique_name_from_base("my-pipeline-fail-step") + + +def test_two_step_fail_pipeline_with_str_err_msg(sagemaker_session, role, pipeline_name): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + step_fail = FailStep( + name="FailStep", + error_message="Failed due to hitting in else branch", + ) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[step_fail], + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step_cond], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + + try: + response = pipeline.create(role) + pipeline_arn = response["PipelineArn"] + execution = pipeline.start(parameters={}) + response = execution.describe() + assert response["PipelineArn"] == pipeline_arn + + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert execution_step["FailureReason"] == "Failed due to hitting in else branch" + metadata = execution_steps[0]["Metadata"]["Fail"] + assert metadata["ErrorMessage"] == "Failed due to hitting in else branch" + + # Check FailureReason field in ListPipelineExecutions + executions = sagemaker_session.sagemaker_client.list_pipeline_executions( + PipelineName=pipeline.name + )["PipelineExecutionSummaries"] + + assert len(executions) == 1 + assert executions[0]["PipelineExecutionStatus"] == "Failed" + assert ( + "Step failure: One or multiple steps failed" + in executions[0]["PipelineExecutionFailureReason"] + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_two_step_fail_pipeline_with_parameter_err_msg(sagemaker_session, role, pipeline_name): + cond_param = ParameterInteger(name="MyInt") + cond = ConditionEquals(left=cond_param, right=1) + err_msg_param = ParameterString(name="MyString") + step_fail = FailStep( + name="FailStep", + error_message=err_msg_param, + ) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[step_fail], + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step_cond], + sagemaker_session=sagemaker_session, + parameters=[cond_param, err_msg_param], + ) + + try: + response = pipeline.create(role) + pipeline_arn = response["PipelineArn"] + execution = pipeline.start( + parameters={ + "MyInt": 3, + "MyString": "Failed due to hitting in else branch", + } + ) + response = execution.describe() + assert response["PipelineArn"] == pipeline_arn + + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert execution_step["FailureReason"] == "Failed due to hitting in else branch" + metadata = execution_steps[0]["Metadata"]["Fail"] + assert metadata["ErrorMessage"] == "Failed due to hitting in else branch" + + # Check FailureReason field in ListPipelineExecutions + executions = sagemaker_session.sagemaker_client.list_pipeline_executions( + PipelineName=pipeline.name + )["PipelineExecutionSummaries"] + + assert len(executions) == 1 + assert executions[0]["PipelineExecutionStatus"] == "Failed" + assert ( + "Step failure: One or multiple steps failed" + in executions[0]["PipelineExecutionFailureReason"] + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_two_step_fail_pipeline_with_join_fn(sagemaker_session, role, pipeline_name): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[], + ) + step_fail = FailStep( + name="FailStep", + error_message=Join( + on=": ", values=["Failed due to xxx == yyy returns", step_cond.properties.Outcome] + ), + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step_cond, step_fail], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + + try: + response = pipeline.create(role) + pipeline_arn = response["PipelineArn"] + execution = pipeline.start( + parameters={"MyInt": 3}, + ) + response = execution.describe() + assert response["PipelineArn"] == pipeline_arn + + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert execution_step["FailureReason"] == "Failed due to xxx == yyy returns: false" + metadata = execution_steps[0]["Metadata"]["Fail"] + assert metadata["ErrorMessage"] == "Failed due to xxx == yyy returns: false" + + # Check FailureReason field in ListPipelineExecutions + executions = sagemaker_session.sagemaker_client.list_pipeline_executions( + PipelineName=pipeline.name + )["PipelineExecutionSummaries"] + + assert len(executions) == 1 + assert executions[0]["PipelineExecutionStatus"] == "Failed" + assert ( + "Step failure: One or multiple steps failed" + in executions[0]["PipelineExecutionFailureReason"] + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_two_step_fail_pipeline_with_no_err_msg(sagemaker_session, role, pipeline_name): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + step_fail = FailStep( + name="FailStep", + ) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[step_fail], + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step_cond], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + + try: + response = pipeline.create(role) + pipeline_arn = response["PipelineArn"] + execution = pipeline.start(parameters={}) + response = execution.describe() + assert response["PipelineArn"] == pipeline_arn + + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert execution_step.get("FailureReason", None) is None + metadata = execution_steps[0]["Metadata"]["Fail"] + assert metadata["ErrorMessage"] == "" + + # Check FailureReason field in ListPipelineExecutions + executions = sagemaker_session.sagemaker_client.list_pipeline_executions( + PipelineName=pipeline.name + )["PipelineExecutionSummaries"] + + assert len(executions) == 1 + assert executions[0]["PipelineExecutionStatus"] == "Failed" + assert ( + "Step failure: One or multiple steps failed" + in executions[0]["PipelineExecutionFailureReason"] + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_invalid_pipeline_depended_on_fail_step(sagemaker_session, role, pipeline_name): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + step_fail = FailStep( + name="FailStep", + error_message="Failed pipeline execution", + ) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[], + depends_on=["FailStep"], + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step_cond, step_fail], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + + try: + with pytest.raises(Exception) as error: + pipeline.create(role) + + assert "CondStep can not depends on FailStep" in str(error.value) + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/workflow/test_fail_step.py b/tests/unit/sagemaker/workflow/test_fail_step.py new file mode 100644 index 0000000000..04edaf0ac5 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_fail_step.py @@ -0,0 +1,121 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +import pytest + +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals +from sagemaker.workflow.fail_step import FailStep +from sagemaker.workflow.functions import Join +from sagemaker.workflow.parameters import ParameterInteger +from sagemaker.workflow.pipeline import Pipeline + + +def test_fail_step(): + fail_step = FailStep( + name="MyFailStep", + depends_on=["TestStep"], + error_message="Test error message", + ) + fail_step.add_depends_on(["SecondTestStep"]) + assert fail_step.to_request() == { + "Name": "MyFailStep", + "Type": "Fail", + "DependsOn": ["TestStep", "SecondTestStep"], + "Arguments": {"ErrorMessage": "Test error message"}, + } + + +def test_fail_step_with_no_error_message(): + fail_step = FailStep( + name="MyFailStep", + depends_on=["TestStep"], + ) + fail_step.add_depends_on(["SecondTestStep"]) + assert fail_step.to_request() == { + "Name": "MyFailStep", + "Type": "Fail", + "DependsOn": ["TestStep", "SecondTestStep"], + "Arguments": {"ErrorMessage": ""}, + } + + +def test_fail_step_with_join_fn_in_error_message(): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[], + ) + step_fail = FailStep( + name="FailStep", + error_message=Join( + on=": ", values=["Failed due to xxx == yyy returns", step_cond.properties.Outcome] + ), + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step_cond, step_fail], + parameters=[param], + ) + + _expected_dsl = [ + { + "Name": "CondStep", + "Type": "Condition", + "Arguments": { + "Conditions": [ + {"Type": "Equals", "LeftValue": {"Get": "Parameters.MyInt"}, "RightValue": 1} + ], + "IfSteps": [], + "ElseSteps": [], + }, + }, + { + "Name": "FailStep", + "Type": "Fail", + "Arguments": { + "ErrorMessage": { + "Std:Join": { + "On": ": ", + "Values": [ + "Failed due to xxx == yyy returns", + {"Get": "Steps.CondStep.Outcome"}, + ], + } + } + }, + }, + ] + + assert json.loads(pipeline.definition())["Steps"] == _expected_dsl + + +def test_fail_step_with_properties_ref(): + fail_step = FailStep( + name="MyFailStep", + error_message="Test error message", + ) + + with pytest.raises(Exception) as error: + fail_step.properties() + + assert ( + str(error.value) + == "FailStep is a terminal step and the Properties object is not available for it." + ) From a765512536fd5bd95733ce9aa497783a06cc8bdf Mon Sep 17 00:00:00 2001 From: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> Date: Wed, 16 Feb 2022 03:48:34 -0800 Subject: [PATCH 8/8] change: use recommended inference image uri from Neo API (#2923) --- src/sagemaker/model.py | 36 +------------ tests/unit/sagemaker/model/test_neo.py | 73 ++++++++++++++------------ tests/unit/test_mxnet.py | 15 +++--- 3 files changed, 46 insertions(+), 78 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ede78c7cce..00a04a3199 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -23,7 +23,6 @@ import sagemaker from sagemaker import ( fw_utils, - image_uris, local, s3, session, @@ -657,34 +656,6 @@ def _compilation_job_config( "job_name": job_name, } - def _compilation_image_uri(self, region, target_instance_type, framework, framework_version): - """Retrieve the Neo or Inferentia image URI. - - Args: - region (str): The AWS region. - target_instance_type (str): Identifies the device on which you want to run - your model after compilation, for example: ml_c5. For valid values, see - https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html. - framework (str): The framework name. - framework_version (str): The framework version. - """ - framework_prefix = "" - framework_suffix = "" - - if framework == "xgboost": - framework_suffix = "-neo" - elif target_instance_type.startswith("ml_inf"): - framework_prefix = "inferentia-" - else: - framework_prefix = "neo-" - - return image_uris.retrieve( - "{}{}{}".format(framework_prefix, framework, framework_suffix), - region, - instance_type=target_instance_type, - version=framework_version, - ) - def package_for_edge( self, output_path, @@ -849,12 +820,7 @@ def compile( if target_instance_family == "ml_eia2": pass elif target_instance_family.startswith("ml_"): - self.image_uri = self._compilation_image_uri( - self.sagemaker_session.boto_region_name, - target_instance_family, - framework, - framework_version, - ) + self.image_uri = job_status.get("InferenceImage", None) self._is_compiled_model = True else: LOGGER.warning( diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 16b5bc6ee6..2357c771f9 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -20,12 +20,15 @@ MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" +IMAGE_URI = "inference-container-uri" + REGION = "us-west-2" NEO_REGION_ACCOUNT = "301217895009" DESCRIBE_COMPILATION_JOB_RESPONSE = { "CompilationJobStatus": "Completed", "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, + "InferenceImage": IMAGE_URI, } @@ -52,12 +55,7 @@ def test_compile_model_for_inferentia(sagemaker_session): framework_version="1.15.0", job_name="compile-model", ) - assert ( - "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format( - NEO_REGION_ACCOUNT, REGION - ) - == model.image_uri - ) + assert DESCRIBE_COMPILATION_JOB_RESPONSE["InferenceImage"] == model.image_uri assert model._is_compiled_model is True @@ -271,11 +269,12 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem assert model.endpoint_name.startswith("{}-ml-c4".format(model_name)) -@patch("sagemaker.session.Session") -def test_compile_with_framework_version_15(session): - session.return_value.boto_region_name = REGION +def test_compile_with_framework_version_15(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) - model = _create_model() + model = _create_model(sagemaker_session) model.compile( target_instance_family="ml_c4", input_shape={"data": [1, 3, 1024, 1024]}, @@ -286,14 +285,15 @@ def test_compile_with_framework_version_15(session): job_name="compile-model", ) - assert "1.5" in model.image_uri + assert IMAGE_URI == model.image_uri -@patch("sagemaker.session.Session") -def test_compile_with_framework_version_16(session): - session.return_value.boto_region_name = REGION +def test_compile_with_framework_version_16(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) - model = _create_model() + model = _create_model(sagemaker_session) model.compile( target_instance_family="ml_c4", input_shape={"data": [1, 3, 1024, 1024]}, @@ -304,26 +304,7 @@ def test_compile_with_framework_version_16(session): job_name="compile-model", ) - assert "1.6" in model.image_uri - - -@patch("sagemaker.session.Session") -def test_compile_validates_framework_version(session): - session.return_value.boto_region_name = REGION - - model = _create_model() - with pytest.raises(ValueError) as e: - model.compile( - target_instance_family="ml_c4", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="pytorch", - framework_version="1.6.1", - job_name="compile-model", - ) - - assert "Unsupported neo-pytorch version: 1.6.1." in str(e) + assert IMAGE_URI == model.image_uri @patch("sagemaker.session.Session") @@ -347,3 +328,25 @@ def test_compile_with_pytorch_neo_in_ml_inf(session): ) != model.image_uri ) + + +def test_compile_validates_framework_version(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value={ + "CompilationJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, + "InferenceImage": None, + } + ) + model = _create_model(sagemaker_session) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="pytorch", + framework_version="1.6.1", + job_name="compile-model", + ) + + assert model.image_uri is None diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 7e6f63eb4e..991eeac2ec 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -68,6 +68,8 @@ ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} +INFERENCE_IMAGE_URI = "inference-uri" + @pytest.fixture() def sagemaker_session(): @@ -83,7 +85,10 @@ def sagemaker_session(): ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} - describe_compilation = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}} + describe_compilation = { + "ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}, + "InferenceImage": INFERENCE_IMAGE_URI, + } session.sagemaker_client.create_model_package.side_effect = MODEL_PKG_RESPONSE session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) @@ -195,12 +200,6 @@ def _create_compilation_job(input_shape, output_location): } -def _neo_inference_image(mxnet_version): - return "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-inference-{}:{}-cpu-py3".format( - FRAMEWORK.lower(), mxnet_version - ) - - @patch("sagemaker.estimator.name_from_base") @patch("sagemaker.utils.create_tar_file", MagicMock()) def test_create_model( @@ -422,7 +421,7 @@ def test_mxnet_neo(time, strftime, sagemaker_session, neo_mxnet_version): actual_compile_model_args = sagemaker_session.method_calls[3][2] assert expected_compile_model_args == actual_compile_model_args - assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version) + assert compiled_model.image_uri == INFERENCE_IMAGE_URI predictor = mx.deploy(1, CPU, use_compiled_model=True) assert isinstance(predictor, MXNetPredictor)