diff --git a/doc/api/inference/serverless.rst b/doc/api/inference/serverless.rst new file mode 100644 index 0000000000..d338efd7be --- /dev/null +++ b/doc/api/inference/serverless.rst @@ -0,0 +1,9 @@ +Serverless Inference +--------------------- + +This module contains classes related to Amazon Sagemaker Serverless Inference + +.. automodule:: sagemaker.serverless.serverless_inference_config + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/overview.rst b/doc/overview.rst index 02290ff94c..bdd964c864 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -684,6 +684,63 @@ For more detailed explanations of the classes that this library provides for aut - `API docs for HyperparameterTuner and parameter range classes `__ - `API docs for analytics classes `__ +******************************* +SageMaker Serverless Inference +******************************* +Amazon SageMaker Serverless Inference enables you to easily deploy machine learning models for inference without having +to configure or manage the underlying infrastructure. After you trained a model, you can deploy it to Amazon Sagemaker +Serverless endpoint and then invoke the endpoint with the model to get inference results back. More information about +SageMaker Serverless Inference can be found in the `AWS documentation `__. + +To deploy serverless endpoint, you will need to create a ``ServerlessInferenceConfig``. +If you create ``ServerlessInferenceConfig`` without specifying its arguments, the default ``MemorySizeInMB`` will be **2048** and +the default ``MaxConcurrency`` will be **5** : + +.. code:: python + + from sagemaker.serverless import ServerlessInferenceConfig + + # Create an empty ServerlessInferenceConfig object to use default values + serverless_config = new ServerlessInferenceConfig() + +Or you can specify ``MemorySizeInMB`` and ``MaxConcurrency`` in ``ServerlessInferenceConfig`` (example shown below): + +.. code:: python + + # Specify MemorySizeInMB and MaxConcurrency in the serverless config object + serverless_config = new ServerlessInferenceConfig( + memory_size_in_mb=4096, + max_concurrency=10, + ) + +Then use the ``ServerlessInferenceConfig`` in the estimator's ``deploy()`` method to deploy a serverless endpoint: + +.. code:: python + + # Deploys the model that was generated by fit() to a SageMaker serverless endpoint + serverless_predictor = estimator.deploy(serverless_inference_config=serverless_config) + +After deployment is complete, you can use predictor's ``predict()`` method to invoke the serverless endpoint just like +real-time endpoints: + +.. code:: python + + # Serializes data and makes a prediction request to the SageMaker serverless endpoint + response = serverless_predictor.predict(data) + +Clean up the endpoint and model if needed after inference: + +.. code:: python + + # Tears down the SageMaker endpoint and endpoint configuration + serverless_predictor.delete_endpoint() + + # Deletes the SageMaker model + serverless_predictor.delete_model() + +For more details about ``ServerlessInferenceConfig``, +see the API docs for `Serverless Inference `__ + ************************* SageMaker Batch Transform ************************* diff --git a/setup.py b/setup.py index 5b6c31fd3c..3c4728c96e 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ "attrs", - "boto3>=1.20.18", + "boto3>=1.20.21", "google-pasta", "numpy>=1.9.0", "protobuf>=3.1", diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 006cc4846c..70353ea2c2 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -111,33 +111,58 @@ def __init__( """Initializes a configuration of the sensitive groups in the dataset. Args: - label_values_or_threshold (Any): List of label values or threshold to indicate positive - outcome used for bias metrics. - facet_name (str or [str]): String or List of strings of sensitive attribute(s) in the - input data for which we like to compare metrics. - facet_values_or_threshold (list): Optional list of values to form a sensitive group or - threshold for a numeric facet column that defines the lower bound of a sensitive - group. Defaults to considering each possible value as sensitive group and - computing metrics vs all the other examples. - If facet_name is a list, this needs to be None or a List consisting of lists or None - with the same length as facet_name list. + label_values_or_threshold ([int or float or str]): List of label value(s) or threshold + to indicate positive outcome used for bias metrics. Dependency on the problem type, + + * Binary problem: The list shall include one positive value. + * Categorical problem: The list shall include one or more (but not all) categories + which are the positive values. + * Regression problem: The list shall include one threshold that defines the lower + bound of positive values. + + facet_name (str or int or [str] or [int]): Sensitive attribute column name (or index in + the input data) for which you like to compute bias metrics. It can also be a list + of names (or indexes) if you like to compute for multiple sensitive attributes. + facet_values_or_threshold ([int or float or str] or [[int or float or str]]): + The parameter indicates the sensitive group. If facet_name is a scalar, then it can + be None or a list. Depending on the data type of the facet column, + + * Binary: None means computing the bias metrics for each binary value. Or add one + binary value to the list, to compute its bias metrics only. + * Categorical: None means computing the bias metrics for each category. Or add one + or more (but not all) categories to the list, to compute their bias metrics v.s. + the other categories. + * Continuous: The list shall include one and only one threshold which defines the + lower bound of a sensitive group. + + If facet_name is a list, then it can be None if all facets are of binary type or + categorical type. Otherwise it shall be a list, and each element is the values or + threshold of the corresponding facet. group_name (str): Optional column name or index to indicate a group column to be used for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or 'Conditional Demographic Disparity in Predicted Labels - CDDPL'. """ - if isinstance(facet_name, str): + if isinstance(facet_name, list): + assert len(facet_name) > 0, "Please provide at least one facet" + if facet_values_or_threshold is None: + facet_list = [ + {"name_or_index": single_facet_name} for single_facet_name in facet_name + ] + elif len(facet_values_or_threshold) == len(facet_name): + facet_list = [] + for i, single_facet_name in enumerate(facet_name): + facet = {"name_or_index": single_facet_name} + if facet_values_or_threshold is not None: + _set(facet_values_or_threshold[i], "value_or_threshold", facet) + facet_list.append(facet) + else: + raise ValueError( + "The number of facet names doesn't match the number of facet values" + ) + else: facet = {"name_or_index": facet_name} _set(facet_values_or_threshold, "value_or_threshold", facet) facet_list = [facet] - elif facet_values_or_threshold is None or len(facet_name) == len(facet_values_or_threshold): - facet_list = [] - for i, single_facet_name in enumerate(facet_name): - facet = {"name_or_index": single_facet_name} - if facet_values_or_threshold is not None: - _set(facet_values_or_threshold[i], "value_or_threshold", facet) - facet_list.append(facet) - else: - raise ValueError("Wrong combination of argument values passed") self.analysis_config = { "label_values_or_threshold": label_values_or_threshold, "facet": facet_list, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index cf039fa010..1de03d6183 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -852,8 +852,8 @@ def logs(self): def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -864,6 +864,7 @@ def deploy( kms_key=None, data_capture_config=None, tags=None, + serverless_inference_config=None, **kwargs, ): """Deploy the trained model to an Amazon SageMaker endpoint. @@ -874,10 +875,14 @@ def deploy( http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html Args: - initial_instance_count (int): Minimum number of EC2 instances to - deploy to an endpoint for prediction. - instance_type (str): Type of EC2 instance to deploy to an endpoint - for prediction, for example, 'ml.c4.xlarge'. + initial_instance_count (int): The initial number of instances to run + in the ``Endpoint`` created from this ``Model``. If not using + serverless inference, then it need to be a number larger or equals + to 1 (default: None) + instance_type (str): The EC2 instance type to deploy this Model to. + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using + serverless inference, then it is required to deploy a model. + (default: None) serializer (:class:`~sagemaker.serializers.BaseSerializer`): A serializer object, used to encode data for an inference endpoint (default: None). If ``serializer`` is not None, then @@ -910,6 +915,11 @@ def deploy( data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, we will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None) tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -927,6 +937,7 @@ def deploy( endpoint and obtain inferences. """ removed_kwargs("update_endpoint", kwargs) + is_serverless = serverless_inference_config is not None self._ensure_latest_training_job() self._ensure_base_job_name() default_name = name_from_base(self.base_job_name) @@ -934,7 +945,7 @@ def deploy( model_name = model_name or default_name self.deploy_instance_type = instance_type - if use_compiled_model: + if use_compiled_model and not is_serverless: family = "_".join(instance_type.split(".")[:-1]) if family not in self._compiled_models: raise ValueError( @@ -959,6 +970,7 @@ def deploy( wait=wait, kms_key=kms_key, data_capture_config=data_capture_config, + serverless_inference_config=serverless_inference_config, ) def register( diff --git a/src/sagemaker/lineage/action.py b/src/sagemaker/lineage/action.py index 67ba6d5db0..9046a3ccf2 100644 --- a/src/sagemaker/lineage/action.py +++ b/src/sagemaker/lineage/action.py @@ -13,13 +13,22 @@ """This module contains code to create and manage SageMaker ``Actions``.""" from __future__ import absolute_import -from typing import Optional, Iterator +from typing import Optional, Iterator, List from datetime import datetime -from sagemaker import Session +from sagemaker.session import Session from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types, _utils from sagemaker.lineage._api_types import ActionSource, ActionSummary +from sagemaker.lineage.artifact import Artifact + +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) class Action(_base_types.Record): @@ -116,7 +125,7 @@ def delete(self, disassociate: bool = False): self._invoke_api(self._boto_delete_method, self._boto_delete_members) @classmethod - def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action": + def load(cls, action_name: str, sagemaker_session=None) -> "Action": """Load an existing action and return an ``Action`` object representing it. Args: @@ -250,3 +259,86 @@ def list( max_results=max_results, next_token=next_token, ) + + def artifacts( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List[Artifact]: + """Use a lineage query to retrieve all artifacts that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + +class ModelPackageApprovalAction(Action): + """An Amazon SageMaker model package approval action, which is part of a SageMaker lineage.""" + + def datasets( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[Artifact]: + """Use a lineage query to retrieve all upstream datasets that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def model_package(self): + """Get model package from model package approval action. + + Returns: + Model package. + """ + source_uri = self.source.source_uri + if source_uri is None: + return None + + model_package_name = source_uri.split("/")[1] + return self.sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package_name + ) + + def endpoints( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS + ) -> List: + """Use a lineage query to retrieve downstream endpoint contexts that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Contexts: Contexts representing an endpoint. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index fc41808099..3921562beb 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact": return artifact def downstream_trials(self, sagemaker_session=None) -> list: - """Retrieve all trial runs which that use this artifact. + """Use the lineage API to retrieve all downstream trials that use this artifact. Args: - sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session + sagemaker_session (obj): Sagemaker Session to use. If not provided a default session will be created. Returns: @@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list: ) trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations)) + return self._get_trial_from_trial_component(trial_component_arns) + + def downstream_trials_v2(self) -> list: + """Use a lineage query to retrieve all downstream trials that use this artifact. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS) + + def upstream_trials(self) -> List: + """Use the lineage query to retrieve all upstream trials that use this artifact. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS) + + def _trials( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List: + """Use the lineage query to retrieve all trials that use this artifact. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices)) + return self._get_trial_from_trial_component(trial_component_arns) + + def _get_trial_from_trial_component(self, trial_component_arns: list) -> List: + """Retrieve all upstream trial runs which that use the trial component arns. + + Args: + trial_component_arns (list): list of trial component arns + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ if not trial_component_arns: # no outgoing associations for this artifact return [] @@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list: num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn) trial_components: list = [] - sagemaker_session = sagemaker_session or _utils.default_session() + sagemaker_session = self.sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client for i in range(num_search_batches): @@ -335,6 +383,17 @@ def list( sagemaker_session=sagemaker_session, ) + def s3_uri_artifacts(self, s3_uri: str) -> dict: + """Retrieve a list of artifacts that use provided s3 uri. + + Args: + s3_uri (str): A S3 URI. + + Returns: + A list of ``Artifacts`` + """ + return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri) + class ModelArtifact(Artifact): """A SageMaker lineage artifact representing a model. @@ -349,7 +408,7 @@ def endpoints(self) -> list: """Get association summaries for endpoints deployed with this model. Returns: - [AssociationSummary]: A list of associations repesenting the endpoints using the model. + [AssociationSummary]: A list of associations representing the endpoints using the model. """ endpoint_development_actions: Iterator = Association.list( source_arn=self.artifact_arn, @@ -522,3 +581,69 @@ def endpoint_contexts( for vertex in query_result.vertices: endpoint_contexts.append(vertex.to_lineage_object()) return endpoint_contexts + + def upstream_datasets(self) -> List[Artifact]: + """Use the lineage query to retrieve upstream artifacts that use this dataset artifact. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS) + + def downstream_datasets(self) -> List[Artifact]: + """Use the lineage query to retrieve downstream artifacts that use this dataset. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS) + + def _datasets( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List[Artifact]: + """Use the lineage query to retrieve all artifacts that use this dataset. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + +class ImageArtifact(Artifact): + """A SageMaker lineage artifact representing an image. + + Common model specific lineage traversals to discover how the image is connected + to other entities. + """ + + def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]: + """Use the lineage query to retrieve datasets that use this image artifact. + + Args: + direction (LineageQueryDirectionEnum): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index 469b9aeb1a..aef919e876 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -31,6 +31,8 @@ LineageQueryDirectionEnum, ) from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent class Context(_base_types.Record): @@ -256,12 +258,30 @@ def list( sagemaker_session=sagemaker_session, ) + def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]: + """Use the lineage query to retrieve actions that use this context. + + Args: + direction (LineageQueryDirectionEnum): The query direction. + + Returns: + list of Actions: Actions. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + class EndpointContext(Context): """An Amazon SageMaker endpoint context, which is part of a SageMaker lineage.""" def models(self) -> List[association.Association]: - """Get all models deployed by all endpoint versions of the endpoint. + """Use Lineage API to get all models deployed by this endpoint. Returns: list of Associations: Associations that destination represents an endpoint's model. @@ -286,7 +306,7 @@ def models(self) -> List[association.Association]: def models_v2( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS ) -> List[Artifact]: - """Get artifacts representing models from the context lineage by querying lineage data. + """Use the lineage query to retrieve downstream model artifacts that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. @@ -335,7 +355,7 @@ def models_v2( def dataset_artifacts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> List[Artifact]: - """Get artifacts representing datasets from the endpoint's lineage. + """Use the lineage query to retrieve datasets that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. @@ -360,6 +380,9 @@ def training_job_arns( ) -> List[str]: """Get ARNs for all training jobs that appear in the endpoint's lineage. + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + Returns: list of str: Training job ARNs. """ @@ -382,11 +405,78 @@ def training_job_arns( training_job_arns.append(trial_component["Source"]["SourceArn"]) return training_job_arns + def processing_jobs( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve processing jobs that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component that represent Processing jobs. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def transform_jobs( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve transform jobs that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component that represent Transform jobs. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def trial_components( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve trial components that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + def pipeline_execution_arn( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> str: """Get the ARN for the pipeline execution associated with this endpoint (if any). + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + Returns: str: A pipeline execution ARN. """ @@ -400,3 +490,15 @@ def pipeline_execution_arn( return tag["Value"] return None + + +class ModelPackageGroup(Context): + """An Amazon SageMaker model package group context, which is part of a SageMaker lineage.""" + + def pipeline_execution_arn(self) -> str: + """Get the ARN for the pipeline execution associated with this model package group (if any). + + Returns: + str: A pipeline execution ARN. + """ + return self.properties.get("PipelineExecutionArn") diff --git a/src/sagemaker/lineage/lineage_trial_component.py b/src/sagemaker/lineage/lineage_trial_component.py new file mode 100644 index 0000000000..f8bc0e53b4 --- /dev/null +++ b/src/sagemaker/lineage/lineage_trial_component.py @@ -0,0 +1,184 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to create and manage SageMaker ``LineageTrialComponent``.""" +from __future__ import absolute_import + +import logging + +from typing import List + +from sagemaker.apiutils import _base_types +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) +from sagemaker.lineage.artifact import Artifact + + +LOGGER = logging.getLogger("sagemaker") + + +class LineageTrialComponent(_base_types.Record): + """An Amazon SageMaker, lineage trial component, which is part of a SageMaker lineage. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and also can be + created directly. To automatically associate trial components with a trial and experiment + supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker from the + name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (obj): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_create_method: str = "create_trial_component" + _boto_load_method: str = "describe_trial_component" + _boto_update_method: str = "update_trial_component" + _boto_delete_method: str = "delete_trial_component" + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + @classmethod + def load(cls, trial_component_name: str, sagemaker_session=None) -> "LineageTrialComponent": + """Load an existing trial component and return an ``TrialComponent`` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + LineageTrialComponent: A SageMaker ``LineageTrialComponent`` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + def pipeline_execution_arn(self) -> str: + """Get the ARN for the pipeline execution associated with this trial component (if any). + + Returns: + str: A pipeline execution ARN. + """ + tags = self.sagemaker_session.sagemaker_client.list_tags( + ResourceArn=self.trial_component_arn + )["Tags"] + for tag in tags: + if tag["Key"] == "sagemaker:pipeline-execution-arn": + return tag["Value"] + return None + + def dataset_artifacts( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[Artifact]: + """Use the lineage query to retrieve datasets that use this trial component. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.trial_component_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def models( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS + ) -> List[Artifact]: + """Use the lineage query to retrieve models that use this trial component. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.trial_component_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 78cfc700e6..a54331c39a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -23,6 +23,7 @@ class LineageEntityEnum(Enum): """Enum of lineage entities for use in a query filter.""" + TRIAL = "Trial" ACTION = "Action" ARTIFACT = "Artifact" CONTEXT = "Context" @@ -43,6 +44,9 @@ class LineageSourceEnum(Enum): MODEL_REPLACE = "ModelReplaced" TENSORBOARD = "TensorBoard" TRAINING_JOB = "TrainingJob" + APPROVAL = "Approval" + PROCESSING_JOB = "ProcessingJob" + TRANSFORM_JOB = "TransformJob" class LineageQueryDirectionEnum(Enum): @@ -127,11 +131,15 @@ def __eq__(self, other): ) def to_lineage_object(self): - """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object.""" - from sagemaker.lineage.artifact import Artifact, ModelArtifact + """Convert the ``Vertex`` object to its corresponding lineage object. + + Returns: + A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context`` + or ``TrialComponent`` object. + """ from sagemaker.lineage.context import Context, EndpointContext - from sagemaker.lineage.artifact import DatasetArtifact from sagemaker.lineage.action import Action + from sagemaker.lineage.lineage_trial_component import LineageTrialComponent if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -142,17 +150,31 @@ def to_lineage_object(self): return Context.load(context_name=resource_name, sagemaker_session=self._session) if self.lineage_entity == LineageEntityEnum.ARTIFACT.value: - if self.lineage_source == LineageSourceEnum.MODEL.value: - return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) - if self.lineage_source == LineageSourceEnum.DATASET.value: - return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) - return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + return self._artifact_to_lineage_object() if self.lineage_entity == LineageEntityEnum.ACTION.value: return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session) + if self.lineage_entity == LineageEntityEnum.TRIAL_COMPONENT.value: + trial_component_name = get_resource_name_from_arn(self.arn) + return LineageTrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=self._session + ) raise ValueError("Vertex cannot be converted to a lineage object.") + def _artifact_to_lineage_object(self): + """Convert the ``Vertex`` object to its corresponding ``Artifact``.""" + from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact + from sagemaker.lineage.artifact import DatasetArtifact + + if self.lineage_source == LineageSourceEnum.MODEL.value: + return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_source == LineageSourceEnum.DATASET.value: + return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_source == LineageSourceEnum.IMAGE.value: + return ImageArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -203,11 +225,11 @@ def __init__( def _to_request_dict(self): """Convert the lineage filter to its API representation.""" filter_request = {} - if self.entities: + if self.sources: filter_request["Types"] = list( map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources) ) - if self.sources: + if self.entities: filter_request["LineageTypes"] = list( map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities) ) @@ -241,9 +263,12 @@ def _get_edge(self, edge): def _get_vertex(self, vertex): """Convert lineage query API response to a Vertex.""" + vertex_type = None + if "Type" in vertex: + vertex_type = vertex["Type"] return Vertex( arn=vertex["Arn"], - lineage_source=vertex["Type"], + lineage_source=vertex_type, lineage_entity=vertex["LineageType"], sagemaker_session=self._session, ) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 830bb50dab..3ed9160f06 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -32,6 +32,7 @@ from sagemaker.inputs import CompilationInput from sagemaker.deprecations import removed_kwargs from sagemaker.predictor import PredictorBase +from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer LOGGER = logging.getLogger("sagemaker") @@ -209,7 +210,7 @@ def register( model_package_arn=model_package.get("ModelPackageArn"), ) - def _init_sagemaker_session_if_does_not_exist(self, instance_type): + def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. The type of session object is determined by the instance type. @@ -663,12 +664,7 @@ def compile( if target_instance_family == "ml_eia2": pass elif target_instance_family.startswith("ml_"): - self.image_uri = self._compilation_image_uri( - self.sagemaker_session.boto_region_name, - target_instance_family, - framework, - framework_version, - ) + self.image_uri = job_status.get("InferenceImage", None) self._is_compiled_model = True else: LOGGER.warning( @@ -688,8 +684,8 @@ def compile( def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -698,6 +694,7 @@ def deploy( kms_key=None, wait=True, data_capture_config=None, + serverless_inference_config=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -715,9 +712,13 @@ def deploy( Args: initial_instance_count (int): The initial number of instances to run - in the ``Endpoint`` created from this ``Model``. + in the ``Endpoint`` created from this ``Model``. If not using + serverless inference, then it need to be a number larger or equals + to 1 (default: None) instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p2.xlarge', or 'local' for local mode. + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using + serverless inference, then it is required to deploy a model. + (default: None) serializer (:class:`~sagemaker.serializers.BaseSerializer`): A serializer object, used to encode data for an inference endpoint (default: None). If ``serializer`` is not None, then @@ -746,7 +747,17 @@ def deploy( data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. - + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, we will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None) + Raises: + ValueError: If arguments combination check failed in these circumstances: + - If no role is specified or + - If serverless inference config is not specified and instance type and instance + count are also not specified or + - If a wrong type of object is provided as serverless inference config Returns: callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` @@ -758,27 +769,47 @@ def deploy( if self.role is None: raise ValueError("Role can not be null for deploying a model") - if instance_type.startswith("ml.inf") and not self._is_compiled_model: + is_serverless = serverless_inference_config is not None + if not is_serverless and not (instance_type and initial_instance_count): + raise ValueError( + "Must specify instance type and instance count unless using serverless inference" + ) + + if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): + raise ValueError( + "serverless_inference_config needs to be a ServerlessInferenceConfig object" + ) + + if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: LOGGER.warning( "Your model is not compiled. Please compile your model before using Inferentia." ) - compiled_model_suffix = "-".join(instance_type.split(".")[:-1]) - if self._is_compiled_model: + compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1]) + if self._is_compiled_model and not is_serverless: self._ensure_base_name_if_needed(self.image_uri) if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) self._create_sagemaker_model(instance_type, accelerator_type, tags) + + serverless_inference_config_dict = ( + serverless_inference_config._to_request_dict() if is_serverless else None + ) production_variant = sagemaker.production_variant( - self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type + self.name, + instance_type, + initial_instance_count, + accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config_dict, ) if endpoint_name: self.endpoint_name = endpoint_name else: base_endpoint_name = self._base_name or utils.base_from_name(self.name) - if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix): - base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) + if self._is_compiled_model and not is_serverless: + if not base_endpoint_name.endswith(compiled_model_suffix): + base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) self.endpoint_name = utils.name_from_base(base_endpoint_name) data_capture_config_dict = None diff --git a/src/sagemaker/serverless/__init__.py b/src/sagemaker/serverless/__init__.py index 8bf55c0dcd..4ecffb56d8 100644 --- a/src/sagemaker/serverless/__init__.py +++ b/src/sagemaker/serverless/__init__.py @@ -13,3 +13,6 @@ """Classes for performing machine learning on serverless compute.""" from sagemaker.serverless.model import LambdaModel # noqa: F401 from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401 +from sagemaker.serverless.serverless_inference_config import ( # noqa: F401 + ServerlessInferenceConfig, +) diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py new file mode 100644 index 0000000000..39950f4f84 --- /dev/null +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code related to the ServerlessInferenceConfig class. + +Codes are used for configuring async inference endpoint. Use it when deploying +the model to the endpoints. +""" +from __future__ import print_function, absolute_import + + +class ServerlessInferenceConfig(object): + """Configuration object passed in when deploying models to Amazon SageMaker Endpoints. + + This object specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference + """ + + def __init__( + self, + memory_size_in_mb=2048, + max_concurrency=5, + ): + """Initialize a ServerlessInferenceConfig object for serverless inference configuration. + + Args: + memory_size_in_mb (int): Optional. The memory size of your serverless endpoint. + Valid values are in 1 GB increments: 1024 MB, 2048 MB, 3072 MB, 4096 MB, + 5120 MB, or 6144 MB. If no value is provided, Amazon SageMaker will choose + the default value for you. (Default: 2048) + max_concurrency (int): Optional. The maximum number of concurrent invocations + your serverless endpoint can process. If no value is provided, Amazon + SageMaker will choose the default value for you. (Default: 5) + """ + self.memory_size_in_mb = memory_size_in_mb + self.max_concurrency = max_concurrency + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = { + "MemorySizeInMB": self.memory_size_in_mb, + "MaxConcurrency": self.max_concurrency, + } + + return request_dict diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 56f008be84..1de9571ac6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4382,11 +4382,12 @@ def pipeline_container_def(models, instance_type=None): def production_variant( model_name, - instance_type, - initial_instance_count=1, + instance_type=None, + initial_instance_count=None, variant_name="AllTraffic", initial_weight=1, accelerator_type=None, + serverless_inference_config=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -4405,14 +4406,15 @@ def production_variant( accelerator_type (str): Type of Elastic Inference accelerator for this production variant. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html + serverless_inference_config (dict): Specifies configuration dict related to serverless + endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig + object (default: None) Returns: dict[str, str]: An SageMaker ``ProductionVariant`` description """ production_variant_configuration = { "ModelName": model_name, - "InstanceType": instance_type, - "InitialInstanceCount": initial_instance_count, "VariantName": variant_name, "InitialVariantWeight": initial_weight, } @@ -4420,6 +4422,13 @@ def production_variant( if accelerator_type: production_variant_configuration["AcceleratorType"] = accelerator_type + if serverless_inference_config: + production_variant_configuration["ServerlessConfig"] = serverless_inference_config + else: + initial_instance_count = initial_instance_count or 1 + production_variant_configuration["InitialInstanceCount"] = initial_instance_count + production_variant_configuration["InstanceType"] = instance_type + return production_variant_configuration diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index d4eb3e60aa..d13bdc8ffa 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -258,8 +258,8 @@ def register( def deploy( self, - initial_instance_count, - instance_type, + initial_instance_count=None, + instance_type=None, serializer=None, deserializer=None, accelerator_type=None, @@ -269,6 +269,7 @@ def deploy( wait=True, data_capture_config=None, update_endpoint=None, + serverless_inference_config=None, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -287,6 +288,7 @@ def deploy( kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, + serverless_inference_config=serverless_inference_config, update_endpoint=update_endpoint, ) diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py new file mode 100644 index 0000000000..8b244c78f2 --- /dev/null +++ b/src/sagemaker/workflow/emr_step.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The step definitions for workflow.""" +from __future__ import absolute_import + +from typing import List + +from sagemaker.workflow.entities import ( + RequestType, +) +from sagemaker.workflow.properties import ( + Properties, +) +from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig + + +class EMRStepConfig: + """Config for a Hadoop Jar step.""" + + def __init__( + self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None + ): + """Create a definition for input data used by an EMR cluster(job flow) step. + + See AWS documentation on the ``StepConfig`` API for more details on the parameters. + + Args: + args(List[str]): + A list of command line arguments passed to + the JAR file's main function when executed. + jar(str): A path to a JAR file run during the step. + main_class(str): The name of the main class in the specified Java file. + properties(List(dict)): A list of key-value pairs that are set when the step runs. + """ + self.jar = jar + self.args = args + self.main_class = main_class + self.properties = properties + + def to_request(self) -> RequestType: + """Convert EMRStepConfig object to request dict.""" + config = {"HadoopJarStep": {"Jar": self.jar}} + if self.args is not None: + config["HadoopJarStep"]["Args"] = self.args + if self.main_class is not None: + config["HadoopJarStep"]["MainClass"] = self.main_class + if self.properties is not None: + config["HadoopJarStep"]["Properties"] = self.properties + + return config + + +class EMRStep(Step): + """EMR step for workflow.""" + + def __init__( + self, + name: str, + display_name: str, + description: str, + cluster_id: str, + step_config: EMRStepConfig, + depends_on: List[str] = None, + cache_config: CacheConfig = None, + ): + """Constructs a EMRStep. + + Args: + name(str): The name of the EMR step. + display_name(str): The display name of the EMR step. + description(str): The description of the EMR step. + cluster_id(str): The ID of the running EMR cluster. + step_config(EMRStepConfig): One StepConfig to be executed by the job flow. + depends_on(List[str]): + A list of step names this `sagemaker.workflow.steps.EMRStep` depends on + cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. + + """ + super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on) + + emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()} + self.args = emr_step_args + self.cache_config = cache_config + + root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr") + root_property.__dict__["ClusterId"] = cluster_id + self._properties = root_property + + @property + def arguments(self) -> RequestType: + """The arguments dict that is used to call `AddJobFlowSteps`. + + NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs. + The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime. + In addition to that, we will also need to include emr job inputs and output config. + """ + return self.args + + @property + def properties(self) -> RequestType: + """A Properties object representing the EMR DescribeStepResponse model""" + return self._properties + + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + if self.cache_config: + request_dict.update(self.cache_config.config) + return request_dict diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 96147e8e8b..6e9aba4408 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -23,17 +23,24 @@ class PropertiesMeta(type): - """Load an internal shapes attribute from the botocore sagemaker service model.""" + """Load an internal shapes attribute from the botocore service model - _shapes = None + for sagemaker and emr service. + """ + + _shapes_map = dict() _primitive_types = {"string", "boolean", "integer", "float"} def __new__(mcs, *args, **kwargs): - """Loads up the shapes from the botocore sagemaker service model.""" - if mcs._shapes is None: + """Loads up the shapes from the botocore service model.""" + if len(mcs._shapes_map.keys()) == 0: loader = botocore.loaders.Loader() - model = loader.load_service_model("sagemaker", "service-2") - mcs._shapes = model["shapes"] + + sagemaker_model = loader.load_service_model("sagemaker", "service-2") + emr_model = loader.load_service_model("emr", "service-2") + mcs._shapes_map["sagemaker"] = sagemaker_model["shapes"] + mcs._shapes_map["emr"] = emr_model["shapes"] + return super().__new__(mcs, *args, **kwargs) @@ -45,32 +52,41 @@ def __init__( path: str, shape_name: str = None, shape_names: List[str] = None, + service_name: str = "sagemaker", ): """Create a Properties instance representing the given shape. Args: path (str): The parent path of the Properties instance. - shape_name (str): The botocore sagemaker service model shape name. - shape_names (str): A List of the botocore sagemaker service model shape name. + shape_name (str): The botocore service model shape name. + shape_names (str): A List of the botocore service model shape name. """ self._path = path shape_names = [] if shape_names is None else shape_names self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names + shapes = Properties._shapes_map.get(service_name, {}) + for name in self._shape_names: - shape = Properties._shapes.get(name, {}) + shape = shapes.get(name, {}) shape_type = shape.get("type") if shape_type in Properties._primitive_types: self.__str__ = name elif shape_type == "structure": members = shape["members"] for key, info in members.items(): - if Properties._shapes.get(info["shape"], {}).get("type") == "list": - self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"]) - elif Properties._shapes.get(info["shape"], {}).get("type") == "map": - self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"]) + if shapes.get(info["shape"], {}).get("type") == "list": + self.__dict__[key] = PropertiesList( + f"{path}.{key}", info["shape"], service_name + ) + elif shapes.get(info["shape"], {}).get("type") == "map": + self.__dict__[key] = PropertiesMap( + f"{path}.{key}", info["shape"], service_name + ) else: - self.__dict__[key] = Properties(f"{path}.{key}", info["shape"]) + self.__dict__[key] = Properties( + f"{path}.{key}", info["shape"], service_name=service_name + ) @property def expr(self): @@ -81,16 +97,17 @@ def expr(self): class PropertiesList(Properties): """PropertiesList for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None): + def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): """Create a PropertiesList instance representing the given shape. Args: path (str): The parent path of the PropertiesList instance. - shape_name (str): The botocore sagemaker service model shape name. - root_shape_name (str): The botocore sagemaker service model shape name. + shape_name (str): The botocore service model shape name. + service_name (str): The botocore service name. """ super(PropertiesList, self).__init__(path, shape_name) self.shape_name = shape_name + self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() def __getitem__(self, item: Union[int, str]): @@ -100,7 +117,7 @@ def __getitem__(self, item: Union[int, str]): item (Union[int, str]): The index of the item in sequence. """ if item not in self._items.keys(): - shape = Properties._shapes.get(self.shape_name) + shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["member"]["shape"] if isinstance(item, str): property_item = Properties(f"{self._path}['{item}']", member) @@ -114,15 +131,17 @@ def __getitem__(self, item: Union[int, str]): class PropertiesMap(Properties): """PropertiesMap for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None): + def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): """Create a PropertiesMap instance representing the given shape. Args: path (str): The parent path of the PropertiesMap instance. shape_name (str): The botocore sagemaker service model shape name. + service_name (str): The botocore service name. """ super(PropertiesMap, self).__init__(path, shape_name) self.shape_name = shape_name + self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() def __getitem__(self, item: Union[int, str]): @@ -132,7 +151,7 @@ def __getitem__(self, item: Union[int, str]): item (Union[int, str]): The index of the item in sequence. """ if item not in self._items.keys(): - shape = Properties._shapes.get(self.shape_name) + shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["value"]["shape"] if isinstance(item, str): property_item = Properties(f"{self._path}['{item}']", member) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 30eca68f66..329bd1d950 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -60,6 +60,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): LAMBDA = "Lambda" QUALITY_CHECK = "QualityCheck" CLARIFY_CHECK = "ClarifyCheck" + EMR = "EMR" @attr.s diff --git a/tests/data/workflow/emr-script.sh b/tests/data/workflow/emr-script.sh new file mode 100644 index 0000000000..aeee24ec95 --- /dev/null +++ b/tests/data/workflow/emr-script.sh @@ -0,0 +1,2 @@ +echo "This is emr test script..." +sleep 15 diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index e4966ab67c..672af41de9 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -32,6 +32,14 @@ from smexperiments import trial_component, trial, experiment from random import randint from botocore.exceptions import ClientError +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from tests.integ.sagemaker.lineage.helpers import name, names @@ -39,6 +47,7 @@ SLEEP_TIME_TWO_SECONDS = 2 STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17" STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17" +STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup" @pytest.fixture @@ -207,6 +216,24 @@ def trial_associated_artifact(artifact_obj, trial_obj, trial_component_obj, sage sagemaker_session=sagemaker_session, ) trial_obj.add_trial_component(trial_component_obj) + time.sleep(4) + yield artifact_obj + trial_obj.remove_trial_component(trial_component_obj) + assntn.delete() + + +@pytest.fixture +def upstream_trial_associated_artifact( + artifact_obj, trial_obj, trial_component_obj, sagemaker_session +): + assntn = association.Association.create( + source_arn=trial_component_obj.trial_component_arn, + destination_arn=artifact_obj.artifact_arn, + association_type="ContributedTo", + sagemaker_session=sagemaker_session, + ) + trial_obj.add_trial_component(trial_component_obj) + time.sleep(3) yield artifact_obj trial_obj.remove_trial_component(trial_component_obj) assntn.delete() @@ -514,6 +541,103 @@ def _get_static_pipeline_execution_arn(sagemaker_session): return pipeline_execution_arn +@pytest.fixture +def static_approval_action( + sagemaker_session, static_endpoint_context, static_pipeline_execution_arn +): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + action_name = query_result.vertices[0].arn.split("/")[1] + yield action.ModelPackageApprovalAction.load( + action_name=action_name, sagemaker_session=sagemaker_session + ) + + +@pytest.fixture +def static_model_deployment_action(sagemaker_session, static_endpoint_context): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + model_approval_actions = [] + for vertex in query_result.vertices: + model_approval_actions.append(vertex.to_lineage_object()) + yield model_approval_actions[0] + + +@pytest.fixture +def static_processing_job_trial_component( + sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] + ) + + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + processing_jobs = [] + for vertex in query_result.vertices: + processing_jobs.append(vertex.to_lineage_object()) + + return processing_jobs[0] + + +@pytest.fixture +def static_training_job_trial_component( + sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] + ) + + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + training_jobs = [] + for vertex in query_result.vertices: + training_jobs.append(vertex.to_lineage_object()) + + return training_jobs[0] + + +@pytest.fixture +def static_transform_job_trial_component( + static_processing_job_trial_component, sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_processing_job_trial_component.trial_component_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.DESCENDANTS, + include_edges=False, + ) + transform_jobs = [] + for vertex in query_result.vertices: + transform_jobs.append(vertex.to_lineage_object()) + yield transform_jobs[0] + + @pytest.fixture def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session) @@ -543,6 +667,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): ) +@pytest.fixture +def static_model_package_group_context(sagemaker_session, static_pipeline_execution_arn): + + model_package_group_arn = get_model_package_group_arn_from_static_pipeline(sagemaker_session) + + contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=model_package_group_arn)[ + "ContextSummaries" + ] + if len(contexts) != 1: + raise ( + Exception( + f"Got an unexpected number of Contexts for \ + model package group {STATIC_MODEL_PACKAGE_GROUP_NAME} from pipeline \ + execution {static_pipeline_execution_arn}. \ + Expected 1 but got {len(contexts)}" + ) + ) + + yield context.ModelPackageGroup.load( + contexts[0]["ContextName"], sagemaker_session=sagemaker_session + ) + + @pytest.fixture def static_model_artifact(sagemaker_session, static_pipeline_execution_arn): model_package_arn = get_model_package_arn_from_static_pipeline( @@ -590,6 +737,23 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session): ) +@pytest.fixture +def static_image_artifact(static_model_artifact, sagemaker_session): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_model_artifact.artifact_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + image_artifact = [] + for vertex in query_result.vertices: + image_artifact.append(vertex.to_lineage_object()) + return image_artifact[0] + + def get_endpoint_arn_from_static_pipeline(sagemaker_session): try: endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint( @@ -604,6 +768,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session): raise e +def get_model_package_group_arn_from_static_pipeline(sagemaker_session): + static_model_package_group_arn = ( + sagemaker_session.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=STATIC_MODEL_PACKAGE_GROUP_NAME + )["ModelPackageGroupArn"] + ) + return static_model_package_group_arn + + def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session): # get the model package ARN from the pipeline pipeline_execution_steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps( diff --git a/tests/integ/sagemaker/lineage/test_action.py b/tests/integ/sagemaker/lineage/test_action.py index a0531450b5..8b462279ca 100644 --- a/tests/integ/sagemaker/lineage/test_action.py +++ b/tests/integ/sagemaker/lineage/test_action.py @@ -20,6 +20,7 @@ import pytest from sagemaker.lineage import action +from sagemaker.lineage.query import LineageQueryDirectionEnum def test_create_delete(action_obj): @@ -117,3 +118,50 @@ def test_tags(action_obj, sagemaker_session): # length of actual tags will be greater than 1 assert len(actual_tags) > 0 assert [actual_tags[-1]] == tags + + +def test_upstream_artifacts(static_model_deployment_action): + artifacts_from_query = static_model_deployment_action.artifacts( + direction=LineageQueryDirectionEnum.ASCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + + +def test_downstream_artifacts(static_approval_action): + artifacts_from_query = static_approval_action.artifacts( + direction=LineageQueryDirectionEnum.DESCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + + +def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session): + + sagemaker_session.sagemaker_client.add_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + AssociationType="ContributedTo", + ) + time.sleep(3) + artifacts_from_query = static_approval_action.datasets() + + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + assert artifact.artifact_type == "DataSet" + + sagemaker_session.sagemaker_client.delete_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + ) + + +def test_endpoints(static_approval_action): + endpoint_contexts_from_query = static_approval_action.endpoints() + assert len(endpoint_contexts_from_query) > 0 + for endpoint in endpoint_contexts_from_query: + assert endpoint.context_type == "Endpoint" + assert "endpoint" in endpoint.context_arn diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index 4a0c6398b2..7ecbd0ac15 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -102,6 +102,13 @@ def test_list_by_type(artifact_objs, sagemaker_session): assert artifact_names_listed[0] == expected_name +def test_get_artifact(static_dataset_artifact): + s3_uri = static_dataset_artifact.source.source_uri + expected_artifact = static_dataset_artifact.s3_uri_artifacts(s3_uri=s3_uri) + for ar in expected_artifact["ArtifactSummaries"]: + assert ar.get("Source")["SourceUri"] == s3_uri + + def test_downstream_trials(trial_associated_artifact, trial_obj, sagemaker_session): # allow trial components to index, 30 seconds max def validate(): @@ -120,6 +127,18 @@ def validate(): retry(validate, num_attempts=3) +def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): + trials = trial_associated_artifact.downstream_trials_v2() + assert len(trials) == 1 + assert trial_obj.trial_name in trials + + +def test_upstream_trials(upstream_trial_associated_artifact, trial_obj, sagemaker_session): + trials = upstream_trial_associated_artifact.upstream_trials() + assert len(trials) == 1 + assert trial_obj.trial_name in trials + + @pytest.mark.timeout(30) def test_tag(artifact_obj, sagemaker_session): tag = {"Key": "foo", "Value": "bar"} diff --git a/tests/integ/sagemaker/lineage/test_context.py b/tests/integ/sagemaker/lineage/test_context.py index 5b36cee746..bdc4cb34e3 100644 --- a/tests/integ/sagemaker/lineage/test_context.py +++ b/tests/integ/sagemaker/lineage/test_context.py @@ -20,6 +20,7 @@ import pytest from sagemaker.lineage import context +from sagemaker.lineage.query import LineageQueryDirectionEnum def test_create_delete(context_obj): @@ -32,6 +33,16 @@ def test_create_delete_with_association(context_obj_with_association): assert context_obj_with_association.context_arn +def test_action(static_endpoint_context, sagemaker_session): + actions_from_query = static_endpoint_context.actions( + direction=LineageQueryDirectionEnum.ASCENDANTS + ) + + assert len(actions_from_query) > 0 + for action in actions_from_query: + assert "action" in action.action_arn + + def test_save(context_obj, sagemaker_session): context_obj.description = "updated description" context_obj.properties = {"k3": "v3"} diff --git a/tests/integ/sagemaker/lineage/test_dataset_artifact.py b/tests/integ/sagemaker/lineage/test_dataset_artifact.py index be03a85e86..ee81b7e137 100644 --- a/tests/integ/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/integ/sagemaker/lineage/test_dataset_artifact.py @@ -35,3 +35,19 @@ def test_endpoint_contexts( assert len(contexts_from_query) > 0 for context in contexts_from_query: assert context.context_type == "Endpoint" + + +def test_get_upstream_datasets(static_dataset_artifact, sagemaker_session): + artifacts_from_query = static_dataset_artifact.upstream_datasets() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn + + +def test_get_down_datasets(static_dataset_artifact, sagemaker_session): + artifacts_from_query = static_dataset_artifact.downstream_datasets() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn diff --git a/tests/integ/sagemaker/lineage/test_endpoint_context.py b/tests/integ/sagemaker/lineage/test_endpoint_context.py index 78a33e8ef9..2a797bd5cb 100644 --- a/tests/integ/sagemaker/lineage/test_endpoint_context.py +++ b/tests/integ/sagemaker/lineage/test_endpoint_context.py @@ -15,6 +15,7 @@ import time SLEEP_TIME_ONE_SECONDS = 1 +SLEEP_TIME_THREE_SECONDS = 3 def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj): @@ -59,3 +60,46 @@ def test_pipeline_execution_arn(static_endpoint_context, static_pipeline_executi pipeline_execution_arn = static_endpoint_context.pipeline_execution_arn() assert pipeline_execution_arn == static_pipeline_execution_arn + + +def test_transform_jobs( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + sagemaker_session.sagemaker_client.add_association( + SourceArn=static_transform_job_trial_component.trial_component_arn, + DestinationArn=static_endpoint_context.context_arn, + AssociationType="ContributedTo", + ) + time.sleep(SLEEP_TIME_THREE_SECONDS) + transform_jobs_from_query = static_endpoint_context.transform_jobs() + + assert len(transform_jobs_from_query) > 0 + for transform_job in transform_jobs_from_query: + assert "transform-job" in transform_job.trial_component_arn + assert "TransformJob" in transform_job.source.get("SourceType") + + sagemaker_session.sagemaker_client.delete_association( + SourceArn=static_transform_job_trial_component.trial_component_arn, + DestinationArn=static_endpoint_context.context_arn, + ) + + +def test_processing_jobs( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + processing_jobs_from_query = static_endpoint_context.processing_jobs() + assert len(processing_jobs_from_query) > 0 + for processing_job in processing_jobs_from_query: + assert "processing-job" in processing_job.trial_component_arn + assert "ProcessingJob" in processing_job.source.get("SourceType") + + +def test_trial_components( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + trial_components_from_query = static_endpoint_context.trial_components() + + assert len(trial_components_from_query) > 0 + for trial_component in trial_components_from_query: + assert "job" in trial_component.trial_component_arn + assert "Job" in trial_component.source.get("SourceType") diff --git a/tests/integ/sagemaker/lineage/test_image_artifact.py b/tests/integ/sagemaker/lineage/test_image_artifact.py new file mode 100644 index 0000000000..bd0f76445d --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_image_artifact.py @@ -0,0 +1,26 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ImageArtifact``""" +from __future__ import absolute_import + +from sagemaker.lineage.query import LineageQueryDirectionEnum + + +def test_dataset(static_image_artifact, sagemaker_session): + artifacts_from_query = static_image_artifact.datasets( + direction=LineageQueryDirectionEnum.DESCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn diff --git a/tests/integ/sagemaker/lineage/test_lineage_trial_component.py b/tests/integ/sagemaker/lineage/test_lineage_trial_component.py new file mode 100644 index 0000000000..d8a8a5d9c8 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_lineage_trial_component.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``Trial Component``""" +from __future__ import absolute_import + + +def test_dataset_artifacts(static_training_job_trial_component): + artifacts_from_query = static_training_job_trial_component.dataset_artifacts() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + + +def test_models(static_processing_job_trial_component): + artifacts_from_query = static_processing_job_trial_component.models() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "Model" + + +def test_pipeline_execution_arn(static_training_job_trial_component, static_pipeline_execution_arn): + pipeline_execution_arn = static_training_job_trial_component.pipeline_execution_arn() + assert pipeline_execution_arn == static_pipeline_execution_arn diff --git a/tests/integ/sagemaker/lineage/test_model_package_group_context.py b/tests/integ/sagemaker/lineage/test_model_package_group_context.py new file mode 100644 index 0000000000..8f6cd85e77 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_model_package_group_context.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ModelPackageGroup``""" +from __future__ import absolute_import + + +def test_pipeline_execution_arn(static_model_package_group_context, static_pipeline_execution_arn): + pipeline_execution_arn = static_model_package_group_context.pipeline_execution_arn() + + assert pipeline_execution_arn == static_pipeline_execution_arn diff --git a/tests/integ/test_serverless_inference.py b/tests/integ/test_serverless_inference.py new file mode 100644 index 0000000000..40b1ace147 --- /dev/null +++ b/tests/integ/test_serverless_inference.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import sagemaker.amazon.pca +from sagemaker.utils import unique_name_from_base +from sagemaker.serverless import ServerlessInferenceConfig +from tests.integ import datasets, TRAINING_DEFAULT_TIMEOUT_MINUTES +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name + + +@pytest.fixture +def training_set(): + return datasets.one_p_mnist() + + +def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_set): + job_name = unique_name_from_base("pca") + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + pca = sagemaker.amazon.pca.PCA( + role="SageMakerRole", + instance_count=1, + instance_type=cpu_instance_type, + num_components=48, + sagemaker_session=sagemaker_session, + enable_network_isolation=True, + ) + + pca.algorithm_mode = "randomized" + pca.subtract_mean = True + pca.extra_components = 5 + pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name) + + with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): + + predictor_serverless = pca.deploy( + endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig() + ) + + result = predictor_serverless.predict(training_set[0][:5]) + + assert len(result) == 5 + for record in result: + assert record.label["projection"] is not None diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index de03608b27..4a3354470a 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -69,6 +69,7 @@ from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum +from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig from sagemaker.wrangler.processing import DataWranglerProcessor from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition from sagemaker.workflow.execution_variables import ExecutionVariables @@ -1148,6 +1149,50 @@ def test_two_step_lambda_pipeline_with_output_reference( pass +def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_name): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + + emr_step_config = EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=["dummy_emr_script_path"], + ) + + step_emr_1 = EMRStep( + name="emr-step-1", + cluster_id="j-1YONHTCP3YZKC", + display_name="emr_step_1", + description="MyEMRStepDescription", + step_config=emr_step_config, + ) + + step_emr_2 = EMRStep( + name="emr-step-2", + cluster_id=step_emr_1.properties.ClusterId, + display_name="emr_step_2", + description="MyEMRStepDescription", + step_config=emr_step_config, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count], + steps=[step_emr_1, step_emr_2], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_conditional_pytorch_training_model_registration( sagemaker_session, role, diff --git a/tests/unit/sagemaker/lineage/test_action.py b/tests/unit/sagemaker/lineage/test_action.py index 79e59b679b..120d643063 100644 --- a/tests/unit/sagemaker/lineage/test_action.py +++ b/tests/unit/sagemaker/lineage/test_action.py @@ -16,6 +16,7 @@ import unittest.mock from sagemaker.lineage import action, _api_types +from sagemaker.lineage._api_types import ActionSource def test_create(sagemaker_session): @@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session): delete_with_association_expected_calls == sagemaker_session.sagemaker_client.delete_association.mock_calls ) + + +def test_model_package(sagemaker_session): + obj = action.ModelPackageApprovalAction( + sagemaker_session, + action_name="abcd-aws-model-package", + source=ActionSource( + source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1", + source_type="ARN", + ), + status="updated-status", + properties={"k1": "v1"}, + properties_to_remove=["k2"], + ) + sagemaker_session.sagemaker_client.describe_model_package.return_value = {} + obj.model_package() + + sagemaker_session.sagemaker_client.describe_model_package.assert_called_with( + ModelPackageName="pipeline88modelpackage", + ) diff --git a/tests/unit/sagemaker/lineage/test_artifact.py b/tests/unit/sagemaker/lineage/test_artifact.py index 72228ec964..218532c1b7 100644 --- a/tests/unit/sagemaker/lineage/test_artifact.py +++ b/tests/unit/sagemaker/lineage/test_artifact.py @@ -377,3 +377,143 @@ def test_downstream_trials(sagemaker_session): ), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + + +def test_downstream_trials_v2(sagemaker_session): + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10) + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [{"TrialName": "test-trial-name"}], + } + } + ] + } + + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + + result = obj.downstream_trials_v2() + + expected_trials = ["test-trial-name"] + + assert expected_trials == result + + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=["test-arn"], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + +def test_upstream_trials(sagemaker_session): + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10) + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [{"TrialName": "test-trial-name"}], + } + } + ] + } + + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + + result = obj.upstream_trials() + + expected_trials = ["test-trial-name"] + + assert expected_trials == result + + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=["test-arn"], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + +def test_s3_uri_artifacts(sagemaker_session): + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + source_uri="s3://abced", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + sagemaker_session.sagemaker_client.list_artifacts.side_effect = [ + { + "ArtifactSummaries": [ + { + "ArtifactArn": "A", + "ArtifactName": "B", + "Source": { + "SourceUri": "D", + "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], + }, + "ArtifactType": "test-type", + } + ], + "NextToken": "100", + }, + ] + result = obj.s3_uri_artifacts(s3_uri="s3://abced") + + expected_calls = [ + unittest.mock.call(SourceUri="s3://abced"), + ] + expected_result = { + "ArtifactSummaries": [ + { + "ArtifactArn": "A", + "ArtifactName": "B", + "Source": { + "SourceUri": "D", + "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], + }, + "ArtifactType": "test-type", + } + ], + "NextToken": "100", + } + assert expected_calls == sagemaker_session.sagemaker_client.list_artifacts.mock_calls + assert result == expected_result diff --git a/tests/unit/sagemaker/lineage/test_context.py b/tests/unit/sagemaker/lineage/test_context.py index 5cf48dea67..d87120dde2 100644 --- a/tests/unit/sagemaker/lineage/test_context.py +++ b/tests/unit/sagemaker/lineage/test_context.py @@ -17,6 +17,9 @@ import pytest from sagemaker.lineage import context, _api_types +from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent +from sagemaker.lineage.query import LineageQueryDirectionEnum @pytest.fixture @@ -328,3 +331,182 @@ def test_create_delete_with_association(sagemaker_session): delete_with_association_expected_calls == sagemaker_session.sagemaker_client.delete_association.mock_calls ) + + +def test_actions(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + action_arn = "arn:aws:sagemaker:us-west-2:123456789012:action/lineage-unit-3b05f017-0d87-4c37" + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": action_arn, "Type": "Approval", "LineageType": "Action"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + sagemaker_session.sagemaker_client.describe_action.return_value = { + "ActionName": "MyAction", + "ActionArn": action_arn, + } + + action_list = obj.actions(direction=LineageQueryDirectionEnum.DESCENDANTS) + + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"LineageTypes": ["Action"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + expected_action_list = [ + Action( + action_arn=action_arn, + action_name="MyAction", + ) + ] + + assert expected_action_list[0].action_arn == action_list[0].action_arn + assert expected_action_list[0].action_name == action_list[0].action_name + + +def test_processing_jobs(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + processing_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": processing_job_arn, "Type": "ProcessingJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyProcessingJob", + "TrialComponentArn": processing_job_arn, + } + + trial_component_list = obj.processing_jobs(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["ProcessingJob"], "LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyProcessingJob", + trial_component_arn=processing_job_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) + + +def test_transform_jobs(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + transform_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": transform_job_arn, "Type": "TransformJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTransformJob", + "TrialComponentArn": transform_job_arn, + } + + trial_component_list = obj.transform_jobs(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["TransformJob"], "LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyTransformJob", + trial_component_arn=transform_job_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) + + +def test_trial_components(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": trial_component_arn, "Type": "TransformJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTransformJob", + "TrialComponentArn": trial_component_arn, + } + + trial_component_list = obj.trial_components(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyTransformJob", + trial_component_arn=trial_component_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) diff --git a/tests/unit/sagemaker/lineage/test_dataset_artifact.py b/tests/unit/sagemaker/lineage/test_dataset_artifact.py index 6db5a215f6..074efb488c 100644 --- a/tests/unit/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/unit/sagemaker/lineage/test_dataset_artifact.py @@ -83,3 +83,89 @@ def test_trained_models(sagemaker_session): ) ] assert expected_model_list == model_list + + +def test_upstream_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.DatasetArtifact( + sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.upstream_datasets() + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name + + +def test_downstream_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.DatasetArtifact( + sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.downstream_datasets() + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name diff --git a/tests/unit/sagemaker/lineage/test_image_artifact.py b/tests/unit/sagemaker/lineage/test_image_artifact.py new file mode 100644 index 0000000000..485d942db3 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_image_artifact.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest.mock + +import pytest +from sagemaker.lineage import artifact +from sagemaker.lineage.query import LineageQueryDirectionEnum + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.ImageArtifact(sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.datasets(direction=LineageQueryDirectionEnum.DESCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name diff --git a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py new file mode 100644 index 0000000000..9b466832a1 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest.mock + +import pytest +from sagemaker.lineage import artifact, lineage_trial_component + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_dataset_artifacts(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.dataset_artifacts() + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[trial_component_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name + + +def test_models(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + model_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/models" + model_name = "myDataset" + + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": model_arn, "Type": "Model", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": model_name, + "ArtifactArn": model_arn, + } + + model_list = obj.models() + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[trial_component_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_model_list = [ + artifact.DatasetArtifact( + artifact_name=model_name, + artifact_arn=model_arn, + ) + ] + assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn + assert expected_model_list[0].artifact_name == model_list[0].artifact_name + + +def test_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "sagemaker:pipeline-execution-arn", "Value": "tag1"}, + ], + } + expected_calls = [ + unittest.mock.call(ResourceArn=trial_component_arn), + ] + pipeline_execution_arn_result = obj.pipeline_execution_arn() + assert pipeline_execution_arn_result == "tag1" + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls + + +def test_no_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "abcd", "Value": "efg"}, + ], + } + expected_calls = [ + unittest.mock.call(ResourceArn=trial_component_arn), + ] + pipeline_execution_arn_result = obj.pipeline_execution_arn() + expected_result = None + assert pipeline_execution_arn_result == expected_result + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls diff --git a/tests/unit/sagemaker/lineage/test_model_package_group_context.py b/tests/unit/sagemaker/lineage/test_model_package_group_context.py new file mode 100644 index 0000000000..8c14773df7 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_model_package_group_context.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ModelPackageGroup``""" +from __future__ import absolute_import + +import unittest.mock +import pytest +from sagemaker.lineage import context + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_pipeline_execution_arn(sagemaker_session): + obj = context.ModelPackageGroup( + sagemaker_session, + context_name="foo", + description="test-description", + properties={"PipelineExecutionArn": "abcd", "k2": "v2"}, + properties_to_remove=["E"], + ) + actual_result = obj.pipeline_execution_arn() + expected_result = "abcd" + assert expected_result == actual_result diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 50bb14e6b1..ae76fd199c 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -11,9 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import unittest.mock from sagemaker.lineage.artifact import DatasetArtifact, ModelArtifact, Artifact from sagemaker.lineage.context import EndpointContext, Context from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest @@ -286,6 +288,49 @@ def test_vertex_to_object_context(sagemaker_session): assert isinstance(context, Context) +def test_vertex_to_object_trial_component(sagemaker_session): + + tc_arn = "arn:aws:sagemaker:us-west-2:963951943925:trial-component/abaloneprocess-ixyt08z3ru-aws-processing-job" + vertex = Vertex( + arn=tc_arn, + lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, + lineage_source=LineageSourceEnum.TRANSFORM_JOB.value, + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTrialComponent", + "TrialComponentArn": tc_arn, + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/my_trial_component", + "SourceType": "ARN", + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", + }, + "TrialComponentType": "ModelDeployment", + "Properties": { + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ + pipeline/mypipeline/execution/0irnteql64d0", + "PipelineStepName": "MyStep", + "Status": "Completed", + }, + "CreationTime": 1608225384.0, + "CreatedBy": {}, + "LastModifiedTime": 1608225384.0, + "LastModifiedBy": {}, + } + + trial_component = vertex.to_lineage_object() + + expected_calls = [ + unittest.mock.call(TrialComponentName="abaloneprocess-ixyt08z3ru-aws-processing-job"), + ] + assert expected_calls == sagemaker_session.sagemaker_client.describe_trial_component.mock_calls + + assert trial_component.trial_component_arn == tc_arn + assert trial_component.trial_component_name == "MyTrialComponent" + assert isinstance(trial_component, LineageTrialComponent) + + def test_vertex_to_object_model_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", @@ -317,6 +362,37 @@ def test_vertex_to_object_model_artifact(sagemaker_session): assert isinstance(artifact, ModelArtifact) +def test_vertex_to_object_artifact(sagemaker_session): + vertex = Vertex( + arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", + lineage_entity=LineageEntityEnum.ARTIFACT.value, + lineage_source=LineageSourceEnum.MODEL.value, + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactArn": "arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", + "SourceTypes": [], + }, + "ArtifactType": None, + "Properties": {}, + "CreationTime": 1608224704.149, + "CreatedBy": {}, + "LastModifiedTime": 1608224704.149, + "LastModifiedBy": {}, + } + + artifact = vertex.to_lineage_object() + + assert ( + artifact.artifact_arn + == "arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f" + ) + assert isinstance(artifact, Artifact) + + def test_vertex_to_dataset_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", @@ -379,7 +455,7 @@ def test_vertex_to_model_artifact(sagemaker_session): assert isinstance(artifact, ModelArtifact) -def test_vertex_to_object_artifact(sagemaker_session): +def test_vertex_to_object_image_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", lineage_entity=LineageEntityEnum.ARTIFACT.value, @@ -441,7 +517,7 @@ def test_vertex_to_object_action(sagemaker_session): def test_vertex_to_object_unconvertable(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", - lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, + lineage_entity=LineageEntityEnum.TRIAL.value, lineage_source=LineageSourceEnum.TENSORBOARD.value, sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 284956aa75..03af3acb7d 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -19,6 +19,7 @@ import sagemaker from sagemaker.model import Model +from sagemaker.serverless import ServerlessInferenceConfig MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -62,7 +63,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None) production_variant.assert_called_with( - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None + MODEL_NAME, + INSTANCE_TYPE, + INSTANCE_COUNT, + accelerator_type=None, + serverless_inference_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -101,7 +106,11 @@ def test_deploy_accelerator_type( create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None) production_variant.assert_called_with( - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE + MODEL_NAME, + INSTANCE_TYPE, + INSTANCE_COUNT, + accelerator_type=ACCELERATOR_TYPE, + serverless_inference_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -279,6 +288,71 @@ def test_deploy_data_capture_config(production_variant, name_from_base, sagemake ) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.model.Model._create_sagemaker_model") +@patch("sagemaker.production_variant") +def test_deploy_serverless_inference(production_variant, create_sagemaker_model, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT) + production_variant.return_value = production_variant_result + + serverless_inference_config = ServerlessInferenceConfig() + serverless_inference_config_dict = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 5, + } + + model.deploy( + serverless_inference_config=serverless_inference_config, + ) + + create_sagemaker_model.assert_called_with(None, None, None) + production_variant.assert_called_with( + MODEL_NAME, + None, + None, + accelerator_type=None, + serverless_inference_config=serverless_inference_config_dict, + ) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=ENDPOINT_NAME, + production_variants=[production_variant_result], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +def test_deploy_wrong_inference_type(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) + + bad_args = ( + {"instance_type": INSTANCE_TYPE}, + {"initial_instance_count": INSTANCE_COUNT}, + {"instance_type": None, "initial_instance_count": None}, + ) + for args in bad_args: + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + model.deploy(args) + + +def test_deploy_wrong_serverless_config(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE) + with pytest.raises( + ValueError, + match="serverless_inference_config needs to be a ServerlessInferenceConfig object", + ): + model.deploy(serverless_inference_config={}) + + @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_deploy_creates_correct_session(local_session, session): diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 16b5bc6ee6..2357c771f9 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -20,12 +20,15 @@ MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" +IMAGE_URI = "inference-container-uri" + REGION = "us-west-2" NEO_REGION_ACCOUNT = "301217895009" DESCRIBE_COMPILATION_JOB_RESPONSE = { "CompilationJobStatus": "Completed", "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, + "InferenceImage": IMAGE_URI, } @@ -52,12 +55,7 @@ def test_compile_model_for_inferentia(sagemaker_session): framework_version="1.15.0", job_name="compile-model", ) - assert ( - "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format( - NEO_REGION_ACCOUNT, REGION - ) - == model.image_uri - ) + assert DESCRIBE_COMPILATION_JOB_RESPONSE["InferenceImage"] == model.image_uri assert model._is_compiled_model is True @@ -271,11 +269,12 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem assert model.endpoint_name.startswith("{}-ml-c4".format(model_name)) -@patch("sagemaker.session.Session") -def test_compile_with_framework_version_15(session): - session.return_value.boto_region_name = REGION +def test_compile_with_framework_version_15(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) - model = _create_model() + model = _create_model(sagemaker_session) model.compile( target_instance_family="ml_c4", input_shape={"data": [1, 3, 1024, 1024]}, @@ -286,14 +285,15 @@ def test_compile_with_framework_version_15(session): job_name="compile-model", ) - assert "1.5" in model.image_uri + assert IMAGE_URI == model.image_uri -@patch("sagemaker.session.Session") -def test_compile_with_framework_version_16(session): - session.return_value.boto_region_name = REGION +def test_compile_with_framework_version_16(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) - model = _create_model() + model = _create_model(sagemaker_session) model.compile( target_instance_family="ml_c4", input_shape={"data": [1, 3, 1024, 1024]}, @@ -304,26 +304,7 @@ def test_compile_with_framework_version_16(session): job_name="compile-model", ) - assert "1.6" in model.image_uri - - -@patch("sagemaker.session.Session") -def test_compile_validates_framework_version(session): - session.return_value.boto_region_name = REGION - - model = _create_model() - with pytest.raises(ValueError) as e: - model.compile( - target_instance_family="ml_c4", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="pytorch", - framework_version="1.6.1", - job_name="compile-model", - ) - - assert "Unsupported neo-pytorch version: 1.6.1." in str(e) + assert IMAGE_URI == model.image_uri @patch("sagemaker.session.Session") @@ -347,3 +328,25 @@ def test_compile_with_pytorch_neo_in_ml_inf(session): ) != model.image_uri ) + + +def test_compile_validates_framework_version(sagemaker_session): + sagemaker_session.wait_for_compilation_job = Mock( + return_value={ + "CompilationJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, + "InferenceImage": None, + } + ) + model = _create_model(sagemaker_session) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="pytorch", + framework_version="1.6.1", + job_name="compile-model", + ) + + assert model.image_uri is None diff --git a/tests/unit/sagemaker/serverless/test_serverless_inference_config.py b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py new file mode 100644 index 0000000000..fab80748a4 --- /dev/null +++ b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.serverless import ServerlessInferenceConfig + +DEFAULT_MEMORY_SIZE_IN_MB = 2048 +DEFAULT_MAX_CONCURRENCY = 5 + +DEFAULT_REQUEST_DICT = { + "MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB, + "MaxConcurrency": DEFAULT_MAX_CONCURRENCY, +} + + +def test_init(): + serverless_inference_config = ServerlessInferenceConfig() + + assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB + assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY + + +def test_to_request_dict(): + serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict() + + assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT diff --git a/tests/unit/sagemaker/workflow/test_emr_step.py b/tests/unit/sagemaker/workflow/test_emr_step.py new file mode 100644 index 0000000000..e0dd81ebb5 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_emr_step.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +import pytest + +from mock import Mock + +from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig +from sagemaker.workflow.steps import CacheConfig +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.parameters import ParameterString + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name="us-west-2") + session_mock = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="us-west-2", + config=None, + local_mode=False, + ) + return session_mock + + +def test_emr_step_with_one_step_config(sagemaker_session): + emr_step_config = EMRStepConfig( + jar="s3:/script-runner/script-runner.jar", + args=["--arg_0", "arg_0_value"], + main_class="com.my.main", + properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}], + ) + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=emr_step_config, + depends_on=["TestStep"], + cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"), + ) + emr_step.add_depends_on(["SecondTestStep"]) + assert emr_step.to_request() == { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner.jar", + "MainClass": "com.my.main", + "Properties": [ + {"Key": "Foo", "Value": "Foo_value"}, + {"Key": "Bar", "Value": "Bar_value"}, + ], + } + }, + }, + "DependsOn": ["TestStep", "SecondTestStep"], + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + } + + assert emr_step.properties.ClusterId == "MyClusterID" + assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"} + assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"} + assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"} + assert emr_step.properties.Config.MainClass.expr == {"Get": "Steps.MyEMRStep.Config.MainClass"} + assert emr_step.properties.Id.expr == {"Get": "Steps.MyEMRStep.Id"} + assert emr_step.properties.Name.expr == {"Get": "Steps.MyEMRStep.Name"} + assert emr_step.properties.Status.State.expr == {"Get": "Steps.MyEMRStep.Status.State"} + assert emr_step.properties.Status.FailureDetails.Reason.expr == { + "Get": "Steps.MyEMRStep.Status.FailureDetails.Reason" + } + + +def test_pipeline_interpolates_emr_outputs(sagemaker_session): + parameter = ParameterString("MyStr") + + emr_step_config_1 = EMRStepConfig( + jar="s3:/script-runner/script-runner_1.jar", + args=["--arg_0", "arg_0_value"], + main_class="com.my.main", + properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}], + ) + + step_emr_1 = EMRStep( + name="emr_step_1", + cluster_id="MyClusterID", + display_name="emr_step_1", + description="MyEMRStepDescription", + depends_on=["TestStep"], + step_config=emr_step_config_1, + ) + + emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar") + + step_emr_2 = EMRStep( + name="emr_step_2", + cluster_id="MyClusterID", + display_name="emr_step_2", + description="MyEMRStepDescription", + depends_on=["TestStep"], + step_config=emr_step_config_2, + ) + + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step_emr_1, step_emr_2], + sagemaker_session=sagemaker_session, + ) + + assert json.loads(pipeline.definition()) == { + "Version": "2020-12-01", + "Metadata": {}, + "Parameters": [{"Name": "MyStr", "Type": "String"}], + "PipelineExperimentConfig": { + "ExperimentName": {"Get": "Execution.PipelineName"}, + "TrialName": {"Get": "Execution.PipelineExecutionId"}, + }, + "Steps": [ + { + "Name": "emr_step_1", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner_1.jar", + "MainClass": "com.my.main", + "Properties": [ + {"Key": "Foo", "Value": "Foo_value"}, + {"Key": "Bar", "Value": "Bar_value"}, + ], + } + }, + }, + "DependsOn": ["TestStep"], + "Description": "MyEMRStepDescription", + "DisplayName": "emr_step_1", + }, + { + "Name": "emr_step_2", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": {"Jar": "s3:/script-runner/script-runner_2.jar"} + }, + }, + "Description": "MyEMRStepDescription", + "DisplayName": "emr_step_2", + "DependsOn": ["TestStep"], + }, + ], + } diff --git a/tests/unit/sagemaker/workflow/test_properties.py b/tests/unit/sagemaker/workflow/test_properties.py index accaf46533..405de5c0b2 100644 --- a/tests/unit/sagemaker/workflow/test_properties.py +++ b/tests/unit/sagemaker/workflow/test_properties.py @@ -70,6 +70,19 @@ def test_properties_tuning_job(): } +def test_properties_emr_step(): + prop = Properties("Steps.MyStep", "Step", service_name="emr") + some_prop_names = ["Id", "Name", "Config", "ActionOnFailure", "Status"] + for name in some_prop_names: + assert name in prop.__dict__.keys() + + assert prop.Id.expr == {"Get": "Steps.MyStep.Id"} + assert prop.Name.expr == {"Get": "Steps.MyStep.Name"} + assert prop.ActionOnFailure.expr == {"Get": "Steps.MyStep.ActionOnFailure"} + assert prop.Config.Jar.expr == {"Get": "Steps.MyStep.Config.Jar"} + assert prop.Status.State.expr == {"Get": "Steps.MyStep.Status.State"} + + def test_properties_describe_model_package_output(): prop = Properties("Steps.MyStep", "DescribeModelPackageOutput") some_prop_names = ["ModelPackageName", "ModelPackageGroupName", "ModelPackageArn"] diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 48a096a69f..56057351f6 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -82,7 +82,7 @@ def test_invalid_data_config(): ) -def test_data_bias_config(): +def test_bias_config(): label_values = [1] facet_name = "F1" facet_threshold = 0.3 @@ -103,52 +103,122 @@ def test_data_bias_config(): assert expected_config == data_bias_config.get_config() -def test_data_bias_config_multi_facet(): - label_values = [1] - facet_name = ["Facet1", "Facet2"] - facet_threshold = [[0], [1, 2]] - group_name = "A151" - - data_bias_config = BiasConfig( - label_values_or_threshold=label_values, - facet_name=facet_name, - facet_values_or_threshold=facet_threshold, - group_name=group_name, - ) +def test_invalid_bias_config(): + # Empty facet list, + with pytest.raises(AssertionError, match="Please provide at least one facet"): + BiasConfig( + label_values_or_threshold=[1], + facet_name=[], + ) - expected_config = { - "label_values_or_threshold": label_values, - "facet": [ - {"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]}, - {"name_or_index": facet_name[1], "value_or_threshold": facet_threshold[1]}, - ], - "group_variable": group_name, - } - assert expected_config == data_bias_config.get_config() + # Two facets but only one value + with pytest.raises( + ValueError, match="The number of facet names doesn't match the number of facet values" + ): + BiasConfig( + label_values_or_threshold=[1], + facet_name=["Feature1", "Feature2"], + facet_values_or_threshold=[[1]], + ) -def test_data_bias_config_multi_facet_not_all_with_value(): +@pytest.mark.parametrize( + "facet_name,facet_values_or_threshold,expected_result", + [ + # One facet, assume that it is binary and value 1 indicates the sensitive group + [ + "Feature1", + [1], + { + "facet": [{"name_or_index": "Feature1", "value_or_threshold": [1]}], + }, + ], + # The same facet as above, facet value is not specified. (Clarify will compute bias metrics + # for each binary value). + [ + "Feature1", + None, + { + "facet": [{"name_or_index": "Feature1"}], + }, + ], + # Assume that the 2nd column (index 1, zero-based) of the dataset as facet, it has + # four categories and two of them indicate the sensitive group. + [ + 1, + ["category1, category2"], + { + "facet": [{"name_or_index": 1, "value_or_threshold": ["category1, category2"]}], + }, + ], + # The same facet as above, facet values are not specified. (Clarify will iterate + # the categories and compute bias metrics for each category). + [ + 1, + None, + { + "facet": [{"name_or_index": 1}], + }, + ], + # Assume that the facet is numeric value in range [0.0, 1.0]. Given facet threshold 0.5, + # interval (0.5, 1.0] indicates the sensitive group. + [ + "Feature3", + [0.5], + { + "facet": [{"name_or_index": "Feature3", "value_or_threshold": [0.5]}], + }, + ], + # Multiple facets + [ + ["Feature1", 1, "Feature3"], + [[1], ["category1, category2"], [0.5]], + { + "facet": [ + {"name_or_index": "Feature1", "value_or_threshold": [1]}, + {"name_or_index": 1, "value_or_threshold": ["category1, category2"]}, + {"name_or_index": "Feature3", "value_or_threshold": [0.5]}, + ], + }, + ], + # Multiple facets, no value or threshold + [ + ["Feature1", 1, "Feature3"], + None, + { + "facet": [ + {"name_or_index": "Feature1"}, + {"name_or_index": 1}, + {"name_or_index": "Feature3"}, + ], + }, + ], + # Multiple facets, specify values or threshold for some of them + [ + ["Feature1", 1, "Feature3"], + [[1], None, [0.5]], + { + "facet": [ + {"name_or_index": "Feature1", "value_or_threshold": [1]}, + {"name_or_index": 1}, + {"name_or_index": "Feature3", "value_or_threshold": [0.5]}, + ], + }, + ], + ], +) +def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_result): label_values = [1] - facet_name = ["Facet1", "Facet2"] - facet_threshold = [[0], None] - group_name = "A151" - - data_bias_config = BiasConfig( + bias_config = BiasConfig( label_values_or_threshold=label_values, facet_name=facet_name, - facet_values_or_threshold=facet_threshold, - group_name=group_name, + facet_values_or_threshold=facet_values_or_threshold, ) - expected_config = { "label_values_or_threshold": label_values, - "facet": [ - {"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]}, - {"name_or_index": facet_name[1]}, - ], - "group_variable": group_name, + **expected_result, } - assert expected_config == data_bias_config.get_config() + assert bias_config.get_config() == expected_config def test_model_config(): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 248eda1aa5..5940ca2c0a 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2801,6 +2801,37 @@ def test_generic_to_deploy(time, sagemaker_session): assert predictor.sagemaker_session == sagemaker_session +def test_generic_to_deploy_bad_arguments_combination(sagemaker_session): + e = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) + + e.fit() + + bad_args = ( + {"instance_type": INSTANCE_TYPE}, + {"initial_instance_count": INSTANCE_COUNT}, + {"instance_type": None, "initial_instance_count": None}, + ) + for args in bad_args: + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + e.deploy(args) + + with pytest.raises( + ValueError, + match="serverless_inference_config needs to be a ServerlessInferenceConfig object", + ): + e.deploy(serverless_inference_config={}) + + def test_generic_to_deploy_network_isolation(sagemaker_session): e = Estimator( IMAGE_URI, @@ -2850,6 +2881,7 @@ def test_generic_to_deploy_kms(create_model, sagemaker_session): wait=True, kms_key=kms_key, data_capture_config=None, + serverless_inference_config=None, ) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 7e6f63eb4e..991eeac2ec 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -68,6 +68,8 @@ ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} +INFERENCE_IMAGE_URI = "inference-uri" + @pytest.fixture() def sagemaker_session(): @@ -83,7 +85,10 @@ def sagemaker_session(): ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} - describe_compilation = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}} + describe_compilation = { + "ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}, + "InferenceImage": INFERENCE_IMAGE_URI, + } session.sagemaker_client.create_model_package.side_effect = MODEL_PKG_RESPONSE session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) @@ -195,12 +200,6 @@ def _create_compilation_job(input_shape, output_location): } -def _neo_inference_image(mxnet_version): - return "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-inference-{}:{}-cpu-py3".format( - FRAMEWORK.lower(), mxnet_version - ) - - @patch("sagemaker.estimator.name_from_base") @patch("sagemaker.utils.create_tar_file", MagicMock()) def test_create_model( @@ -422,7 +421,7 @@ def test_mxnet_neo(time, strftime, sagemaker_session, neo_mxnet_version): actual_compile_model_args = sagemaker_session.method_calls[3][2] assert expected_compile_model_args == actual_compile_model_args - assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version) + assert compiled_model.image_uri == INFERENCE_IMAGE_URI predictor = mx.deploy(1, CPU, use_compiled_model=True) assert isinstance(predictor, MXNetPredictor) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b2c14c5e5a..9a63a6c114 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -749,6 +749,11 @@ def test_training_input_all_arguments(): IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "InProgress"}) +SERVERLESS_INFERENCE_CONFIG = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 2, +} + @pytest.fixture() def sagemaker_session(): @@ -1911,6 +1916,31 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi ) +def test_endpoint_from_production_variants_with_serverless_inference_config(sagemaker_session): + ims = sagemaker_session + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant( + "A", "ml.p2.xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG + ), + sagemaker.production_variant( + "B", "p299.4096xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG + ), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) + ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) + tags = [{"ModelName": "TestModel"}] + sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs, tags) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=tags + ) + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=tags + ) + + def test_update_endpoint_succeed(sagemaker_session): sagemaker_session.sagemaker_client.describe_endpoint = Mock( return_value={"EndpointStatus": "InService"}