diff --git a/.gitignore b/.gitignore index 5b496055e9..7a61658c20 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ venv/ env/ .vscode/ **/tmp -.python-version \ No newline at end of file +.python-version +*.html \ No newline at end of file diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 1ea7baecb6..5f3164bdc0 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -17,7 +17,7 @@ from enum import Enum from typing import Optional, Union, List, Dict -from sagemaker.lineage._utils import get_resource_name_from_arn +from sagemaker.lineage._utils import get_resource_name_from_arn, get_module class LineageEntityEnum(Enum): @@ -201,194 +201,81 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) -class DashVisualizer(object): - """Create object used for visualizing graph using Dash library.""" +class PyvisVisualizer(object): + """Create object used for visualizing graph using Pyvis library.""" def __init__(self, graph_styles): - """Init for DashVisualizer.""" + """Init for PyvisVisualizer.""" # import visualization packages ( - self.cyto, - self.JupyterDash, - self.html, - self.Input, - self.Output, + self.Network, + self.Options, ) = self._import_visual_modules() self.graph_styles = graph_styles + # pyvis graph options + self._options = """ + var options = { + "configure":{ + "enabled": false + }, + "layout": { + "hierarchical": { + "enabled": true, + "blockShifting": true, + "direction": "LR", + "sortMethod": "directed", + "shakeTowards": "leaves" + } + }, + "interaction": { + "multiselect": true, + "navigationButtons": true + }, + "physics": { + "enabled": false, + "hierarchicalRepulsion": { + "centralGravity": 0, + "avoidOverlap": null + }, + "minVelocity": 0.75, + "solver": "hierarchicalRepulsion" + } + } + """ + def _import_visual_modules(self): """Import modules needed for visualization.""" - try: - import dash_cytoscape as cyto - except ImportError as e: - print(e) - print("Try: pip install dash-cytoscape") - raise - - try: - from jupyter_dash import JupyterDash - except ImportError as e: - print(e) - print("Try: pip install jupyter-dash") - raise - - try: - from dash import html - except ImportError as e: - print(e) - print("Try: pip install dash") - raise - - try: - from dash.dependencies import Input, Output - except ImportError as e: - print(e) - print("Try: pip install dash") - raise - - return cyto, JupyterDash, html, Input, Output - - def _create_legend_component(self, style): - """Create legend component div.""" - text = style["name"] - symbol = "" - color = "#ffffff" - if style["isShape"] == "False": - color = style["style"]["background-color"] - else: - symbol = style["symbol"] - return self.html.Div( - [ - self.html.Div( - symbol, - style={ - "background-color": color, - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - "font-size": "1.5vw", - }, - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - text, - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ) - - def _create_entity_selector(self, entity_name, style): - """Create selector for each lineage entity.""" - return {"selector": "." + entity_name, "style": style["style"]} - - def _get_app(self, elements): - """Create JupyterDash app for interactivity on Jupyter notebook.""" - app = self.JupyterDash(__name__) - self.cyto.load_extra_layouts() - - app.layout = self.html.Div( - [ - # graph section - self.cyto.Cytoscape( - id="cytoscape-graph", - elements=elements, - style={ - "width": "84%", - "height": "350px", - "display": "inline-block", - "border-width": "1vw", - "border-color": "#232f3e", - }, - layout={"name": "klay"}, - stylesheet=[ - { - "selector": "node", - "style": { - "label": "data(label)", - "font-size": "3.5vw", - "height": "10vw", - "width": "10vw", - "border-width": "0.8", - "border-opacity": "0", - "border-color": "#232f3e", - "font-family": "verdana", - }, - }, - { - "selector": "edge", - "style": { - "label": "data(label)", - "color": "gray", - "text-halign": "left", - "text-margin-y": "2.5", - "font-size": "3", - "width": "1", - "curve-style": "bezier", - "control-point-step-size": "15", - "target-arrow-color": "gray", - "target-arrow-shape": "triangle", - "line-color": "gray", - "arrow-scale": "0.5", - "font-family": "verdana", - }, - }, - {"selector": ".select", "style": {"border-opacity": "0.7"}}, - ] - + [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()], - responsive=True, - ), - self.html.Div( - style={ - "width": "0.5%", - "display": "inline-block", - "font-size": "1vw", - "font-family": "verdana", - "vertical-align": "top", - }, - ), - # legend section - self.html.Div( - [self._create_legend_component(v) for k, v in self.graph_styles.items()], - style={ - "display": "inline-block", - "font-size": "1vw", - "font-family": "verdana", - "vertical-align": "top", - }, - ), - ] - ) - - @app.callback( - self.Output("cytoscape-graph", "elements"), - self.Input("cytoscape-graph", "tapNodeData"), - self.Input("cytoscape-graph", "elements"), - ) - def selectNode(tapData, elements): - for n in elements: - if tapData is not None and n["data"]["id"] == tapData["id"]: - # if is tapped node, add "select" class to node - n["classes"] += " select" - elif "classes" in n: - # remove "select" class in "classes" if node not selected - n["classes"] = n["classes"].replace("select", "") + get_module("pyvis") + from pyvis.network import Network + from pyvis.options import Options - return elements + return Network, Options - return app + def _node_color(self, entity): + """Return node color by background-color specified in graph styles.""" + return self.graph_styles[entity]["style"]["background-color"] - def render(self, elements, mode): + def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result.""" - app = self._get_app(elements) + net = self.Network(height="500px", width="100%", notebook=True, directed=True) + net.set_options(self._options) + + # add nodes to graph + for arn, source, entity, is_start_arn in elements["nodes"]: + if is_start_arn: # startarn + net.add_node( + arn, label=source, title=entity, color=self._node_color(entity), shape="star" + ) + else: + net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) - return app.run_server(mode=mode) + # add edges to graph + for src, dest, asso_type in elements["edges"]: + net.add_edge(src, dest, title=asso_type) + + return net.show(path) class LineageQueryResult(object): @@ -449,6 +336,14 @@ def __str__(self): result_dict = vars(self) return str({k: [str(val) for val in v] for k, v in result_dict.items()}) + def _covert_edges_to_tuples(self): + """Convert edges to tuple format for visualizer.""" + edges = [] + # get edge info in the form of (source, target, label) + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + def _covert_vertices_to_tuples(self): """Convert vertices to tuple format for visualizer.""" verts = [] @@ -456,42 +351,21 @@ def _covert_vertices_to_tuples(self): for vert in self.vertices: if vert.arn in self.startarn: # add "startarn" class to node if arn is a startarn - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True)) else: - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False)) return verts - def _covert_edges_to_tuples(self): - """Convert edges to tuple format for visualizer.""" - edges = [] - # get edge info in the form of (source, target, label) - for edge in self.edges: - edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) - return edges - def _get_visualization_elements(self): - """Get elements for visualization.""" - # get vertices and edges info for graph + """Get elements(nodes+edges) for visualization.""" verts = self._covert_vertices_to_tuples() edges = self._covert_edges_to_tuples() - nodes = [ - {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts - ] - - edges = [ - {"data": {"source": source, "target": target, "label": label}} - for source, target, label in edges - ] - - elements = nodes + edges - + elements = {"nodes": verts, "edges": edges} return elements def visualize(self): """Visualize lineage query result.""" - elements = self._get_visualization_elements() - lineage_graph = { # nodes can have shape / color "TrialComponent": { @@ -522,12 +396,9 @@ def visualize(self): }, } - # initialize DashVisualizer instance to render graph & interactive components - dash_vis = DashVisualizer(lineage_graph) - - dash_server = dash_vis.render(elements=elements, mode="inline") - - return dash_server + pyvis_vis = PyvisVisualizer(lineage_graph) + elements = self._get_visualization_elements() + return pyvis_vis.render(elements=elements) class LineageFilter(object):