12
12
# language governing permissions and limitations under the License.
13
13
"""This module contains code to query SageMaker lineage."""
14
14
from __future__ import absolute_import
15
+
15
16
from datetime import datetime
16
17
from enum import Enum
17
18
from typing import Optional , Union , List , Dict
19
+
18
20
from sagemaker .lineage ._utils import get_resource_name_from_arn
19
21
20
22
@@ -65,6 +67,27 @@ def __init__(
65
67
self .destination_arn = destination_arn
66
68
self .association_type = association_type
67
69
70
+ def __hash__ (self ):
71
+ """Define hash function for ``Edge``."""
72
+ return hash (
73
+ (
74
+ "source_arn" ,
75
+ self .source_arn ,
76
+ "destination_arn" ,
77
+ self .destination_arn ,
78
+ "association_type" ,
79
+ self .association_type ,
80
+ )
81
+ )
82
+
83
+ def __eq__ (self , other ):
84
+ """Define equal function for ``Edge``."""
85
+ return (
86
+ self .association_type == other .association_type
87
+ and self .source_arn == other .source_arn
88
+ and self .destination_arn == other .destination_arn
89
+ )
90
+
68
91
69
92
class Vertex :
70
93
"""A vertex for a lineage graph."""
@@ -82,6 +105,27 @@ def __init__(
82
105
self .lineage_source = lineage_source
83
106
self ._session = sagemaker_session
84
107
108
+ def __hash__ (self ):
109
+ """Define hash function for ``Vertex``."""
110
+ return hash (
111
+ (
112
+ "arn" ,
113
+ self .arn ,
114
+ "lineage_entity" ,
115
+ self .lineage_entity ,
116
+ "lineage_source" ,
117
+ self .lineage_source ,
118
+ )
119
+ )
120
+
121
+ def __eq__ (self , other ):
122
+ """Define equal function for ``Vertex``."""
123
+ return (
124
+ self .arn == other .arn
125
+ and self .lineage_entity == other .lineage_entity
126
+ and self .lineage_source == other .lineage_source
127
+ )
128
+
85
129
def to_lineage_object (self ):
86
130
"""Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object."""
87
131
from sagemaker .lineage .artifact import Artifact , ModelArtifact
@@ -206,6 +250,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
206
250
converted .edges = [self ._get_edge (edge ) for edge in response ["Edges" ]]
207
251
converted .vertices = [self ._get_vertex (vertex ) for vertex in response ["Vertices" ]]
208
252
253
+ edge_set = set ()
254
+ for edge in converted .edges :
255
+ if edge in edge_set :
256
+ converted .edges .remove (edge )
257
+ edge_set .add (edge )
258
+
259
+ vertex_set = set ()
260
+ for vertex in converted .vertices :
261
+ if vertex in vertex_set :
262
+ converted .vertices .remove (vertex )
263
+ vertex_set .add (vertex )
264
+
209
265
return converted
210
266
211
267
def query (
0 commit comments