Skip to content

Commit b9e1104

Browse files
committed
fix:Remove duplicate vertex/edge in query lineage
1 parent 87c1d2c commit b9e1104

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/sagemaker/lineage/query.py

+56
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code to query SageMaker lineage."""
1414
from __future__ import absolute_import
15+
1516
from datetime import datetime
1617
from enum import Enum
1718
from typing import Optional, Union, List, Dict
19+
1820
from sagemaker.lineage._utils import get_resource_name_from_arn
1921

2022

@@ -65,6 +67,27 @@ def __init__(
6567
self.destination_arn = destination_arn
6668
self.association_type = association_type
6769

70+
def __hash__(self):
71+
"""Define hash function for ``Edge``."""
72+
return hash(
73+
(
74+
"source_arn",
75+
self.source_arn,
76+
"destination_arn",
77+
self.destination_arn,
78+
"association_type",
79+
self.association_type,
80+
)
81+
)
82+
83+
def __eq__(self, other):
84+
"""Define equal function for ``Edge``."""
85+
return (
86+
self.association_type == other.association_type
87+
and self.source_arn == other.source_arn
88+
and self.destination_arn == other.destination_arn
89+
)
90+
6891

6992
class Vertex:
7093
"""A vertex for a lineage graph."""
@@ -82,6 +105,27 @@ def __init__(
82105
self.lineage_source = lineage_source
83106
self._session = sagemaker_session
84107

108+
def __hash__(self):
109+
"""Define hash function for ``Vertex``."""
110+
return hash(
111+
(
112+
"arn",
113+
self.arn,
114+
"lineage_entity",
115+
self.lineage_entity,
116+
"lineage_source",
117+
self.lineage_source,
118+
)
119+
)
120+
121+
def __eq__(self, other):
122+
"""Define equal function for ``Vertex``."""
123+
return (
124+
self.arn == other.arn
125+
and self.lineage_entity == other.lineage_entity
126+
and self.lineage_source == other.lineage_source
127+
)
128+
85129
def to_lineage_object(self):
86130
"""Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object."""
87131
from sagemaker.lineage.artifact import Artifact, ModelArtifact
@@ -206,6 +250,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
206250
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
207251
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
208252

253+
edge_set = set()
254+
for edge in converted.edges:
255+
if edge in edge_set:
256+
converted.edges.remove(edge)
257+
edge_set.add(edge)
258+
259+
vertex_set = set()
260+
for vertex in converted.vertices:
261+
if vertex in vertex_set:
262+
converted.vertices.remove(vertex)
263+
vertex_set.add(vertex)
264+
209265
return converted
210266

211267
def query(

tests/unit/sagemaker/lineage/test_query.py

+32
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,38 @@ def test_lineage_query(sagemaker_session):
3131
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
3232
)
3333

34+
assert len(response.edges) == 1
35+
assert response.edges[0].source_arn == "arn1"
36+
assert response.edges[0].destination_arn == "arn2"
37+
assert response.edges[0].association_type == "Produced"
38+
assert len(response.vertices) == 2
39+
40+
assert response.vertices[0].arn == "arn1"
41+
assert response.vertices[0].lineage_source == "Endpoint"
42+
assert response.vertices[0].lineage_entity == "Artifact"
43+
assert response.vertices[1].arn == "arn2"
44+
assert response.vertices[1].lineage_source == "Model"
45+
assert response.vertices[1].lineage_entity == "Context"
46+
47+
48+
def test_lineage_query_duplication(sagemaker_session):
49+
lineage_query = LineageQuery(sagemaker_session)
50+
sagemaker_session.sagemaker_client.query_lineage.return_value = {
51+
"Vertices": [
52+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
53+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
54+
{"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
55+
],
56+
"Edges": [
57+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
58+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
59+
],
60+
}
61+
62+
response = lineage_query.query(
63+
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
64+
)
65+
3466
assert len(response.edges) == 1
3567
assert response.edges[0].source_arn == "arn1"
3668
assert response.edges[0].destination_arn == "arn2"

0 commit comments

Comments
 (0)