diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index f2d1bf8c14..6592aeaafe 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -208,6 +208,44 @@ def _convert_api_response(self, response) -> LineageQueryResult: return converted + def _collapse_cross_account_artifacts(self, query_response): + """Collapse the duplicate vertices and edges for cross-account.""" + for edge in query_response.edges: + if ( + "artifact" in edge.source_arn + and "artifact" in edge.destination_arn + and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1] + and edge.source_arn != edge.destination_arn + ): + edge_source_arn = edge.source_arn + edge_destination_arn = edge.destination_arn + self._update_cross_account_edge( + edges=query_response.edges, + arn=edge_source_arn, + duplicate_arn=edge_destination_arn, + ) + self._update_cross_account_vertex( + query_response=query_response, duplicate_arn=edge_destination_arn + ) + + # remove the duplicate edges from cross account + new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn] + query_response.edges = new_edge + + return query_response + + def _update_cross_account_edge(self, edges, arn, duplicate_arn): + """Replace the duplicate arn with arn in edges list.""" + for idx, e in enumerate(edges): + if e.destination_arn == duplicate_arn: + edges[idx].destination_arn = arn + elif e.source_arn == duplicate_arn: + edges[idx].source_arn = arn + + def _update_cross_account_vertex(self, query_response, duplicate_arn): + """Remove the vertex with duplicate arn in the vertices list.""" + query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn] + def query( self, start_arns: List[str], @@ -235,5 +273,7 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) + query_response = self._convert_api_response(query_response) + query_response = self._collapse_cross_account_artifacts(query_response) - return self._convert_api_response(query_response) + return query_response diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 17d3eabe92..df288234e7 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -44,6 +44,143 @@ def test_lineage_query(sagemaker_session): assert response.vertices[1].lineage_entity == "Context" +def test_lineage_query_cross_account_same_artifact(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + ], + "Edges": [ + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + ], + } + + response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + assert len(response.edges) == 0 + assert len(response.vertices) == 1 + assert ( + response.vertices[0].arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert response.vertices[0].lineage_source == "Endpoint" + assert response.vertices[0].lineage_entity == "Artifact" + + +def test_lineage_query_cross_account(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + { + "Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh", + "Type": "Endpoint", + "LineageType": "Artifact", + }, + ], + "Edges": [ + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", + "AssociationType": "SAME_AS", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "AssociationType": "ABC", + }, + { + "SourceArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd", + "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh", + "AssociationType": "DEF", + }, + ], + } + + response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + assert len(response.edges) == 2 + assert ( + response.edges[0].source_arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert ( + response.edges[0].destination_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert response.edges[0].association_type == "ABC" + + assert ( + response.edges[1].source_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert ( + response.edges[1].destination_arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh" + ) + assert response.edges[1].association_type == "DEF" + + assert len(response.vertices) == 3 + assert ( + response.vertices[0].arn + == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" + ) + assert response.vertices[0].lineage_source == "Endpoint" + assert response.vertices[0].lineage_entity == "Artifact" + assert ( + response.vertices[1].arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd" + ) + assert response.vertices[1].lineage_source == "Endpoint" + assert response.vertices[1].lineage_entity == "Artifact" + assert ( + response.vertices[2].arn + == "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh" + ) + assert response.vertices[2].lineage_source == "Endpoint" + assert response.vertices[2].lineage_entity == "Artifact" + + def test_vertex_to_object_endpoint_context(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext",