From 1cbfe503fc0906af82f6ce1e711186086525ba28 Mon Sep 17 00:00:00 2001 From: Madhubalasri B <79440059+MadhubalasriB@users.noreply.github.com> Date: Fri, 18 Feb 2022 11:39:56 +0530 Subject: [PATCH 01/14] feature: adding customer metadata support to registermodel step (#2935) --- src/sagemaker/estimator.py | 4 ++++ src/sagemaker/model.py | 4 ++++ src/sagemaker/mxnet/model.py | 4 ++++ src/sagemaker/pytorch/model.py | 4 ++++ src/sagemaker/session.py | 24 ++++++++++++++++++++++ src/sagemaker/tensorflow/model.py | 5 +++++ src/sagemaker/workflow/_utils.py | 5 +++++ src/sagemaker/workflow/step_collections.py | 7 ++++++- tests/integ/test_workflow.py | 3 +++ tests/unit/test_session.py | 3 +++ 10 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6431ca8afc..fd74633584 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1263,6 +1263,7 @@ def register( compile_model_family=None, model_name=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1292,6 +1293,8 @@ def register( model will be used (default: None). model_name (str): User defined model name (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1322,6 +1325,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) @property diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 00a04a3199..2d01bb4c0f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -303,6 +303,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -328,6 +329,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -355,6 +358,7 @@ def register( description=description, container_def_list=[container_def], drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index df0dd31a28..0a10cbf3c1 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -158,6 +158,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +184,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -211,6 +214,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 3a0c3a283c..0f51788626 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -157,6 +157,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -182,6 +183,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -210,6 +213,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91b89ea4c9..c50a22d3f8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2778,6 +2778,7 @@ def create_model_package_from_containers( approval_status="PendingManualApproval", description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -2803,6 +2804,9 @@ def create_model_package_from_containers( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + """ request = get_create_model_package_request( @@ -2819,7 +2823,17 @@ def create_model_package_from_containers( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) + if model_package_group_name is not None: + try: + self.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) + except ClientError: + self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5): @@ -4120,6 +4134,7 @@ def get_model_package_args( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get arguments for create_model_package method. @@ -4148,6 +4163,8 @@ def get_model_package_args( (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: dict: A dictionary of method argument names and values. """ @@ -4185,6 +4202,8 @@ def get_model_package_args( model_package_args["description"] = description if tags is not None: model_package_args["tags"] = tags + if customer_metadata_properties is not None: + model_package_args["customer_metadata_properties"] = customer_metadata_properties return model_package_args @@ -4203,6 +4222,7 @@ def get_create_model_package_request( description=None, tags=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -4229,6 +4249,8 @@ def get_create_model_package_request( tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4250,6 +4272,8 @@ def get_create_model_package_request( request_dict["DriftCheckBaselines"] = drift_check_baselines if metadata_properties: request_dict["MetadataProperties"] = metadata_properties + if customer_metadata_properties is not None: + request_dict["CustomerMetadataProperties"] = customer_metadata_properties if containers is not None: if not all([content_types, response_types, inference_instances, transform_instances]): raise ValueError( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 0b8d2f7235..9f6a7841d5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -201,6 +201,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +227,9 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + Returns: A `sagemaker.model.ModelPackage` instance. @@ -254,6 +258,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ca078fe7ea..d341af211d 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -310,6 +310,7 @@ def __init__( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Constructor of a register model step. @@ -347,6 +348,8 @@ def __init__( this step depends on retry_policies (List[RetryPolicy]): The list of retry policies for the current step drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -362,6 +365,7 @@ def __init__( self.tags = tags self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines + self.customer_metadata_properties = customer_metadata_properties self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -435,6 +439,7 @@ def arguments(self) -> RequestType: description=self.description, tags=self.tags, container_def_list=self.container_def_list, + customer_metadata_properties=self.customer_metadata_properties, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index f4606488b2..27060d928e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -75,6 +75,7 @@ def __init__( tags=None, model: Union[Model, PipelineModel] = None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -95,7 +96,7 @@ def __init__( for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for register model step - model_package_group_name (str): The Model Package Group name, exclusive to + model_package_group_name (str): The Model Package Group name or Arn, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). @@ -113,6 +114,9 @@ def __init__( model (object or Model): A PipelineModel object that comprises a list of models which gets executed as a serial inference pipeline or a Model object. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -229,6 +233,7 @@ def __init__( tags=tags, container_def_list=self.container_def_list, retry_policies=register_model_step_retry_policies, + customer_metadata_properties=customer_metadata_properties, **kwargs, ) if not repack_model: diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 160f9f934b..14c2cf54b3 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1952,6 +1952,7 @@ def test_model_registration_with_drift_check_baselines( content_type="application/json", ), ) + customer_metadata_properties = {"key1": "value1"} estimator = XGBoost( entry_point="training.py", source_dir=os.path.join(DATA_DIR, "sip"), @@ -1973,6 +1974,7 @@ def test_model_registration_with_drift_check_baselines( model_package_group_name="testModelPackageGroup", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) pipeline = Pipeline( @@ -2043,6 +2045,7 @@ def test_model_registration_with_drift_check_baselines( response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"] == "application/json" ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties break finally: try: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8604835890..4523253a7f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): marketplace_cert = (True,) approval_status = ("Approved",) description = "description" + customer_metadata_properties = {"key1": "value1"} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): approval_status=approval_status, description=description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) expected_args = { "ModelPackageName": model_package_name, @@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From 663983d6b4299c5faee0b6b448a1d10d1e97e22e Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Fri, 18 Feb 2022 17:24:40 -0500 Subject: [PATCH 02/14] feat: jumpstart model id suggestions (#2899) Co-authored-by: Navin Soni Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Co-authored-by: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> --- src/sagemaker/jumpstart/cache.py | 31 +++++++-- src/sagemaker/jumpstart/types.py | 12 ++-- src/sagemaker/jumpstart/validators.py | 22 +++--- .../jumpstart/test_validate.py | 67 ++++++++++++++----- tests/unit/sagemaker/jumpstart/test_cache.py | 34 ++++++++-- 5 files changed, 123 insertions(+), 43 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 26284419de..25d3b37fcb 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -13,6 +13,7 @@ """This module defines the JumpStartModelsCache class.""" from __future__ import absolute_import import datetime +from difflib import get_close_matches from typing import List, Optional import json import boto3 @@ -204,14 +205,34 @@ def _get_manifest_key_from_model_id_semantic_version( sm_version_to_use = sm_version_to_use_list[0] error_msg = ( - f"Unable to find model manifest for {model_id} with version {version} " - f"compatible with your SageMaker version ({sm_version}). " + f"Unable to find model manifest for '{model_id}' with version '{version}' " + f"compatible with your SageMaker version ('{sm_version}'). " f"Consider upgrading your SageMaker library to at least version " - f"{sm_version_to_use} so you can use version " - f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." + f"'{sm_version_to_use}' so you can use version " + f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'." ) raise KeyError(error_msg) - error_msg = f"Unable to find model manifest for {model_id} with version {version}." + + error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " + error_msg += ( + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html" + " for updated list of models. " + ) + + other_model_id_version = self._select_version( + "*", versions_incompatible_with_sagemaker + ) # all versions here are incompatible with sagemaker + if other_model_id_version is not None: + error_msg += ( + f"Consider using model ID '{model_id}' with version " + f"'{other_model_id_version}'." + ) + + else: + possible_model_ids = [header.model_id for header in manifest.values()] + closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0] + error_msg += f"Did you mean to use model ID '{closest_model_id}'?" + raise KeyError(error_msg) def _get_file_from_s3( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 7c36795652..b9384ca042 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -135,12 +135,12 @@ def from_json(self, json_obj: Dict[str, str]) -> None: class JumpStartECRSpecs(JumpStartDataHolderType): """Data class for JumpStart ECR specs.""" - __slots__ = { + __slots__ = [ "framework", "framework_version", "py_version", "huggingface_transformers_version", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartECRSpecs object from its json representation. @@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]: class JumpStartHyperparameter(JumpStartDataHolderType): """Data class for JumpStart hyperparameter definition in the training container.""" - __slots__ = { + __slots__ = [ "name", "type", "options", @@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartHyperparameter object from its json representation. @@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]: class JumpStartEnvironmentVariable(JumpStartDataHolderType): """Data class for JumpStart environment variable definitions in the hosting container.""" - __slots__ = { + __slots__ = [ "name", "type", "default", "scope", - } + ] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartEnvironmentVariable object from its json representation. diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 10c5b38a81..65268388c3 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -49,7 +49,7 @@ def _validate_hyperparameter( if len(hyperparameter_spec) > 1: raise JumpStartHyperparametersError( - f"Unable to perform validation -- found multiple hyperparameter " + "Unable to perform validation -- found multiple hyperparameter " f"'{hyperparameter_name}' in model specs." ) @@ -76,35 +76,35 @@ def _validate_hyperparameter( if hyperparameter_value not in hyperparameter_spec.options: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have one of the following " - f"values: {', '.join(hyperparameter_spec.options)}" + f"values: {', '.join(hyperparameter_spec.options)}." ) if hasattr(hyperparameter_spec, "min"): if len(hyperparameter_value) < hyperparameter_spec.min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length no less than " - f"{hyperparameter_spec.min}" + f"{hyperparameter_spec.min}." ) if hasattr(hyperparameter_spec, "exclusive_min"): if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length greater than " - f"{hyperparameter_spec.exclusive_min}" + f"{hyperparameter_spec.exclusive_min}." ) if hasattr(hyperparameter_spec, "max"): if len(hyperparameter_value) > hyperparameter_spec.max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length no greater than " - f"{hyperparameter_spec.max}" + f"{hyperparameter_spec.max}." ) if hasattr(hyperparameter_spec, "exclusive_max"): if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must have length less than " - f"{hyperparameter_spec.exclusive_max}" + f"{hyperparameter_spec.exclusive_max}." ) # validate numeric types @@ -125,35 +125,35 @@ def _validate_hyperparameter( if not hyperparameter_value_str[start_index:].isdigit(): raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be integer type " - "('{hyperparameter_value}')." + f"('{hyperparameter_value}')." ) if hasattr(hyperparameter_spec, "min"): if numeric_hyperparam_value < hyperparameter_spec.min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' can be no less than " - "{hyperparameter_spec.min}." + f"{hyperparameter_spec.min}." ) if hasattr(hyperparameter_spec, "max"): if numeric_hyperparam_value > hyperparameter_spec.max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' can be no greater than " - "{hyperparameter_spec.max}." + f"{hyperparameter_spec.max}." ) if hasattr(hyperparameter_spec, "exclusive_min"): if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be greater than " - "{hyperparameter_spec.exclusive_min}." + f"{hyperparameter_spec.exclusive_min}." ) if hasattr(hyperparameter_spec, "exclusive_max"): if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max: raise JumpStartHyperparametersError( f"Hyperparameter '{hyperparameter_name}' must be less than " - "{hyperparameter_spec.exclusive_max}." + f"{hyperparameter_spec.exclusive_max}." ) diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index ddeeccba1d..83092f74e5 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -147,49 +147,54 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) hyperparameter_to_test["batch-size"] = "0" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' " "can be no less than 1.") hyperparameter_to_test["batch-size"] = "-1" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' can be no " "less than 1.") hyperparameter_to_test["batch-size"] = "-1.5" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' must be " "integer type ('-1.5').") hyperparameter_to_test["batch-size"] = "1.5" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' must be integer " "type ('1.5').") hyperparameter_to_test["batch-size"] = "99999" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ("Hyperparameter 'batch-size' can be no greater " "than 1024.") hyperparameter_to_test["batch-size"] = 5 hyperparameters.validate( @@ -210,13 +215,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [None, "", 5, "Truesday", "Falsehood"]: hyperparameter_to_test["test_bool_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Expecting boolean valued hyperparameter, " f"but got '{str(val)}'." + ) + hyperparameter_to_test["test_bool_param"] = original_bool_val original_exclusive_min_val = hyperparameter_to_test["test_exclusive_min_param"] @@ -230,13 +239,16 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [1, 1 - 1e-99, -99]: hyperparameter_to_test["test_exclusive_min_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_exclusive_min_param' must " "be greater than 1." + ) hyperparameter_to_test["test_exclusive_min_param"] = original_exclusive_min_val original_exclusive_max_val = hyperparameter_to_test["test_exclusive_max_param"] @@ -250,13 +262,15 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [4, 5, 99]: hyperparameter_to_test["test_exclusive_max_param"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == "Hyperparameter 'test_exclusive_max_param' must be less than 4." + hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val original_exclusive_max_text_val = hyperparameter_to_test["test_exclusive_max_param_text"] @@ -270,13 +284,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["123456", "123456789"]: hyperparameter_to_test["test_exclusive_max_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert ( + str(e.value) + == "Hyperparameter 'test_exclusive_max_param_text' must have length less than 6." + ) hyperparameter_to_test["test_exclusive_max_param_text"] = original_exclusive_max_text_val original_max_text_val = hyperparameter_to_test["test_max_param_text"] @@ -290,13 +308,17 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["1234567", "123456789"]: hyperparameter_to_test["test_max_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert ( + str(e.value) + == "Hyperparameter 'test_max_param_text' must have length no greater than 6." + ) hyperparameter_to_test["test_max_param_text"] = original_max_text_val original_exclusive_min_text_val = hyperparameter_to_test["test_exclusive_min_param_text"] @@ -310,13 +332,16 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in ["1", "d", ""]: hyperparameter_to_test["test_exclusive_min_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_exclusive_min_param_text' must have length greater " "than 1." + ) hyperparameter_to_test["test_exclusive_min_param_text"] = original_exclusive_min_text_val original_min_text_val = hyperparameter_to_test["test_min_param_text"] @@ -330,24 +355,31 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) for val in [""]: hyperparameter_to_test["test_min_param_text"] = val - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'test_min_param_text' " "must have length no less than 1." + ) hyperparameter_to_test["test_min_param_text"] = original_min_text_val del hyperparameter_to_test["batch-size"] hyperparameter_to_test["penalty"] = "blah" - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, ) + assert str(e.value) == ( + "Hyperparameter 'penalty' must have one of the following values: l1, l2, elasticnet," + " none." + ) hyperparameter_to_test["penalty"] = "elasticnet" hyperparameters.validate( @@ -411,7 +443,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) del hyperparameter_to_test["adam-learning-rate"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -419,6 +451,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, ) + assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -454,7 +487,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): del hyperparameter_to_test["sagemaker_submit_directory"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -462,13 +495,14 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, ) + assert str(e.value) == "Cannot find hyperparameter for 'sagemaker_submit_directory'." hyperparameter_to_test[ "sagemaker_submit_directory" ] = "/opt/ml/input/data/code/sourcedir.tar.gz" del hyperparameter_to_test["epochs"] - with pytest.raises(JumpStartHyperparametersError): + with pytest.raises(JumpStartHyperparametersError) as e: hyperparameters.validate( region=region, model_id=model_id, @@ -476,6 +510,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, ) + assert str(e.value) == "Cannot find hyperparameter for 'epochs'." hyperparameter_to_test["epochs"] = "3" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 761b53d469..93e8114185 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -151,17 +151,41 @@ def test_jumpstart_cache_get_header(): semantic_version_str="3.*", ) assert ( - "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 " - "with version 3.* compatible with your SageMaker version (2.68.3). Consider upgrading " - "your SageMaker library to at least version 4.49.0 so you can use version 3.0.0 of " - "tensorflow-ic-imagenet-inception-v3-classification-4." in str(e.value) + "Unable to find model manifest for 'tensorflow-ic-imagenet-inception-v3-classification-4' " + "with version '3.*' compatible with your SageMaker version ('2.68.3'). Consider upgrading " + "your SageMaker library to at least version '4.49.0' so you can use version '3.0.0' of " + "'tensorflow-ic-imagenet-inception-v3-classification-4'." in str(e.value) ) with pytest.raises(KeyError) as e: cache.get_header( model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="3.*" ) - assert "Consider upgrading" not in str(e.value) + assert ( + "Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with " + "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-" + "classification-4' with version '2.0.0'." + ) in str(e.value) + + with pytest.raises(KeyError) as e: + cache.get_header(model_id="pytorch-ic-", semantic_version_str="*") + assert ( + "Unable to find model manifest for 'pytorch-ic-' with version '*'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. " + "Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?" + ) in str(e.value) + + with pytest.raises(KeyError) as e: + cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*") + assert ( + "Unable to find model manifest for 'tensorflow-ic-' with version '*'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "for updated list of models. " + "Did you mean to use model ID 'tensorflow-ic-imagenet-inception-" + "v3-classification-4'?" + ) in str(e.value) with pytest.raises(KeyError): cache.get_header( From d529475f2639785a6cf193276d578b8f8658729f Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Fri, 18 Feb 2022 19:48:53 -0500 Subject: [PATCH 03/14] feat: override jumpstart content bucket (#2901) Co-authored-by: Navin Soni Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Co-authored-by: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> --- doc/overview.rst | 5 ++--- src/sagemaker/jumpstart/constants.py | 2 ++ src/sagemaker/jumpstart/utils.py | 9 +++++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 13 +++++++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/doc/overview.rst b/doc/overview.rst index 39f5f6ecae..df320e3b47 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -746,6 +746,7 @@ see `Model str: Raises: RuntimeError: If JumpStart is not launched in ``region``. """ + + if ( + constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] + LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) + return bucket_override try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe494eb459..04eddced08 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -11,11 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock.mock import Mock, patch import pytest import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, JumpStartScriptScope, @@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket(): utils.get_jumpstart_content_bucket(bad_region) +def test_get_jumpstart_content_bucket_override(): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart bucket override: '%s'", + "some-val", + ) + + def test_get_jumpstart_launched_regions_message(): with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): From a98a6d87b0da8d6b8a0a734617af27023db0af7d Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Mon, 21 Feb 2022 13:01:58 -0800 Subject: [PATCH 04/14] fix: Support primitive types for left value of ConditionSteps (#2886) Co-authored-by: Payton Staub Co-authored-by: Navin Soni --- src/sagemaker/workflow/conditions.py | 2 +- tests/unit/sagemaker/workflow/test_conditions.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 2e2849cc80..065cf01315 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -79,7 +79,7 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "Type": self.condition_type.value, - "LeftValue": self.left.expr, + "LeftValue": primitive_or_expr(self.left), "RightValue": primitive_or_expr(self.right), } diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index d473b36121..f4bea55b6e 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -165,3 +165,12 @@ def test_condition_or(): }, ], } + + +def test_left_and_right_primitives(): + cond = ConditionEquals(left=2, right=1) + assert cond.to_request() == { + "Type": "Equals", + "LeftValue": 2, + "RightValue": 1, + } From a928c0a7898c06b10dba57170ecfb5a5e594eeb3 Mon Sep 17 00:00:00 2001 From: Yifei Zhu <66866419+yzhu0@users.noreply.github.com> Date: Mon, 21 Feb 2022 13:59:13 -0800 Subject: [PATCH 05/14] fix: Add lineage doc (#2937) --- doc/workflows/index.rst | 1 + doc/workflows/lineage/index.rst | 11 ++++ doc/workflows/lineage/sagemaker.lineage.rst | 70 +++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 doc/workflows/lineage/index.rst create mode 100644 doc/workflows/lineage/sagemaker.lineage.rst diff --git a/doc/workflows/index.rst b/doc/workflows/index.rst index 4c0a92c411..bde061aaaf 100644 --- a/doc/workflows/index.rst +++ b/doc/workflows/index.rst @@ -10,3 +10,4 @@ The SageMaker Python SDK supports managed training and inference for a variety o airflow/index step_functions/index pipelines/index + lineage/index diff --git a/doc/workflows/lineage/index.rst b/doc/workflows/lineage/index.rst new file mode 100644 index 0000000000..2f64df0c82 --- /dev/null +++ b/doc/workflows/lineage/index.rst @@ -0,0 +1,11 @@ +################### +SageMaker Lineage +################### +Amazon SageMaker ML Lineage Tracking creates and stores information about the steps of a machine learning (ML) workflow from data preparation to model deployment. With the tracking information, you can reproduce the workflow steps, track model and dataset lineage, and establish model governance and audit standards. + +SageMaker APIs for creating and managing SageMaker Lineage. + +.. toctree:: + :maxdepth: 2 + + sagemaker.lineage diff --git a/doc/workflows/lineage/sagemaker.lineage.rst b/doc/workflows/lineage/sagemaker.lineage.rst new file mode 100644 index 0000000000..20ae7dd8fd --- /dev/null +++ b/doc/workflows/lineage/sagemaker.lineage.rst @@ -0,0 +1,70 @@ +Lineage +========= + + +Artifact +------------- + +.. autoclass:: sagemaker.lineage.artifact.Artifact + +.. autoclass:: sagemaker.lineage.artifact.ModelArtifact + +.. autoclass:: sagemaker.lineage.artifact.DatasetArtifact + +.. autoclass:: sagemaker.lineage.artifact.ImageArtifact + + +Actions +------------- + +.. autoclass:: sagemaker.lineage.action.Action + +.. autoclass:: sagemaker.lineage.action.ModelPackageApprovalAction + + +Association +------------- + +.. autoclass:: sagemaker.lineage.association.Association + + +Context +------------- + +.. autoclass:: sagemaker.lineage.context.Context + +.. autoclass:: sagemaker.lineage.context.EndpointContext + +.. autoclass:: sagemaker.lineage.context.ModelPackageGroup + + +Lineage Trial Component +-------------------------- + +.. autoclass:: sagemaker.lineage.lineage_trial_component.LineageTrialComponent + + +Query +------------- + +.. autoclass:: sagemaker.lineage.query.LineageEntityEnum + +.. autoclass:: sagemaker.lineage.query.LineageSourceEnum + +.. autoclass:: sagemaker.lineage.query.LineageQueryDirectionEnum + +.. autoclass:: sagemaker.lineage.query.Edge + +.. autoclass:: sagemaker.lineage.query.Vertex + +.. autoclass:: sagemaker.lineage.query.LineageQueryResult + +.. autoclass:: sagemaker.lineage.query.LineageFilter + +.. autoclass:: sagemaker.lineage.query.LineageQuery + + +Visualizer +------------- + +.. autoclass:: sagemaker.lineage.visualizer.LineageTableVisualizer From 668359f68011c7a4eb137e8212a3985ca962088e Mon Sep 17 00:00:00 2001 From: Yifei Zhu <66866419+yzhu0@users.noreply.github.com> Date: Mon, 21 Feb 2022 18:12:14 -0800 Subject: [PATCH 06/14] fix: update lineage_trial_compoment get pipeline execution arn (#2944) Co-authored-by: Shreya Pandit --- .../lineage/lineage_trial_component.py | 9 +- tests/integ/sagemaker/lineage/conftest.py | 20 ++-- .../lineage/test_lineage_trial_component.py | 94 +++++++++++++++++-- 3 files changed, 104 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/lineage/lineage_trial_component.py b/src/sagemaker/lineage/lineage_trial_component.py index f8bc0e53b4..1e02e83657 100644 --- a/src/sagemaker/lineage/lineage_trial_component.py +++ b/src/sagemaker/lineage/lineage_trial_component.py @@ -130,8 +130,15 @@ def pipeline_execution_arn(self) -> str: Returns: str: A pipeline execution ARN. """ + trial_component = self.load( + trial_component_name=self.trial_component_name, sagemaker_session=self.sagemaker_session + ) + + if trial_component.source is None or trial_component.source["SourceArn"] is None: + return None + tags = self.sagemaker_session.sagemaker_client.list_tags( - ResourceArn=self.trial_component_arn + ResourceArn=trial_component.source["SourceArn"] )["Tags"] for tag in tags: if tag["Key"] == "sagemaker:pipeline-execution-arn": diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 0139a5b658..4ede5c193d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -233,7 +233,7 @@ def upstream_trial_associated_artifact( sagemaker_session=sagemaker_session, ) trial_obj.add_trial_component(trial_component_obj) - time.sleep(3) + time.sleep(4) yield artifact_obj trial_obj.remove_trial_component(trial_component_obj) assntn.delete() @@ -561,14 +561,14 @@ def static_approval_action( @pytest.fixture -def static_model_deployment_action(sagemaker_session, static_endpoint_context): +def static_model_deployment_action(sagemaker_session, static_processing_job_trial_component): query_filter = LineageFilter( entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_processing_job_trial_component.trial_component_arn], query_filter=query_filter, - direction=LineageQueryDirectionEnum.ASCENDANTS, + direction=LineageQueryDirectionEnum.DESCENDANTS, include_edges=False, ) model_approval_actions = [] @@ -579,14 +579,14 @@ def static_model_deployment_action(sagemaker_session, static_endpoint_context): @pytest.fixture def static_processing_job_trial_component( - sagemaker_session, static_endpoint_context + sagemaker_session, static_dataset_artifact ) -> LineageTrialComponent: query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, @@ -600,14 +600,14 @@ def static_processing_job_trial_component( @pytest.fixture def static_training_job_trial_component( - sagemaker_session, static_endpoint_context + sagemaker_session, static_model_artifact ) -> LineageTrialComponent: query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_endpoint_context.context_arn], + start_arns=[static_model_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, @@ -738,12 +738,12 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session): @pytest.fixture -def static_image_artifact(static_model_artifact, sagemaker_session): +def static_image_artifact(static_dataset_artifact, sagemaker_session): query_filter = LineageFilter( entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE] ) query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_model_artifact.artifact_arn], + start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, diff --git a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py index 9b466832a1..5755f512f9 100644 --- a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py +++ b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py @@ -114,9 +114,28 @@ def test_pipeline_execution_arn(sagemaker_session): trial_component_arn = ( "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" ) - obj = lineage_trial_component.LineageTrialComponent( - sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj sagemaker_session.sagemaker_client.list_tags.return_value = { "Tags": [ @@ -124,9 +143,10 @@ def test_pipeline_execution_arn(sagemaker_session): ], } expected_calls = [ - unittest.mock.call(ResourceArn=trial_component_arn), + unittest.mock.call(ResourceArn=training_job_arn), ] - pipeline_execution_arn_result = obj.pipeline_execution_arn() + pipeline_execution_arn_result = context.pipeline_execution_arn() + assert pipeline_execution_arn_result == "tag1" assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls @@ -135,9 +155,28 @@ def test_no_pipeline_execution_arn(sagemaker_session): trial_component_arn = ( "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" ) - obj = lineage_trial_component.LineageTrialComponent( - sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj sagemaker_session.sagemaker_client.list_tags.return_value = { "Tags": [ @@ -145,9 +184,48 @@ def test_no_pipeline_execution_arn(sagemaker_session): ], } expected_calls = [ - unittest.mock.call(ResourceArn=trial_component_arn), + unittest.mock.call(ResourceArn=training_job_arn), ] - pipeline_execution_arn_result = obj.pipeline_execution_arn() + pipeline_execution_arn_result = context.pipeline_execution_arn() + expected_result = None + assert pipeline_execution_arn_result == expected_result + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls + + +def test_no_source_arn_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain" + ) + context = lineage_trial_component.LineageTrialComponent( + sagemaker_session, + trial_component_name="foo", + trial_component_arn=trial_component_arn, + source={ + "SourceArn": training_job_arn, + "SourceType": "SageMakerTrainingJob", + }, + ) + obj = { + "TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "TrialComponentArn": trial_component_arn, + "DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job", + "Source": { + "SourceArn": None, + "SourceType": None, + }, + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "abcd", "Value": "efg"}, + ], + } + expected_calls = [] + pipeline_execution_arn_result = context.pipeline_execution_arn() expected_result = None assert pipeline_execution_arn_result == expected_result assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls From 0e4fd556833661b3a544081ac1633ce2306394cf Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 21 Feb 2022 19:02:02 -0800 Subject: [PATCH 07/14] fix: Improve Pipeline workflow unit test branch coverage (#2878) Co-authored-by: Dewen Qi --- src/sagemaker/workflow/_repack_model.py | 4 +- src/sagemaker/workflow/_utils.py | 2 +- src/sagemaker/workflow/condition_step.py | 2 +- src/sagemaker/workflow/functions.py | 4 +- src/sagemaker/workflow/lambda_step.py | 2 +- src/sagemaker/workflow/step_collections.py | 14 +- tests/integ/test_workflow.py | 5 +- tests/integ/test_workflow_with_clarify.py | 5 +- .../{ => sagemaker/workflow}/test_airflow.py | 0 .../unit/sagemaker/workflow/test_functions.py | 22 +++ .../sagemaker/workflow/test_lambda_step.py | 78 +++++++++-- .../workflow/test_step_collections.py | 131 +++++++++++++++++- tests/unit/sagemaker/workflow/test_steps.py | 34 ++++- tests/unit/sagemaker/workflow/test_utils.py | 5 +- 14 files changed, 273 insertions(+), 35 deletions(-) rename tests/unit/{ => sagemaker/workflow}/test_airflow.py (100%) diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 6ce7e41831..f98f170f39 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -34,7 +34,7 @@ from distutils.dir_util import copy_tree -def repack(inference_script, model_archive, dependencies=None, source_dir=None): +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive Args: @@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): copy_tree(src_dir, "/opt/ml/model") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover parser = argparse.ArgumentParser() parser.add_argument("--inference_script", type=str, default="inference.py") parser.add_argument("--dependencies", type=str, default=None) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index d341af211d..fbbb6acba9 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -80,7 +80,7 @@ def __init__( artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file (default: None). + ``.tar.gz`` file. entry_point (str): Path (absolute or relative) to the local Python source file which should be executed as the entry point to inference. If ``source_dir`` is specified, then ``entry_point`` diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index a34330d94d..a2597c07f9 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -95,7 +95,7 @@ def properties(self): @attr.s -class JsonGet(Expression): +class JsonGet(Expression): # pragma: no cover """Get JSON properties from PropertyFiles. Attributes: diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 03ac099d18..e0076322de 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -75,8 +75,8 @@ class JsonGet(Expression): @property def expr(self): """The expression dict for a `JsonGet` function.""" - if not isinstance(self.step_name, str): - raise ValueError("Please give step name as a string") + if not isinstance(self.step_name, str) or not self.step_name: + raise ValueError("Please give a valid step name as a string") if isinstance(self.property_file, PropertyFile): name = self.property_file.name diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 0446a0b46c..5240ae60b9 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -161,8 +161,8 @@ def _get_function_arn(self): partition = "aws" if self.lambda_func.function_arn is None: + account_id = self.lambda_func.session.account_id() try: - account_id = self.lambda_func.session.account_id() response = self.lambda_func.create() return response["FunctionArn"] except ValueError as error: diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 27060d928e..1280637006 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -323,15 +323,15 @@ def __init__( """ steps = [] if "entry_point" in kwargs: - entry_point = kwargs["entry_point"] - source_dir = kwargs.get("source_dir") - dependencies = kwargs.get("dependencies") + entry_point = kwargs.get("entry_point", None) + source_dir = kwargs.get("source_dir", None) + dependencies = kwargs.get("dependencies", None) repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, - role=estimator.sagemaker_session, + role=estimator.role, model_data=model_data, entry_point=entry_point, source_dir=source_dir, @@ -357,7 +357,11 @@ def predict_wrapper(endpoint, session): vpc_config=None, sagemaker_session=estimator.sagemaker_session, role=estimator.role, - **kwargs, + env=kwargs.get("env", None), + name=kwargs.get("name", None), + enable_network_isolation=kwargs.get("enable_network_isolation", None), + model_kms_key=kwargs.get("model_kms_key", None), + image_config=kwargs.get("image_config", None), ) model_step = CreateModelStep( name=f"{name}CreateModelStep", diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 14c2cf54b3..dd24149ca4 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -67,7 +67,6 @@ ConditionLessThanOrEqualTo, ) from sagemaker.workflow.condition_step import ConditionStep -from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet from sagemaker.workflow.callback_step import ( CallbackStep, CallbackOutput, @@ -2835,8 +2834,8 @@ def test_end_to_end_pipeline_successful_execution( # define condition step cond_lte = ConditionLessThanOrEqualTo( - left=ConditionStepJsonGet( - step=step_eval, + left=JsonGet( + step_name=step_eval.name, property_file=evaluation_report, json_path="regression_metrics.mse.value", ), diff --git a/tests/integ/test_workflow_with_clarify.py b/tests/integ/test_workflow_with_clarify.py index 0c41b2212a..486abab89b 100644 --- a/tests/integ/test_workflow_with_clarify.py +++ b/tests/integ/test_workflow_with_clarify.py @@ -33,7 +33,8 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.session import get_execution_role from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo -from sagemaker.workflow.condition_step import ConditionStep, JsonGet +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.functions import JsonGet from sagemaker.workflow.parameters import ( ParameterInteger, ParameterString, @@ -237,7 +238,7 @@ def test_workflow_with_clarify( ) cond_left = JsonGet( - step=step_process, + step_name=step_process.name, property_file="BiasOutput", json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value", ) diff --git a/tests/unit/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py similarity index 100% rename from tests/unit/test_airflow.py rename to tests/unit/sagemaker/workflow/test_airflow.py diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 8e5d6b6d31..9b07a41d09 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.parameters import ( @@ -97,3 +99,23 @@ def test_json_get_expressions(): "Path": "my-json-path", }, } + + +def test_json_get_expressions_with_invalid_step_name(): + with pytest.raises(ValueError) as err: + JsonGet( + step_name="", + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) + + with pytest.raises(ValueError) as err: + JsonGet( + step_name=ParameterString(name="MyString"), + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index 0566e39318..bdaa781b1c 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -22,6 +22,7 @@ from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from sagemaker.lambda_helper import Lambda +from sagemaker.workflow.steps import CacheConfig @pytest.fixture() @@ -38,10 +39,25 @@ def sagemaker_session(): return session_mock +@pytest.fixture() +def sagemaker_session_cn(): + boto_mock = Mock(name="boto_session", region_name="cn-north-1") + session_mock = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="cn-north-1", + config=None, + local_mode=False, + ) + session_mock.account_id.return_value = "234567890123" + return session_mock + + def test_lambda_step(sagemaker_session): param = ParameterInteger(name="MyInt") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") lambda_step = LambdaStep( name="MyLambdaStep", depends_on=["TestStep"], @@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session): display_name="MyLambdaStep", description="MyLambdaStepDescription", inputs={"arg1": "foo", "arg2": 5, "arg3": param}, - outputs=[outputParam1, outputParam2], + outputs=[output_param1, output_param2], + cache_config=cache_config, ) lambda_step.add_depends_on(["SecondTestStep"]) - assert lambda_step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[param], + steps=[lambda_step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyLambdaStep", "Type": "Lambda", "DependsOn": ["TestStep", "SecondTestStep"], @@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session): {"OutputName": "output1", "OutputType": "String"}, {"OutputName": "output2", "OutputType": "Boolean"}, ], - "Arguments": {"arg1": "foo", "arg2": 5, "arg3": param}, + "Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } @@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session): def test_pipeline_interpolates_lambda_outputs(sagemaker_session): parameter = ParameterString("MyStr") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) lambda_step1 = LambdaStep( name="MyLambdaStep1", depends_on=["TestStep"], @@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): session=sagemaker_session, ), inputs={"arg1": "foo"}, - outputs=[outputParam1], + outputs=[output_param1], ) lambda_step2 = LambdaStep( name="MyLambdaStep2", @@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, ), - inputs={"arg1": outputParam1}, - outputs=[outputParam2], + inputs={"arg1": output_param1}, + outputs=[output_param2], ) pipeline = Pipeline( @@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session): ) lambda_step._get_function_arn() sagemaker_session.account_id.assert_called_once() + + +def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn): + lambda_func = MagicMock( + function_arn=None, + function_name="name", + execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role", + zipped_code_dir="", + handler="", + session=sagemaker_session_cn, + ) + # The raised ValueError contains ResourceConflictException + lambda_func.create.side_effect = ValueError("ResourceConflictException") + lambda_step1 = LambdaStep( + name="MyLambdaStep1", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + function_arn = lambda_step1._get_function_arn() + assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name" + + # The raised ValueError does not contain ResourceConflictException + lambda_func.create.side_effect = ValueError() + lambda_step2 = LambdaStep( + name="MyLambdaStep2", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + with pytest.raises(ValueError): + lambda_step2._get_function_arn() diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 6c78412b22..d2f1f07059 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -19,6 +19,7 @@ import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR import sagemaker @@ -206,6 +207,16 @@ def test_step_collection(): ] +def test_step_collection_with_list_to_request(): + step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")]) + custom_step = CustomStep("MyStep3") + assert list_to_request([step_collection, custom_step]) == [ + {"Name": "MyStep1", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep2", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep3", "Type": "Training", "Arguments": dict()}, + ] + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel( @@ -216,6 +227,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): response_types=["response_type"], inference_instances=["inference_instance"], transform_instances=["transform_instance"], + image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", model_package_group_name="mpg", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, @@ -236,7 +248,10 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "Arguments": { "InferenceSpecification": { "Containers": [ - {"Image": "fakeimage", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz"} + { + "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", + "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", + } ], "SupportedContentTypes": ["content_type"], "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], @@ -865,3 +880,117 @@ def test_estimator_transformer(estimator): } else: raise Exception("A step exists in the collection of an invalid type.") + + +def test_estimator_transformer_with_model_repack_with_estimator(estimator): + model_data = f"s3://{BUCKET}/model.tar.gz" + model_inputs = CreateModelInput( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) + transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + estimator_transformer = EstimatorTransformer( + name="EstimatorTransformerStep", + estimator=estimator, + model_data=model_data, + model_inputs=model_inputs, + instance_count=1, + instance_type="ml.c4.4xlarge", + transform_inputs=transform_inputs, + depends_on=["TestStep"], + model_step_retry_policies=[service_fault_retry_policy], + transform_step_retry_policies=[service_fault_retry_policy], + repack_model_step_retry_policies=[service_fault_retry_policy], + entry_point=f"{DATA_DIR}/dummy_script.py", + ) + request_dicts = estimator_transformer.request_dicts() + assert len(request_dicts) == 3 + + for request_dict in request_dicts: + if request_dict["Type"] == "Training": + assert request_dict["Name"] == "EstimatorTransformerStepRepackModel" + assert request_dict["DependsOn"] == ["TestStep"] + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + # pop out the dynamic generated fields + arguments["HyperParameters"].pop("sagemaker_submit_directory") + arguments["HyperParameters"].pop("sagemaker_job_name") + assert arguments == { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + + "sagemaker-scikit-learn:0.23-1-cpu-py3", + }, + "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 30, + }, + "RoleArn": "DummyRole", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-bucket", + "S3DataDistributionType": "FullyReplicated", + } + }, + "ChannelName": "training", + } + ], + "HyperParameters": { + "inference_script": '"dummy_script.py"', + "model_archive": '"model.tar.gz"', + "dependencies": "null", + "source_dir": "null", + "sagemaker_program": '"_repack_model.py"', + "sagemaker_container_log_level": "20", + "sagemaker_region": '"us-west-2"', + }, + "VpcConfig": {"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, + "DebugHookConfig": { + "S3OutputPath": "s3://my-bucket/", + "CollectionConfigurations": [], + }, + } + elif request_dict["Type"] == "Model": + assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) + arguments["PrimaryContainer"].pop("ModelDataUrl") + assert "DependsOn" not in request_dict + assert arguments == { + "ExecutionRoleArn": "DummyRole", + "PrimaryContainer": { + "Environment": {}, + "Image": "fakeimage", + }, + } + elif request_dict["Type"] == "Transform": + assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["ModelName"], Properties) + arguments.pop("ModelName") + assert "DependsOn" not in request_dict + assert arguments == { + "TransformInput": { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{BUCKET}/transform_manifest", + } + } + }, + "TransformOutput": {"S3OutputPath": None}, + "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"}, + } + else: + raise Exception("A step exists in the collection of an invalid type.") diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index e3dc10e23e..674c715617 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json + import pytest import sagemaker import os @@ -43,7 +45,8 @@ ) from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer -from sagemaker.workflow.properties import Properties +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.properties import Properties, PropertyFile from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.retry import ( StepRetryPolicy, @@ -535,6 +538,9 @@ def test_processing_step(sagemaker_session): ) ] cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + evaluation_report = PropertyFile( + name="EvaluationReport", output_name="evaluation", path="evaluation.json" + ) step = ProcessingStep( name="MyProcessingStep", description="ProcessingStep description", @@ -544,9 +550,20 @@ def test_processing_step(sagemaker_session): inputs=inputs, outputs=[], cache_config=cache_config, + property_files=[evaluation_report], ) step.add_depends_on(["ThirdTestStep"]) - assert step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[ + processing_input_data_uri_parameter, + instance_type_parameter, + instance_count_parameter, + ], + steps=[step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyProcessingStep", "Description": "ProcessingStep description", "DisplayName": "MyProcessingStep", @@ -564,20 +581,27 @@ def test_processing_step(sagemaker_session): "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3InputMode": "File", - "S3Uri": processing_input_data_uri_parameter, + "S3Uri": {"Get": "Parameters.ProcessingInputDataUri"}, }, } ], "ProcessingResources": { "ClusterConfig": { - "InstanceCount": instance_count_parameter, - "InstanceType": instance_type_parameter, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, + "InstanceType": {"Get": "Parameters.InstanceType"}, "VolumeSizeInGB": 30, } }, "RoleArn": "DummyRole", }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + "PropertyFiles": [ + { + "FilePath": "evaluation.json", + "OutputName": "evaluation", + "PropertyFileName": "EvaluationReport", + } + ], } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 5a2a9497f8..e534aa531e 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -26,6 +26,7 @@ ) from sagemaker.estimator import Estimator +from sagemaker.workflow import Properties from sagemaker.workflow._utils import _RepackModelStep from tests.unit import DATA_DIR @@ -156,7 +157,7 @@ def test_repack_model_step(estimator): def test_repack_model_step_with_source_dir(estimator, source_dir): - model_data = f"s3://{BUCKET}/model.tar.gz" + model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" step = _RepackModelStep( name="MyRepackModelStep", @@ -189,7 +190,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", - "S3Uri": f"s3://{BUCKET}", + "S3Uri": model_data, } }, } From 88ac6c69d732ee32ee381386acdad1166be5d086 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Tue, 22 Feb 2022 15:03:38 -0800 Subject: [PATCH 08/14] Fix for #2949. Extend the set of types that are not handled by json.dumps to elementary types. --- src/sagemaker/estimator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fd74633584..8838138c84 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -598,7 +598,9 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A current_hyperparameters = hyperparameters if current_hyperparameters is not None: hyperparameters = { - str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v)) + str(k): (v if isinstance( + v, (str, int, float, bool, Parameter, Expression, Properties) + ) else json.dumps(v)) for (k, v) in current_hyperparameters.items() } return hyperparameters From dbdf1063d134dfeeed440ddf9fe64df3aba55f18 Mon Sep 17 00:00:00 2001 From: Christian Date: Tue, 22 Feb 2022 18:20:58 -0800 Subject: [PATCH 09/14] Update test_estimator.py Fixing tests to reflect desired behaviour. --- tests/unit/test_estimator.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 656d773914..906919d370 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -1241,9 +1241,7 @@ def test_custom_code_bucket(time, sagemaker_session): expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( - expected_submit_dir - ) + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir @patch("time.strftime", return_value=TIMESTAMP) @@ -1266,9 +1264,7 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session): expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( - expected_submit_dir - ) + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir def test_invalid_custom_code_bucket(sagemaker_session): @@ -1340,11 +1336,10 @@ def test_shuffle_config(sagemaker_session): BASE_HP = { - "sagemaker_program": json.dumps(SCRIPT_NAME), - "sagemaker_submit_directory": json.dumps( - "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME) - ), - "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_program": SCRIPT_NAME, + "sagemaker_submit_directory": + "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME), + "sagemaker_job_name": JOB_NAME, } @@ -1389,8 +1384,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session): t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) - expected_hyperparameters["learning_rate"] = json.dumps(0.1) + expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO + expected_hyperparameters["learning_rate"] = 0.1 expected_hyperparameters["123"] = json.dumps([456]) expected_hyperparameters["sagemaker_region"] = '"us-west-2"' @@ -1413,7 +1408,7 @@ def test_start_new_wait_called(strftime, sagemaker_session): t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) + expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO expected_hyperparameters["sagemaker_region"] = '"us-west-2"' actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"] From a82cb7a80af8440c6158bde50b37fc42d1d67d12 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Tue, 22 Feb 2022 21:35:00 -0800 Subject: [PATCH 10/14] Updating PR. Fixing omissions and formating. --- src/sagemaker/estimator.py | 8 +++++--- src/sagemaker/local/image.py | 2 +- tests/unit/test_estimator.py | 3 +-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 8838138c84..26a5b5a36b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -598,9 +598,11 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A current_hyperparameters = hyperparameters if current_hyperparameters is not None: hyperparameters = { - str(k): (v if isinstance( - v, (str, int, float, bool, Parameter, Expression, Properties) - ) else json.dumps(v)) + str(k): ( + v + if isinstance(v, (str, int, float, bool, Parameter, Expression, Properties)) + else json.dumps(v) + ) for (k, v) in current_hyperparameters.items() } return hyperparameters diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 7a10eeacc6..c4c6cddbd6 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -493,7 +493,7 @@ def _prepare_training_volumes( # If there is a training script directory and it is a local directory, # mount it to the container. if sagemaker.estimator.DIR_PARAM_NAME in hyperparameters: - training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME]) + training_dir = hyperparameters[sagemaker.estimator.DIR_PARAM_NAME] parsed_uri = urlparse(training_dir) if parsed_uri.scheme == "file": host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 906919d370..82a70e6eee 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -1337,8 +1337,7 @@ def test_shuffle_config(sagemaker_session): BASE_HP = { "sagemaker_program": SCRIPT_NAME, - "sagemaker_submit_directory": - "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME), + "sagemaker_submit_directory": "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME), "sagemaker_job_name": JOB_NAME, } From 30959492eddbde24febdd14876e111bf6261d123 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Wed, 23 Feb 2022 22:26:02 -0800 Subject: [PATCH 11/14] Removing spurious json.loads. Note: This might need a larger review, as the arbitrary use of json.dumps introduces a lot of ambiguity with respect to the underlying types. --- src/sagemaker/estimator.py | 11 ++++------- src/sagemaker/local/image.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 26a5b5a36b..d41a35af55 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2761,13 +2761,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na job_details, model_channel_name ) - init_params["entry_point"] = json.loads( - init_params["hyperparameters"].get(SCRIPT_PARAM_NAME) - ) - init_params["source_dir"] = json.loads(init_params["hyperparameters"].get(DIR_PARAM_NAME)) - init_params["container_log_level"] = json.loads( - init_params["hyperparameters"].get(CONTAINER_LOG_LEVEL_PARAM_NAME) - ) + init_params["entry_point"] = init_params["hyperparameters"].get(SCRIPT_PARAM_NAME) + init_params["source_dir"] = init_params["hyperparameters"].get(DIR_PARAM_NAME) + init_params["container_log_level"] = init_params["hyperparameters"].get( + CONTAINER_LOG_LEVEL_PARAM_NAME) hyperparameters = {} for k, v in init_params["hyperparameters"].items(): diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index c4c6cddbd6..c794fc610e 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -579,7 +579,7 @@ def _update_local_src_path(self, params, key): The updated parameters. """ if key in params: - src_dir = json.loads(params[key]) + src_dir = params[key] parsed_uri = urlparse(src_dir) if parsed_uri.scheme == "file": new_params = params.copy() From 7cd161a1e66a94ac5a68c0b995c44a381f8366f7 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Tue, 1 Mar 2022 12:59:50 -0800 Subject: [PATCH 12/14] Hyperparameters are json encoded at the end of setting up training config. --- src/sagemaker/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index d41a35af55..e42905064a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -686,7 +686,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None: hyperparams[JOB_NAME_PARAM_NAME] = self._current_job_name hyperparams[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name - self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) + self._hyperparameters.update(hyperparams) def _stage_user_code_in_s3(self) -> str: """Uploads the user training script to S3 and returns the S3 URI. From e7d7085c610d0b35db17d329790ce94383d7e824 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Tue, 1 Mar 2022 12:59:50 -0800 Subject: [PATCH 13/14] Fix for aws#2949. Hyperparameters are only json encoded at the end of setting up a Sagemaker job. --- src/sagemaker/estimator.py | 12 +++++++----- src/sagemaker/local/image.py | 4 ++-- tests/unit/test_estimator.py | 23 ++++++++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e42905064a..83ff1ba945 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2760,11 +2760,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na init_params = super(Framework, cls)._prepare_init_params_from_job_description( job_details, model_channel_name ) - - init_params["entry_point"] = init_params["hyperparameters"].get(SCRIPT_PARAM_NAME) - init_params["source_dir"] = init_params["hyperparameters"].get(DIR_PARAM_NAME) - init_params["container_log_level"] = init_params["hyperparameters"].get( - CONTAINER_LOG_LEVEL_PARAM_NAME) + init_params["entry_point"] = json.loads( + init_params["hyperparameters"].get(SCRIPT_PARAM_NAME) + ) + init_params["source_dir"] = json.loads(init_params["hyperparameters"].get(DIR_PARAM_NAME)) + init_params["container_log_level"] = json.loads( + init_params["hyperparameters"].get(CONTAINER_LOG_LEVEL_PARAM_NAME) + ) hyperparameters = {} for k, v in init_params["hyperparameters"].items(): diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index c794fc610e..7a10eeacc6 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -493,7 +493,7 @@ def _prepare_training_volumes( # If there is a training script directory and it is a local directory, # mount it to the container. if sagemaker.estimator.DIR_PARAM_NAME in hyperparameters: - training_dir = hyperparameters[sagemaker.estimator.DIR_PARAM_NAME] + training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME]) parsed_uri = urlparse(training_dir) if parsed_uri.scheme == "file": host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) @@ -579,7 +579,7 @@ def _update_local_src_path(self, params, key): The updated parameters. """ if key in params: - src_dir = params[key] + src_dir = json.loads(params[key]) parsed_uri = urlparse(src_dir) if parsed_uri.scheme == "file": new_params = params.copy() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 82a70e6eee..1f8363cf62 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -1241,8 +1241,9 @@ def test_custom_code_bucket(time, sagemaker_session): expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir - + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( + expected_submit_dir + ) @patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket_without_prefix(time, sagemaker_session): @@ -1264,7 +1265,9 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session): expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( + expected_submit_dir + ) def test_invalid_custom_code_bucket(sagemaker_session): @@ -1336,9 +1339,11 @@ def test_shuffle_config(sagemaker_session): BASE_HP = { - "sagemaker_program": SCRIPT_NAME, - "sagemaker_submit_directory": "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME), - "sagemaker_job_name": JOB_NAME, + "sagemaker_program": json.dumps(SCRIPT_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME) + ), + "sagemaker_job_name": json.dumps(JOB_NAME), } @@ -1383,8 +1388,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session): t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO - expected_hyperparameters["learning_rate"] = 0.1 + expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) + expected_hyperparameters["learning_rate"] = json.dumps(0.1) expected_hyperparameters["123"] = json.dumps([456]) expected_hyperparameters["sagemaker_region"] = '"us-west-2"' @@ -1407,7 +1412,7 @@ def test_start_new_wait_called(strftime, sagemaker_session): t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO + expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) expected_hyperparameters["sagemaker_region"] = '"us-west-2"' actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"] From c87d17c588988ed20266c4098ffbf8417d04f2d8 Mon Sep 17 00:00:00 2001 From: Christian Osendorfer Date: Fri, 4 Mar 2022 17:40:21 -0800 Subject: [PATCH 14/14] Resolve black formating issue. --- tests/unit/test_estimator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1f8363cf62..656d773914 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -1245,6 +1245,7 @@ def test_custom_code_bucket(time, sagemaker_session): expected_submit_dir ) + @patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket_without_prefix(time, sagemaker_session): code_bucket = "codebucket" @@ -1267,7 +1268,7 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session): _, _, train_kwargs = sagemaker_session.train.mock_calls[0] assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( expected_submit_dir - ) + ) def test_invalid_custom_code_bucket(sagemaker_session): @@ -1341,7 +1342,7 @@ def test_shuffle_config(sagemaker_session): BASE_HP = { "sagemaker_program": json.dumps(SCRIPT_NAME), "sagemaker_submit_directory": json.dumps( - "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME) + "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME) ), "sagemaker_job_name": json.dumps(JOB_NAME), }