Skip to content

change: Changed to use pyvis library for visualization #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ venv/
env/
.vscode/
**/tmp
.python-version
.python-version
*.html
281 changes: 76 additions & 205 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
from typing import Optional, Union, List, Dict

from sagemaker.lineage._utils import get_resource_name_from_arn
from sagemaker.lineage._utils import get_resource_name_from_arn, get_module


class LineageEntityEnum(Enum):
Expand Down Expand Up @@ -201,194 +201,81 @@ def _artifact_to_lineage_object(self):
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)


class DashVisualizer(object):
"""Create object used for visualizing graph using Dash library."""
class PyvisVisualizer(object):
"""Create object used for visualizing graph using Pyvis library."""

def __init__(self, graph_styles):
"""Init for DashVisualizer."""
"""Init for PyvisVisualizer."""
# import visualization packages
(
self.cyto,
self.JupyterDash,
self.html,
self.Input,
self.Output,
self.Network,
self.Options,
) = self._import_visual_modules()

self.graph_styles = graph_styles

# pyvis graph options
self._options = """
var options = {
"configure":{
"enabled": false
},
"layout": {
"hierarchical": {
"enabled": true,
"blockShifting": true,
"direction": "LR",
"sortMethod": "directed",
"shakeTowards": "leaves"
}
},
"interaction": {
"multiselect": true,
"navigationButtons": true
},
"physics": {
"enabled": false,
"hierarchicalRepulsion": {
"centralGravity": 0,
"avoidOverlap": null
},
"minVelocity": 0.75,
"solver": "hierarchicalRepulsion"
}
}
"""

def _import_visual_modules(self):
"""Import modules needed for visualization."""
try:
import dash_cytoscape as cyto
except ImportError as e:
print(e)
print("Try: pip install dash-cytoscape")
raise

try:
from jupyter_dash import JupyterDash
except ImportError as e:
print(e)
print("Try: pip install jupyter-dash")
raise

try:
from dash import html
except ImportError as e:
print(e)
print("Try: pip install dash")
raise

try:
from dash.dependencies import Input, Output
except ImportError as e:
print(e)
print("Try: pip install dash")
raise

return cyto, JupyterDash, html, Input, Output

def _create_legend_component(self, style):
"""Create legend component div."""
text = style["name"]
symbol = ""
color = "#ffffff"
if style["isShape"] == "False":
color = style["style"]["background-color"]
else:
symbol = style["symbol"]
return self.html.Div(
[
self.html.Div(
symbol,
style={
"background-color": color,
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
"font-size": "1.5vw",
},
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
text,
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
)

def _create_entity_selector(self, entity_name, style):
"""Create selector for each lineage entity."""
return {"selector": "." + entity_name, "style": style["style"]}

def _get_app(self, elements):
"""Create JupyterDash app for interactivity on Jupyter notebook."""
app = self.JupyterDash(__name__)
self.cyto.load_extra_layouts()

app.layout = self.html.Div(
[
# graph section
self.cyto.Cytoscape(
id="cytoscape-graph",
elements=elements,
style={
"width": "84%",
"height": "350px",
"display": "inline-block",
"border-width": "1vw",
"border-color": "#232f3e",
},
layout={"name": "klay"},
stylesheet=[
{
"selector": "node",
"style": {
"label": "data(label)",
"font-size": "3.5vw",
"height": "10vw",
"width": "10vw",
"border-width": "0.8",
"border-opacity": "0",
"border-color": "#232f3e",
"font-family": "verdana",
},
},
{
"selector": "edge",
"style": {
"label": "data(label)",
"color": "gray",
"text-halign": "left",
"text-margin-y": "2.5",
"font-size": "3",
"width": "1",
"curve-style": "bezier",
"control-point-step-size": "15",
"target-arrow-color": "gray",
"target-arrow-shape": "triangle",
"line-color": "gray",
"arrow-scale": "0.5",
"font-family": "verdana",
},
},
{"selector": ".select", "style": {"border-opacity": "0.7"}},
]
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
responsive=True,
),
self.html.Div(
style={
"width": "0.5%",
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
# legend section
self.html.Div(
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
style={
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
]
)

@app.callback(
self.Output("cytoscape-graph", "elements"),
self.Input("cytoscape-graph", "tapNodeData"),
self.Input("cytoscape-graph", "elements"),
)
def selectNode(tapData, elements):
for n in elements:
if tapData is not None and n["data"]["id"] == tapData["id"]:
# if is tapped node, add "select" class to node
n["classes"] += " select"
elif "classes" in n:
# remove "select" class in "classes" if node not selected
n["classes"] = n["classes"].replace("select", "")
get_module("pyvis")
from pyvis.network import Network
from pyvis.options import Options

return elements
return Network, Options

return app
def _node_color(self, entity):
"""Return node color by background-color specified in graph styles."""
return self.graph_styles[entity]["style"]["background-color"]

def render(self, elements, mode):
def render(self, elements, path="pyvisExample.html"):
"""Render graph for lineage query result."""
app = self._get_app(elements)
net = self.Network(height="500px", width="100%", notebook=True, directed=True)
net.set_options(self._options)

# add nodes to graph
for arn, source, entity, is_start_arn in elements["nodes"]:
if is_start_arn: # startarn
net.add_node(
arn, label=source, title=entity, color=self._node_color(entity), shape="star"
)
else:
net.add_node(arn, label=source, title=entity, color=self._node_color(entity))

return app.run_server(mode=mode)
# add edges to graph
for src, dest, asso_type in elements["edges"]:
net.add_edge(src, dest, title=asso_type)

return net.show(path)


class LineageQueryResult(object):
Expand Down Expand Up @@ -449,49 +336,36 @@ def __str__(self):
result_dict = vars(self)
return str({k: [str(val) for val in v] for k, v in result_dict.items()})

def _covert_edges_to_tuples(self):
"""Convert edges to tuple format for visualizer."""
edges = []
# get edge info in the form of (source, target, label)
for edge in self.edges:
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
return edges

def _covert_vertices_to_tuples(self):
"""Convert vertices to tuple format for visualizer."""
verts = []
# get vertex info in the form of (id, label, class)
for vert in self.vertices:
if vert.arn in self.startarn:
# add "startarn" class to node if arn is a startarn
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True))
else:
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False))
return verts

def _covert_edges_to_tuples(self):
"""Convert edges to tuple format for visualizer."""
edges = []
# get edge info in the form of (source, target, label)
for edge in self.edges:
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
return edges

def _get_visualization_elements(self):
"""Get elements for visualization."""
# get vertices and edges info for graph
"""Get elements(nodes+edges) for visualization."""
verts = self._covert_vertices_to_tuples()
edges = self._covert_edges_to_tuples()

nodes = [
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
]

edges = [
{"data": {"source": source, "target": target, "label": label}}
for source, target, label in edges
]

elements = nodes + edges

elements = {"nodes": verts, "edges": edges}
return elements

def visualize(self):
"""Visualize lineage query result."""
elements = self._get_visualization_elements()

lineage_graph = {
# nodes can have shape / color
"TrialComponent": {
Expand Down Expand Up @@ -522,12 +396,9 @@ def visualize(self):
},
}

# initialize DashVisualizer instance to render graph & interactive components
dash_vis = DashVisualizer(lineage_graph)

dash_server = dash_vis.render(elements=elements, mode="inline")

return dash_server
pyvis_vis = PyvisVisualizer(lineage_graph)
elements = self._get_visualization_elements()
return pyvis_vis.render(elements=elements)


class LineageFilter(object):
Expand Down