Skip to content

fix: Collapse cross-account artifacts in query lineage response #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,44 @@ def _convert_api_response(self, response) -> LineageQueryResult:

return converted

def _collapse_cross_account_artifacts(self, query_response):
"""Collapse the duplicate vertices and edges for cross-account."""
for edge in query_response.edges:
if (
"artifact" in edge.source_arn
and "artifact" in edge.destination_arn
and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1]
and edge.source_arn != edge.destination_arn
):
edge_source_arn = edge.source_arn
edge_destination_arn = edge.destination_arn
self._update_cross_account_edge(
edges=query_response.edges,
arn=edge_source_arn,
duplicate_arn=edge_destination_arn,
)
self._update_cross_account_vertex(
query_response=query_response, duplicate_arn=edge_destination_arn
)

# remove the duplicate edges from cross account
new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn]
query_response.edges = new_edge

return query_response

def _update_cross_account_edge(self, edges, arn, duplicate_arn):
"""Replace the duplicate arn with arn in edges list."""
for idx, e in enumerate(edges):
if e.destination_arn == duplicate_arn:
edges[idx].destination_arn = arn
elif e.source_arn == duplicate_arn:
edges[idx].source_arn = arn

def _update_cross_account_vertex(self, query_response, duplicate_arn):
"""Remove the vertex with duplicate arn in the vertices list."""
query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn]

def query(
self,
start_arns: List[str],
Expand Down Expand Up @@ -235,5 +273,7 @@ def query(
Filters=query_filter._to_request_dict() if query_filter else {},
MaxDepth=max_depth,
)
query_response = self._convert_api_response(query_response)
query_response = self._collapse_cross_account_artifacts(query_response)

return self._convert_api_response(query_response)
return query_response
137 changes: 137 additions & 0 deletions tests/unit/sagemaker/lineage/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,143 @@ def test_lineage_query(sagemaker_session):
assert response.vertices[1].lineage_entity == "Context"


def test_lineage_query_cross_account_same_artifact(sagemaker_session):
lineage_query = LineageQuery(sagemaker_session)
sagemaker_session.sagemaker_client.query_lineage.return_value = {
"Vertices": [
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"Type": "Endpoint",
"LineageType": "Artifact",
},
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"Type": "Endpoint",
"LineageType": "Artifact",
},
],
"Edges": [
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"AssociationType": "SAME_AS",
},
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"AssociationType": "SAME_AS",
},
],
}

response = lineage_query.query(
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
)
assert len(response.edges) == 0
assert len(response.vertices) == 1
assert (
response.vertices[0].arn
== "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0"
)
assert response.vertices[0].lineage_source == "Endpoint"
assert response.vertices[0].lineage_entity == "Artifact"


def test_lineage_query_cross_account(sagemaker_session):
lineage_query = LineageQuery(sagemaker_session)
sagemaker_session.sagemaker_client.query_lineage.return_value = {
"Vertices": [
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"Type": "Endpoint",
"LineageType": "Artifact",
},
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"Type": "Endpoint",
"LineageType": "Artifact",
},
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd",
"Type": "Endpoint",
"LineageType": "Artifact",
},
{
"Arn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh",
"Type": "Endpoint",
"LineageType": "Artifact",
},
],
"Edges": [
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"AssociationType": "SAME_AS",
},
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
"AssociationType": "SAME_AS",
},
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd",
"AssociationType": "ABC",
},
{
"SourceArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd",
"DestinationArn": "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh",
"AssociationType": "DEF",
},
],
}

response = lineage_query.query(
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
)

assert len(response.edges) == 2
assert (
response.edges[0].source_arn
== "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0"
)
assert (
response.edges[0].destination_arn
== "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd"
)
assert response.edges[0].association_type == "ABC"

assert (
response.edges[1].source_arn
== "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd"
)
assert (
response.edges[1].destination_arn
== "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh"
)
assert response.edges[1].association_type == "DEF"

assert len(response.vertices) == 3
assert (
response.vertices[0].arn
== "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0"
)
assert response.vertices[0].lineage_source == "Endpoint"
assert response.vertices[0].lineage_entity == "Artifact"
assert (
response.vertices[1].arn
== "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9abcd"
)
assert response.vertices[1].lineage_source == "Endpoint"
assert response.vertices[1].lineage_entity == "Artifact"
assert (
response.vertices[2].arn
== "arn:aws:sagemaker:us-east-2:012345678903:artifact/e1f29799189751939405b0f2b5b9efgh"
)
assert response.vertices[2].lineage_source == "Endpoint"
assert response.vertices[2].lineage_entity == "Artifact"


def test_vertex_to_object_endpoint_context(sagemaker_session):
vertex = Vertex(
arn="arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext",
Expand Down