Skip to content

Commit 7b2bcf1

Browse files
authored
Merge pull request #6 from ytlee93/master
change: add queryLineageResult visualizer load test & integ test
2 parents db9c3a3 + 23fe126 commit 7b2bcf1

File tree

6 files changed

+347
-3
lines changed

6 files changed

+347
-3
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
urllib3==1.26.8
22
docker-compose==1.29.2
33
docker~=5.0.0
4-
PyYAML==5.4.1
4+
PyYAML==5.4.1

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ fabric==2.6.0
1818
requests==2.27.1
1919
sagemaker-experiments==0.1.35
2020
Jinja2==3.0.3
21+
pyvis==0.2.1

src/sagemaker/lineage/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _get_visualization_elements(self):
364364
elements = {"nodes": verts, "edges": edges}
365365
return elements
366366

367-
def visualize(self):
367+
def visualize(self, path="pyvisExample.html"):
368368
"""Visualize lineage query result."""
369369
lineage_graph = {
370370
# nodes can have shape / color
@@ -398,7 +398,7 @@ def visualize(self):
398398

399399
pyvis_vis = PyvisVisualizer(lineage_graph)
400400
elements = self._get_visualization_elements()
401-
return pyvis_vis.render(elements=elements)
401+
return pyvis_vis.render(elements=elements, path=path)
402402

403403

404404
class LineageFilter(object):

tests/integ/sagemaker/lineage/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import logging
2121
import uuid
22+
import json
2223
from sagemaker.lineage import (
2324
action,
2425
context,
@@ -891,3 +892,17 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session):
891892
pass
892893
else:
893894
raise (e)
895+
896+
897+
@pytest.fixture
898+
def extract_data_from_html():
899+
def _method(data):
900+
start = data.find("[")
901+
end = data.find("]")
902+
res = data[start + 1 : end].split("}, ")
903+
res = [i + "}" for i in res]
904+
res[-1] = res[-1][:-1]
905+
data_dict = [json.loads(i) for i in res]
906+
return data_dict
907+
908+
return _method

tests/integ/sagemaker/lineage/helpers.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,94 @@ def visit(arn, visited: set):
7878

7979
ret = []
8080
return visit(start_arn, set())
81+
82+
83+
class LineageResourceHelper:
84+
def __init__(self, sagemaker_session):
85+
self.client = sagemaker_session.sagemaker_client
86+
self.artifacts = []
87+
self.actions = []
88+
self.contexts = []
89+
self.associations = []
90+
91+
def create_artifact(self, artifact_name, artifact_type="Dataset"):
92+
response = self.client.create_artifact(
93+
ArtifactName=artifact_name,
94+
Source={
95+
"SourceUri": "Test-artifact-" + artifact_name,
96+
"SourceTypes": [
97+
{"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"},
98+
],
99+
},
100+
ArtifactType=artifact_type,
101+
)
102+
self.artifacts.append(response["ArtifactArn"])
103+
104+
return response["ArtifactArn"]
105+
106+
def create_action(self, action_name, action_type="ModelDeployment"):
107+
response = self.client.create_action(
108+
ActionName=action_name,
109+
Source={
110+
"SourceUri": "Test-action-" + action_name,
111+
"SourceType": "S3ETag",
112+
"SourceId": "Test-action-sourceId-value",
113+
},
114+
ActionType=action_type,
115+
)
116+
self.actions.append(response["ActionArn"])
117+
118+
return response["ActionArn"]
119+
120+
def create_context(self, context_name, context_type="Endpoint"):
121+
response = self.client.create_context(
122+
ContextName=context_name,
123+
Source={
124+
"SourceUri": "Test-context-" + context_name,
125+
"SourceType": "S3ETag",
126+
"SourceId": "Test-context-sourceId-value",
127+
},
128+
ContextType=context_type,
129+
)
130+
self.contexts.append(response["ContextArn"])
131+
132+
return response["ContextArn"]
133+
134+
def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"):
135+
response = self.client.add_association(
136+
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type
137+
)
138+
if "SourceArn" in response.keys():
139+
self.associations.append((source_arn, dest_arn))
140+
return True
141+
else:
142+
return False
143+
144+
def clean_all(self):
145+
# clean all lineage data created by LineageResourceHelper
146+
147+
time.sleep(1) # avoid GSI race condition between create & delete
148+
149+
for source, dest in self.associations:
150+
try:
151+
self.client.delete_association(SourceArn=source, DestinationArn=dest)
152+
except Exception as e:
153+
print("skipped " + str(e))
154+
155+
for artifact_arn in self.artifacts:
156+
try:
157+
self.client.delete_artifact(ArtifactArn=artifact_arn)
158+
except Exception as e:
159+
print("skipped " + str(e))
160+
161+
for action_arn in self.actions:
162+
try:
163+
self.client.delete_action(ActionArn=action_arn)
164+
except Exception as e:
165+
print("skipped " + str(e))
166+
167+
for context_arn in self.contexts:
168+
try:
169+
self.client.delete_context(ContextArn=context_arn)
170+
except Exception as e:
171+
print("skipped " + str(e))
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``"""
14+
from __future__ import absolute_import
15+
import time
16+
import os
17+
18+
import pytest
19+
20+
import sagemaker.lineage.query
21+
from sagemaker.lineage.query import LineageQueryDirectionEnum
22+
from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper
23+
24+
25+
def test_LineageResourceHelper(sagemaker_session):
26+
# check if LineageResourceHelper works properly
27+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
28+
try:
29+
art1 = lineage_resource_helper.create_artifact(artifact_name=name())
30+
art2 = lineage_resource_helper.create_artifact(artifact_name=name())
31+
lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2)
32+
except Exception as e:
33+
print(e)
34+
assert False
35+
finally:
36+
lineage_resource_helper.clean_all()
37+
38+
39+
@pytest.mark.skip("visualizer load test")
40+
def test_wide_graph_visualize(sagemaker_session):
41+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
42+
wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
43+
44+
# create wide graph
45+
# Artifact ----> Artifact
46+
# \ \ \-> Artifact
47+
# \ \--> Artifact
48+
# \---> ...
49+
try:
50+
for i in range(150):
51+
artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
52+
lineage_resource_helper.create_association(
53+
source_arn=wide_graph_root_arn, dest_arn=artifact_arn
54+
)
55+
56+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
57+
lq_result = lq.query(start_arns=[wide_graph_root_arn])
58+
lq_result.visualize(path="wideGraph.html")
59+
60+
print("vertex len = ")
61+
print(len(lq_result.vertices))
62+
assert False
63+
64+
except Exception as e:
65+
print(e)
66+
assert False
67+
68+
finally:
69+
lineage_resource_helper.clean_all()
70+
71+
72+
@pytest.mark.skip("visualizer load test")
73+
def test_long_graph_visualize(sagemaker_session):
74+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
75+
long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
76+
last_arn = long_graph_root_arn
77+
78+
# create long graph
79+
# Artifact -> Artifact -> ... -> Artifact
80+
try:
81+
for i in range(10):
82+
new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
83+
lineage_resource_helper.create_association(
84+
source_arn=last_arn, dest_arn=new_artifact_arn
85+
)
86+
last_arn = new_artifact_arn
87+
88+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
89+
lq_result = lq.query(
90+
start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS
91+
)
92+
# max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction)
93+
lq_result.visualize(path="longGraph.html")
94+
95+
except Exception as e:
96+
print(e)
97+
assert False
98+
99+
finally:
100+
lineage_resource_helper.clean_all()
101+
102+
103+
def test_graph_visualize(sagemaker_session, extract_data_from_html):
104+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
105+
106+
# create lineage data
107+
# image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context
108+
# /->
109+
# dataset artifact -/
110+
try:
111+
graph_startarn = lineage_resource_helper.create_artifact(
112+
artifact_name=name(), artifact_type="Model"
113+
)
114+
image_artifact = lineage_resource_helper.create_artifact(
115+
artifact_name=name(), artifact_type="Image"
116+
)
117+
lineage_resource_helper.create_association(
118+
source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo"
119+
)
120+
dataset_artifact = lineage_resource_helper.create_artifact(
121+
artifact_name=name(), artifact_type="DataSet"
122+
)
123+
lineage_resource_helper.create_association(
124+
source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith"
125+
)
126+
modeldeploy_action = lineage_resource_helper.create_action(
127+
action_name=name(), action_type="ModelDeploy"
128+
)
129+
lineage_resource_helper.create_association(
130+
source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo"
131+
)
132+
endpoint_context = lineage_resource_helper.create_context(
133+
context_name=name(), context_type="Endpoint"
134+
)
135+
lineage_resource_helper.create_association(
136+
source_arn=modeldeploy_action,
137+
dest_arn=endpoint_context,
138+
association_type="AssociatedWith",
139+
)
140+
time.sleep(3)
141+
142+
# visualize
143+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
144+
lq_result = lq.query(start_arns=[graph_startarn])
145+
lq_result.visualize(path="testGraph.html")
146+
147+
# check generated graph info
148+
fo = open("testGraph.html", "r")
149+
lines = fo.readlines()
150+
for line in lines:
151+
if "nodes = " in line:
152+
node = line
153+
if "edges = " in line:
154+
edge = line
155+
156+
node_dict = extract_data_from_html(node)
157+
edge_dict = extract_data_from_html(edge)
158+
159+
# check node number
160+
assert len(node_dict) == 5
161+
162+
expected_nodes = {
163+
graph_startarn: {
164+
"color": "#146eb4",
165+
"label": "Model",
166+
"shape": "star",
167+
"title": "Artifact",
168+
},
169+
image_artifact: {
170+
"color": "#146eb4",
171+
"label": "Image",
172+
"shape": "dot",
173+
"title": "Artifact",
174+
},
175+
dataset_artifact: {
176+
"color": "#146eb4",
177+
"label": "DataSet",
178+
"shape": "dot",
179+
"title": "Artifact",
180+
},
181+
modeldeploy_action: {
182+
"color": "#88c396",
183+
"label": "ModelDeploy",
184+
"shape": "dot",
185+
"title": "Action",
186+
},
187+
endpoint_context: {
188+
"color": "#ff9900",
189+
"label": "Endpoint",
190+
"shape": "dot",
191+
"title": "Context",
192+
},
193+
}
194+
195+
# check node properties
196+
for node in node_dict:
197+
for label, val in expected_nodes[node["id"]].items():
198+
assert node[label] == val
199+
200+
# check edge number
201+
assert len(edge_dict) == 4
202+
203+
expected_edges = {
204+
image_artifact: {
205+
"from": image_artifact,
206+
"to": graph_startarn,
207+
"title": "ContributedTo",
208+
},
209+
dataset_artifact: {
210+
"from": dataset_artifact,
211+
"to": graph_startarn,
212+
"title": "AssociatedWith",
213+
},
214+
graph_startarn: {
215+
"from": graph_startarn,
216+
"to": modeldeploy_action,
217+
"title": "ContributedTo",
218+
},
219+
modeldeploy_action: {
220+
"from": modeldeploy_action,
221+
"to": endpoint_context,
222+
"title": "AssociatedWith",
223+
},
224+
}
225+
226+
# check edge properties
227+
for edge in edge_dict:
228+
for label, val in expected_edges[edge["from"]].items():
229+
assert edge[label] == val
230+
231+
except Exception as e:
232+
print(e)
233+
assert False
234+
235+
finally:
236+
lineage_resource_helper.clean_all()
237+
os.remove("testGraph.html")

0 commit comments

Comments
 (0)