Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit db9c3a3

Browse files
authoredJul 29, 2022
Merge pull request #5 from ytlee93/master
change: Changed to use pyvis library for visualization
2 parents e6078a9 + 3bab76a commit db9c3a3

File tree

2 files changed

+78
-206
lines changed

2 files changed

+78
-206
lines changed
 

‎.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ venv/
2929
env/
3030
.vscode/
3131
**/tmp
32-
.python-version
32+
.python-version
33+
*.html

‎src/sagemaker/lineage/query.py

Lines changed: 76 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from enum import Enum
1818
from typing import Optional, Union, List, Dict
1919

20-
from sagemaker.lineage._utils import get_resource_name_from_arn
20+
from sagemaker.lineage._utils import get_resource_name_from_arn, get_module
2121

2222

2323
class LineageEntityEnum(Enum):
@@ -201,194 +201,81 @@ def _artifact_to_lineage_object(self):
201201
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
202202

203203

204-
class DashVisualizer(object):
205-
"""Create object used for visualizing graph using Dash library."""
204+
class PyvisVisualizer(object):
205+
"""Create object used for visualizing graph using Pyvis library."""
206206

207207
def __init__(self, graph_styles):
208-
"""Init for DashVisualizer."""
208+
"""Init for PyvisVisualizer."""
209209
# import visualization packages
210210
(
211-
self.cyto,
212-
self.JupyterDash,
213-
self.html,
214-
self.Input,
215-
self.Output,
211+
self.Network,
212+
self.Options,
216213
) = self._import_visual_modules()
217214

218215
self.graph_styles = graph_styles
219216

217+
# pyvis graph options
218+
self._options = """
219+
var options = {
220+
"configure":{
221+
"enabled": false
222+
},
223+
"layout": {
224+
"hierarchical": {
225+
"enabled": true,
226+
"blockShifting": true,
227+
"direction": "LR",
228+
"sortMethod": "directed",
229+
"shakeTowards": "leaves"
230+
}
231+
},
232+
"interaction": {
233+
"multiselect": true,
234+
"navigationButtons": true
235+
},
236+
"physics": {
237+
"enabled": false,
238+
"hierarchicalRepulsion": {
239+
"centralGravity": 0,
240+
"avoidOverlap": null
241+
},
242+
"minVelocity": 0.75,
243+
"solver": "hierarchicalRepulsion"
244+
}
245+
}
246+
"""
247+
220248
def _import_visual_modules(self):
221249
"""Import modules needed for visualization."""
222-
try:
223-
import dash_cytoscape as cyto
224-
except ImportError as e:
225-
print(e)
226-
print("Try: pip install dash-cytoscape")
227-
raise
228-
229-
try:
230-
from jupyter_dash import JupyterDash
231-
except ImportError as e:
232-
print(e)
233-
print("Try: pip install jupyter-dash")
234-
raise
235-
236-
try:
237-
from dash import html
238-
except ImportError as e:
239-
print(e)
240-
print("Try: pip install dash")
241-
raise
242-
243-
try:
244-
from dash.dependencies import Input, Output
245-
except ImportError as e:
246-
print(e)
247-
print("Try: pip install dash")
248-
raise
249-
250-
return cyto, JupyterDash, html, Input, Output
251-
252-
def _create_legend_component(self, style):
253-
"""Create legend component div."""
254-
text = style["name"]
255-
symbol = ""
256-
color = "#ffffff"
257-
if style["isShape"] == "False":
258-
color = style["style"]["background-color"]
259-
else:
260-
symbol = style["symbol"]
261-
return self.html.Div(
262-
[
263-
self.html.Div(
264-
symbol,
265-
style={
266-
"background-color": color,
267-
"width": "1.5vw",
268-
"height": "1.5vw",
269-
"display": "inline-block",
270-
"font-size": "1.5vw",
271-
},
272-
),
273-
self.html.Div(
274-
style={
275-
"width": "0.5vw",
276-
"height": "1.5vw",
277-
"display": "inline-block",
278-
}
279-
),
280-
self.html.Div(
281-
text,
282-
style={"display": "inline-block", "font-size": "1.5vw"},
283-
),
284-
]
285-
)
286-
287-
def _create_entity_selector(self, entity_name, style):
288-
"""Create selector for each lineage entity."""
289-
return {"selector": "." + entity_name, "style": style["style"]}
290-
291-
def _get_app(self, elements):
292-
"""Create JupyterDash app for interactivity on Jupyter notebook."""
293-
app = self.JupyterDash(__name__)
294-
self.cyto.load_extra_layouts()
295-
296-
app.layout = self.html.Div(
297-
[
298-
# graph section
299-
self.cyto.Cytoscape(
300-
id="cytoscape-graph",
301-
elements=elements,
302-
style={
303-
"width": "84%",
304-
"height": "350px",
305-
"display": "inline-block",
306-
"border-width": "1vw",
307-
"border-color": "#232f3e",
308-
},
309-
layout={"name": "klay"},
310-
stylesheet=[
311-
{
312-
"selector": "node",
313-
"style": {
314-
"label": "data(label)",
315-
"font-size": "3.5vw",
316-
"height": "10vw",
317-
"width": "10vw",
318-
"border-width": "0.8",
319-
"border-opacity": "0",
320-
"border-color": "#232f3e",
321-
"font-family": "verdana",
322-
},
323-
},
324-
{
325-
"selector": "edge",
326-
"style": {
327-
"label": "data(label)",
328-
"color": "gray",
329-
"text-halign": "left",
330-
"text-margin-y": "2.5",
331-
"font-size": "3",
332-
"width": "1",
333-
"curve-style": "bezier",
334-
"control-point-step-size": "15",
335-
"target-arrow-color": "gray",
336-
"target-arrow-shape": "triangle",
337-
"line-color": "gray",
338-
"arrow-scale": "0.5",
339-
"font-family": "verdana",
340-
},
341-
},
342-
{"selector": ".select", "style": {"border-opacity": "0.7"}},
343-
]
344-
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
345-
responsive=True,
346-
),
347-
self.html.Div(
348-
style={
349-
"width": "0.5%",
350-
"display": "inline-block",
351-
"font-size": "1vw",
352-
"font-family": "verdana",
353-
"vertical-align": "top",
354-
},
355-
),
356-
# legend section
357-
self.html.Div(
358-
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
359-
style={
360-
"display": "inline-block",
361-
"font-size": "1vw",
362-
"font-family": "verdana",
363-
"vertical-align": "top",
364-
},
365-
),
366-
]
367-
)
368-
369-
@app.callback(
370-
self.Output("cytoscape-graph", "elements"),
371-
self.Input("cytoscape-graph", "tapNodeData"),
372-
self.Input("cytoscape-graph", "elements"),
373-
)
374-
def selectNode(tapData, elements):
375-
for n in elements:
376-
if tapData is not None and n["data"]["id"] == tapData["id"]:
377-
# if is tapped node, add "select" class to node
378-
n["classes"] += " select"
379-
elif "classes" in n:
380-
# remove "select" class in "classes" if node not selected
381-
n["classes"] = n["classes"].replace("select", "")
250+
get_module("pyvis")
251+
from pyvis.network import Network
252+
from pyvis.options import Options
382253

383-
return elements
254+
return Network, Options
384255

385-
return app
256+
def _node_color(self, entity):
257+
"""Return node color by background-color specified in graph styles."""
258+
return self.graph_styles[entity]["style"]["background-color"]
386259

387-
def render(self, elements, mode):
260+
def render(self, elements, path="pyvisExample.html"):
388261
"""Render graph for lineage query result."""
389-
app = self._get_app(elements)
262+
net = self.Network(height="500px", width="100%", notebook=True, directed=True)
263+
net.set_options(self._options)
264+
265+
# add nodes to graph
266+
for arn, source, entity, is_start_arn in elements["nodes"]:
267+
if is_start_arn: # startarn
268+
net.add_node(
269+
arn, label=source, title=entity, color=self._node_color(entity), shape="star"
270+
)
271+
else:
272+
net.add_node(arn, label=source, title=entity, color=self._node_color(entity))
390273

391-
return app.run_server(mode=mode)
274+
# add edges to graph
275+
for src, dest, asso_type in elements["edges"]:
276+
net.add_edge(src, dest, title=asso_type)
277+
278+
return net.show(path)
392279

393280

394281
class LineageQueryResult(object):
@@ -449,49 +336,36 @@ def __str__(self):
449336
result_dict = vars(self)
450337
return str({k: [str(val) for val in v] for k, v in result_dict.items()})
451338

339+
def _covert_edges_to_tuples(self):
340+
"""Convert edges to tuple format for visualizer."""
341+
edges = []
342+
# get edge info in the form of (source, target, label)
343+
for edge in self.edges:
344+
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
345+
return edges
346+
452347
def _covert_vertices_to_tuples(self):
453348
"""Convert vertices to tuple format for visualizer."""
454349
verts = []
455350
# get vertex info in the form of (id, label, class)
456351
for vert in self.vertices:
457352
if vert.arn in self.startarn:
458353
# add "startarn" class to node if arn is a startarn
459-
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
354+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True))
460355
else:
461-
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
356+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False))
462357
return verts
463358

464-
def _covert_edges_to_tuples(self):
465-
"""Convert edges to tuple format for visualizer."""
466-
edges = []
467-
# get edge info in the form of (source, target, label)
468-
for edge in self.edges:
469-
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
470-
return edges
471-
472359
def _get_visualization_elements(self):
473-
"""Get elements for visualization."""
474-
# get vertices and edges info for graph
360+
"""Get elements(nodes+edges) for visualization."""
475361
verts = self._covert_vertices_to_tuples()
476362
edges = self._covert_edges_to_tuples()
477363

478-
nodes = [
479-
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
480-
]
481-
482-
edges = [
483-
{"data": {"source": source, "target": target, "label": label}}
484-
for source, target, label in edges
485-
]
486-
487-
elements = nodes + edges
488-
364+
elements = {"nodes": verts, "edges": edges}
489365
return elements
490366

491367
def visualize(self):
492368
"""Visualize lineage query result."""
493-
elements = self._get_visualization_elements()
494-
495369
lineage_graph = {
496370
# nodes can have shape / color
497371
"TrialComponent": {
@@ -522,12 +396,9 @@ def visualize(self):
522396
},
523397
}
524398

525-
# initialize DashVisualizer instance to render graph & interactive components
526-
dash_vis = DashVisualizer(lineage_graph)
527-
528-
dash_server = dash_vis.render(elements=elements, mode="inline")
529-
530-
return dash_server
399+
pyvis_vis = PyvisVisualizer(lineage_graph)
400+
elements = self._get_visualization_elements()
401+
return pyvis_vis.render(elements=elements)
531402

532403

533404
class LineageFilter(object):

0 commit comments

Comments
 (0)
Please sign in to comment.