diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index 2796d138fc..469b9aeb1a 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -283,6 +283,55 @@ def models(self) -> List[association.Association]: ] return model_list + def models_v2( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS + ) -> List[Artifact]: + """Get artifacts representing models from the context lineage by querying lineage data. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a model. + """ + # Firstly query out the model_deployment vertices + query_filter = LineageFilter( + entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT] + ) + model_deployment_query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + if not model_deployment_query_result: + return [] + + model_deployment_vertices: [] = model_deployment_query_result.vertices + + # Secondary query model based on model deployment + model_vertices = [] + for vertex in model_deployment_vertices: + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[vertex.arn], + query_filter=LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL] + ), + direction=LineageQueryDirectionEnum.DESCENDANTS, + include_edges=False, + ) + model_vertices.extend(query_result.vertices) + + if not model_vertices: + return [] + + model_artifacts = [] + for vertex in model_vertices: + lineage_object = vertex.to_lineage_object() + model_artifacts.append(lineage_object) + + return model_artifacts + def dataset_artifacts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> List[Artifact]: diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index b6cebdcb61..e4966ab67c 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -36,6 +36,7 @@ from tests.integ.sagemaker.lineage.helpers import name, names SLEEP_TIME_SECONDS = 1 +SLEEP_TIME_TWO_SECONDS = 2 STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17" STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17" @@ -360,12 +361,10 @@ def endpoint_context_obj(sagemaker_session): @pytest.fixture def model_obj(sagemaker_session): - model = context.Context.create( - context_name=name(), + model = artifact.Artifact.create( + artifact_name=name(), + artifact_type="Model", source_uri="bar1", - source_type="test-source-type1", - context_type="Model", - description="test-description", properties={"k1": "v1"}, sagemaker_session=sagemaker_session, ) @@ -417,11 +416,12 @@ def endpoint_context_associate_with_model(sagemaker_session, endpoint_action_obj association.Association.create( source_arn=endpoint_action_obj.action_arn, - destination_arn=model_obj.context_arn, + destination_arn=model_obj.artifact_arn, sagemaker_session=sagemaker_session, ) yield obj - time.sleep(SLEEP_TIME_SECONDS) + # sleep 2 seconds since take longer for lineage injection + time.sleep(SLEEP_TIME_TWO_SECONDS) obj.delete(disassociate=True) diff --git a/tests/integ/sagemaker/lineage/test_endpoint_context.py b/tests/integ/sagemaker/lineage/test_endpoint_context.py index 07cc48142d..78a33e8ef9 100644 --- a/tests/integ/sagemaker/lineage/test_endpoint_context.py +++ b/tests/integ/sagemaker/lineage/test_endpoint_context.py @@ -12,17 +12,31 @@ # language governing permissions and limitations under the License. """This module contains code to test SageMaker ``Contexts``""" from __future__ import absolute_import +import time + +SLEEP_TIME_ONE_SECONDS = 1 def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj): model_list = endpoint_context_associate_with_model.models() for model in model_list: assert model.source_arn == endpoint_action_obj.action_arn - assert model.destination_arn == model_obj.context_arn + assert model.destination_arn == model_obj.artifact_arn assert model.source_type == "ModelDeployment" assert model.destination_type == "Model" +def test_model_v2(endpoint_context_associate_with_model, model_obj, sagemaker_session): + time.sleep(SLEEP_TIME_ONE_SECONDS) + model_list = endpoint_context_associate_with_model.models_v2() + assert len(model_list) == 1 + for model in model_list: + assert model.artifact_arn == model_obj.artifact_arn + assert model.artifact_name == model_obj.artifact_name + assert model.artifact_type == "Model" + assert model.properties == model_obj.properties + + def test_dataset_artifacts(static_endpoint_context): artifacts_from_query = static_endpoint_context.dataset_artifacts() diff --git a/tests/unit/sagemaker/lineage/test_endpoint_context.py b/tests/unit/sagemaker/lineage/test_endpoint_context.py index 61e315c2ae..f1f92f493a 100644 --- a/tests/unit/sagemaker/lineage/test_endpoint_context.py +++ b/tests/unit/sagemaker/lineage/test_endpoint_context.py @@ -15,6 +15,8 @@ import unittest.mock from sagemaker.lineage import context, _api_types +from sagemaker.lineage._api_types import ArtifactSource +from sagemaker.lineage.artifact import ModelArtifact def test_models(sagemaker_session): @@ -75,3 +77,113 @@ def test_models(sagemaker_session): ) ] assert expected_model_list == model_list + + +def test_models_v2(sagemaker_session): + arn1 = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-integ-3b05f017-0d87-4c37" + + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=arn1) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": arn1, "Type": "Model", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + sagemaker_session.sagemaker_client.describe_context.return_value = { + "ContextName": "MyContext", + "ContextArn": arn1, + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:endpoint/myendpoint", + "SourceType": "ARN", + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", + }, + "ContextType": "Endpoint", + "Properties": { + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ + pipeline/mypipeline/execution/0irnteql64d0", + "PipelineStepName": "MyStep", + "Status": "Completed", + }, + "CreationTime": 1608225384.0, + "CreatedBy": {}, + "LastModifiedTime": 1608225384.0, + "LastModifiedBy": {}, + } + + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": "MyArtifact", + "ArtifactArn": arn1, + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", + "SourceType": "ARN", + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", + }, + "ArtifactType": "Model", + "Properties": { + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ + pipeline/mypipeline/execution/0irnteql64d0", + "PipelineStepName": "MyStep", + "Status": "Completed", + }, + "CreationTime": 1608225384.0, + "CreatedBy": {}, + "LastModifiedTime": 1608225384.0, + "LastModifiedBy": {}, + } + + model_list = obj.models_v2() + + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["ModelDeployment"], "LineageTypes": ["Action"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[arn1], + ), + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[arn1], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + expected_model_list = [ + ModelArtifact( + artifact_arn=arn1, + artifact_name="MyArtifact", + source=ArtifactSource( + source_uri="arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", + source_types=None, + source_type="ARN", + source_id="Thu Dec 17 17:16:24 UTC 2020", + ), + artifact_type="Model", + properties={ + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ + pipeline/mypipeline/execution/0irnteql64d0", + "PipelineStepName": "MyStep", + "Status": "Completed", + }, + creation_time=1608225384.0, + created_by={}, + last_modified_time=1608225384.0, + last_modified_by={}, + ) + ] + + assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn + assert expected_model_list[0].artifact_name == model_list[0].artifact_name + assert expected_model_list[0].source == model_list[0].source + assert expected_model_list[0].artifact_type == model_list[0].artifact_type + assert expected_model_list[0].artifact_type == "Model" + assert expected_model_list[0].properties == model_list[0].properties + assert expected_model_list[0].creation_time == model_list[0].creation_time + assert expected_model_list[0].created_by == model_list[0].created_by + assert expected_model_list[0].last_modified_time == model_list[0].last_modified_time + assert expected_model_list[0].last_modified_by == model_list[0].last_modified_by