diff --git a/src/sagemaker/lineage/action.py b/src/sagemaker/lineage/action.py index 67ba6d5db0..1c8015a451 100644 --- a/src/sagemaker/lineage/action.py +++ b/src/sagemaker/lineage/action.py @@ -13,13 +13,23 @@ """This module contains code to create and manage SageMaker ``Actions``.""" from __future__ import absolute_import -from typing import Optional, Iterator +from typing import Optional, Iterator, List from datetime import datetime from sagemaker import Session from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types, _utils from sagemaker.lineage._api_types import ActionSource, ActionSummary +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.context import Context + +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) class Action(_base_types.Record): @@ -250,3 +260,86 @@ def list( max_results=max_results, next_token=next_token, ) + + def artifacts( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List[Artifact]: + """Use a lineage query to retrieve all artifacts that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + +class ModelPackageApprovalAction(Action): + """An Amazon SageMaker model package approval action, which is part of a SageMaker lineage.""" + + def datasets( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[Artifact]: + """Use a lineage query to retrieve all upstream datasets that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def model_package(self): + """Get model package from model package approval action. + + Returns: + Model package. + """ + source_uri = self.source.source_uri + if source_uri is None: + return None + + model_package_name = source_uri.split("/")[1] + return self.sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package_name + ) + + def endpoints( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS + ) -> List[Context]: + """Use a lineage query to retrieve downstream endpoint contexts that use this action. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Contexts: Contexts representing an endpoint. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.action_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 78cfc700e6..ecb48e3661 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -43,6 +43,7 @@ class LineageSourceEnum(Enum): MODEL_REPLACE = "ModelReplaced" TENSORBOARD = "TensorBoard" TRAINING_JOB = "TrainingJob" + APPROVAL = "Approval" class LineageQueryDirectionEnum(Enum): @@ -203,11 +204,11 @@ def __init__( def _to_request_dict(self): """Convert the lineage filter to its API representation.""" filter_request = {} - if self.entities: + if self.sources: filter_request["Types"] = list( map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources) ) - if self.sources: + if self.entities: filter_request["LineageTypes"] = list( map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities) ) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index e4966ab67c..863ab62183 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -25,6 +25,13 @@ association, artifact, ) +from sagemaker.lineage.query import ( + LineageFilter, + LineageEntityEnum, + LineageSourceEnum, + LineageQuery, + LineageQueryDirectionEnum, +) from sagemaker.model import ModelPackage from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution from sagemaker.workflow.pipeline import _PipelineExecution @@ -514,6 +521,42 @@ def _get_static_pipeline_execution_arn(sagemaker_session): return pipeline_execution_arn +@pytest.fixture +def static_approval_action( + sagemaker_session, static_endpoint_context, static_pipeline_execution_arn +): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + action_name = query_result.vertices[0].arn.split("/")[1] + yield action.ModelPackageApprovalAction.load( + action_name=action_name, sagemaker_session=sagemaker_session + ) + + +@pytest.fixture +def static_model_deployment_action(sagemaker_session, static_endpoint_context): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + model_approval_actions = [] + for vertex in query_result.vertices: + model_approval_actions.append(vertex.to_lineage_object()) + yield model_approval_actions[0] + + @pytest.fixture def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session) diff --git a/tests/integ/sagemaker/lineage/test_action.py b/tests/integ/sagemaker/lineage/test_action.py index a0531450b5..8b462279ca 100644 --- a/tests/integ/sagemaker/lineage/test_action.py +++ b/tests/integ/sagemaker/lineage/test_action.py @@ -20,6 +20,7 @@ import pytest from sagemaker.lineage import action +from sagemaker.lineage.query import LineageQueryDirectionEnum def test_create_delete(action_obj): @@ -117,3 +118,50 @@ def test_tags(action_obj, sagemaker_session): # length of actual tags will be greater than 1 assert len(actual_tags) > 0 assert [actual_tags[-1]] == tags + + +def test_upstream_artifacts(static_model_deployment_action): + artifacts_from_query = static_model_deployment_action.artifacts( + direction=LineageQueryDirectionEnum.ASCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + + +def test_downstream_artifacts(static_approval_action): + artifacts_from_query = static_approval_action.artifacts( + direction=LineageQueryDirectionEnum.DESCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + + +def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session): + + sagemaker_session.sagemaker_client.add_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + AssociationType="ContributedTo", + ) + time.sleep(3) + artifacts_from_query = static_approval_action.datasets() + + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert "artifact" in artifact.artifact_arn + assert artifact.artifact_type == "DataSet" + + sagemaker_session.sagemaker_client.delete_association( + SourceArn=static_dataset_artifact.artifact_arn, + DestinationArn=static_approval_action.action_arn, + ) + + +def test_endpoints(static_approval_action): + endpoint_contexts_from_query = static_approval_action.endpoints() + assert len(endpoint_contexts_from_query) > 0 + for endpoint in endpoint_contexts_from_query: + assert endpoint.context_type == "Endpoint" + assert "endpoint" in endpoint.context_arn diff --git a/tests/unit/sagemaker/lineage/test_action.py b/tests/unit/sagemaker/lineage/test_action.py index 79e59b679b..120d643063 100644 --- a/tests/unit/sagemaker/lineage/test_action.py +++ b/tests/unit/sagemaker/lineage/test_action.py @@ -16,6 +16,7 @@ import unittest.mock from sagemaker.lineage import action, _api_types +from sagemaker.lineage._api_types import ActionSource def test_create(sagemaker_session): @@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session): delete_with_association_expected_calls == sagemaker_session.sagemaker_client.delete_association.mock_calls ) + + +def test_model_package(sagemaker_session): + obj = action.ModelPackageApprovalAction( + sagemaker_session, + action_name="abcd-aws-model-package", + source=ActionSource( + source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1", + source_type="ARN", + ), + status="updated-status", + properties={"k1": "v1"}, + properties_to_remove=["k2"], + ) + sagemaker_session.sagemaker_client.describe_model_package.return_value = {} + obj.model_package() + + sagemaker_session.sagemaker_client.describe_model_package.assert_called_with( + ModelPackageName="pipeline88modelpackage", + )