Skip to content

Commit b3c19d8

Browse files
authored
fix: Remove duplicate vertex/edge in query lineage (#2784)
1 parent 88e4d68 commit b3c19d8

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/sagemaker/lineage/query.py

+56
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code to query SageMaker lineage."""
1414
from __future__ import absolute_import
15+
1516
from datetime import datetime
1617
from enum import Enum
1718
from typing import Optional, Union, List, Dict
19+
1820
from sagemaker.lineage._utils import get_resource_name_from_arn
1921

2022

@@ -65,6 +67,27 @@ def __init__(
6567
self.destination_arn = destination_arn
6668
self.association_type = association_type
6769

70+
def __hash__(self):
71+
"""Define hash function for ``Edge``."""
72+
return hash(
73+
(
74+
"source_arn",
75+
self.source_arn,
76+
"destination_arn",
77+
self.destination_arn,
78+
"association_type",
79+
self.association_type,
80+
)
81+
)
82+
83+
def __eq__(self, other):
84+
"""Define equal function for ``Edge``."""
85+
return (
86+
self.association_type == other.association_type
87+
and self.source_arn == other.source_arn
88+
and self.destination_arn == other.destination_arn
89+
)
90+
6891

6992
class Vertex:
7093
"""A vertex for a lineage graph."""
@@ -82,6 +105,27 @@ def __init__(
82105
self.lineage_source = lineage_source
83106
self._session = sagemaker_session
84107

108+
def __hash__(self):
109+
"""Define hash function for ``Vertex``."""
110+
return hash(
111+
(
112+
"arn",
113+
self.arn,
114+
"lineage_entity",
115+
self.lineage_entity,
116+
"lineage_source",
117+
self.lineage_source,
118+
)
119+
)
120+
121+
def __eq__(self, other):
122+
"""Define equal function for ``Vertex``."""
123+
return (
124+
self.arn == other.arn
125+
and self.lineage_entity == other.lineage_entity
126+
and self.lineage_source == other.lineage_source
127+
)
128+
85129
def to_lineage_object(self):
86130
"""Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object."""
87131
from sagemaker.lineage.artifact import Artifact, ModelArtifact
@@ -210,6 +254,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
210254
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
211255
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
212256

257+
edge_set = set()
258+
for edge in converted.edges:
259+
if edge in edge_set:
260+
converted.edges.remove(edge)
261+
edge_set.add(edge)
262+
263+
vertex_set = set()
264+
for vertex in converted.vertices:
265+
if vertex in vertex_set:
266+
converted.vertices.remove(vertex)
267+
vertex_set.add(vertex)
268+
213269
return converted
214270

215271
def _collapse_cross_account_artifacts(self, query_response):

tests/unit/sagemaker/lineage/test_query.py

+32
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ def test_lineage_query(sagemaker_session):
3232
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
3333
)
3434

35+
assert len(response.edges) == 1
36+
assert response.edges[0].source_arn == "arn1"
37+
assert response.edges[0].destination_arn == "arn2"
38+
assert response.edges[0].association_type == "Produced"
39+
assert len(response.vertices) == 2
40+
41+
assert response.vertices[0].arn == "arn1"
42+
assert response.vertices[0].lineage_source == "Endpoint"
43+
assert response.vertices[0].lineage_entity == "Artifact"
44+
assert response.vertices[1].arn == "arn2"
45+
assert response.vertices[1].lineage_source == "Model"
46+
assert response.vertices[1].lineage_entity == "Context"
47+
48+
49+
def test_lineage_query_duplication(sagemaker_session):
50+
lineage_query = LineageQuery(sagemaker_session)
51+
sagemaker_session.sagemaker_client.query_lineage.return_value = {
52+
"Vertices": [
53+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
54+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
55+
{"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
56+
],
57+
"Edges": [
58+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
59+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
60+
],
61+
}
62+
63+
response = lineage_query.query(
64+
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
65+
)
66+
3567
assert len(response.edges) == 1
3668
assert response.edges[0].source_arn == "arn1"
3769
assert response.edges[0].destination_arn == "arn2"

0 commit comments

Comments
 (0)