Skip to content

feature: query lineage visualizer advanced styling & interactive component handling #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 26, 2022
211 changes: 181 additions & 30 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,32 +204,89 @@ def _artifact_to_lineage_object(self):
class DashVisualizer(object):
"""Create object used for visualizing graph using Dash library."""

def __init__(self):
def __init__(self, graph_styles):
"""Init for DashVisualizer."""
# import visualization packages
self.cyto, self.JupyterDash, self.html = self._import_visual_modules()
(
self.cyto,
self.JupyterDash,
self.html,
self.Input,
self.Output,
) = self._import_visual_modules()

self.graph_styles = graph_styles

def _import_visual_modules(self):
"""Import modules needed for visualization."""
try:
import dash_cytoscape as cyto
except ImportError as e:
print(e)
print("try pip install dash-cytoscape")
print("Try: pip install dash-cytoscape")
raise

try:
from jupyter_dash import JupyterDash
except ImportError as e:
print(e)
print("try pip install jupyter-dash")
print("Try: pip install jupyter-dash")
raise

try:
from dash import html
except ImportError as e:
print(e)
print("try pip install dash")
print("Try: pip install dash")
raise

try:
from dash.dependencies import Input, Output
except ImportError as e:
print(e)
print("Try: pip install dash")
raise

return cyto, JupyterDash, html, Input, Output

def _create_legend_component(self, style):
"""Create legend component div."""
text = style["name"]
symbol = ""
color = "#ffffff"
if style["isShape"] == "False":
color = style["style"]["background-color"]
else:
symbol = style["symbol"]
return self.html.Div(
[
self.html.Div(
symbol,
style={
"background-color": color,
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
"font-size": "1.5vw",
},
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
text,
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
)

return cyto, JupyterDash, html
def _create_entity_selector(self, entity_name, style):
"""Create selector for each lineage entity."""
return {"selector": "." + entity_name, "style": style["style"]}

def _get_app(self, elements):
"""Create JupyterDash app for interactivity on Jupyter notebook."""
Expand All @@ -238,10 +295,17 @@ def _get_app(self, elements):

app.layout = self.html.Div(
[
# graph section
self.cyto.Cytoscape(
id="cytoscape-layout-1",
id="cytoscape-graph",
elements=elements,
style={"width": "100%", "height": "350px"},
style={
"width": "84%",
"height": "350px",
"display": "inline-block",
"border-width": "1vw",
"border-color": "#232f3e",
},
layout={"name": "klay"},
stylesheet=[
{
Expand All @@ -251,6 +315,10 @@ def _get_app(self, elements):
"font-size": "3.5vw",
"height": "10vw",
"width": "10vw",
"border-width": "0.8",
"border-opacity": "0",
"border-color": "#232f3e",
"font-family": "verdana",
},
},
{
Expand All @@ -259,23 +327,61 @@ def _get_app(self, elements):
"label": "data(label)",
"color": "gray",
"text-halign": "left",
"text-margin-y": "3px",
"text-margin-x": "-2px",
"font-size": "3%",
"width": "1%",
"curve-style": "taxi",
"text-margin-y": "2.5",
"font-size": "3",
"width": "1",
"curve-style": "bezier",
"control-point-step-size": "15",
"target-arrow-color": "gray",
"target-arrow-shape": "triangle",
"line-color": "gray",
"arrow-scale": "0.5",
"font-family": "verdana",
},
},
],
{"selector": ".select", "style": {"border-opacity": "0.7"}},
]
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
responsive=True,
)
),
self.html.Div(
style={
"width": "0.5%",
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
# legend section
self.html.Div(
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
style={
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
]
)

@app.callback(
self.Output("cytoscape-graph", "elements"),
self.Input("cytoscape-graph", "tapNodeData"),
self.Input("cytoscape-graph", "elements"),
)
def selectNode(tapData, elements):
for n in elements:
if tapData is not None and n["data"]["id"] == tapData["id"]:
# if is tapped node, add "select" class to node
n["classes"] += " select"
elif "classes" in n:
# remove "select" class in "classes" if node not selected
n["classes"] = n["classes"].replace("select", "")

return elements

return app

def render(self, elements, mode):
Expand All @@ -292,6 +398,7 @@ def __init__(
self,
edges: List[Edge] = None,
vertices: List[Vertex] = None,
startarn: List[str] = None,
):
"""Init for LineageQueryResult.

Expand All @@ -301,63 +408,75 @@ def __init__(
"""
self.edges = []
self.vertices = []
self.startarn = []

if edges is not None:
self.edges = edges

if vertices is not None:
self.vertices = vertices

if startarn is not None:
self.startarn = startarn

def __str__(self):
"""Define string representation of ``LineageQueryResult``.

Format:
{
'edges':[
{
"{
'source_arn': 'string', 'destination_arn': 'string',
'association_type': 'string'
},
}",
...
]
],
'vertices':[
{
"{
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'_session': <sagemaker.session.Session object>
},
}",
...
],
'startarn':[
'string',
...
]
}

"""
result_dict = vars(self)
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
return str({k: [str(val) for val in v] for k, v in result_dict.items()})

def _covert_vertices_to_tuples(self):
"""Convert vertices to tuple format for visualizer."""
verts = []
# get vertex info in the form of (id, label, class)
for vert in self.vertices:
verts.append((vert.arn, vert.lineage_source))
if vert.arn in self.startarn:
# add "startarn" class to node if arn is a startarn
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
else:
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
return verts

def _covert_edges_to_tuples(self):
"""Convert edges to tuple format for visualizer."""
edges = []
# get edge info in the form of (source, target, label)
for edge in self.edges:
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
return edges

def _get_visualization_elements(self):
"""Get elements for visualization."""
# get vertices and edges info for graph
verts = self._covert_vertices_to_tuples()
edges = self._covert_edges_to_tuples()

nodes = [
{
"data": {"id": id, "label": label},
}
for id, label in verts
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
]

edges = [
Expand All @@ -373,7 +492,38 @@ def visualize(self):
"""Visualize lineage query result."""
elements = self._get_visualization_elements()

dash_vis = DashVisualizer()
lineage_graph = {
# nodes can have shape / color
"TrialComponent": {
"name": "Trial Component",
"style": {"background-color": "#f6cf61"},
"isShape": "False",
},
"Context": {
"name": "Context",
"style": {"background-color": "#ff9900"},
"isShape": "False",
},
"Action": {
"name": "Action",
"style": {"background-color": "#88c396"},
"isShape": "False",
},
"Artifact": {
"name": "Artifact",
"style": {"background-color": "#146eb4"},
"isShape": "False",
},
"StartArn": {
"name": "StartArn",
"style": {"shape": "star"},
"isShape": "True",
"symbol": "★", # shape symbol for legend
},
}

# initialize DashVisualizer instance to render graph & interactive components
dash_vis = DashVisualizer(lineage_graph)

dash_server = dash_vis.render(elements=elements, mode="inline")

Expand Down Expand Up @@ -453,9 +603,8 @@ def _get_vertex(self, vertex):
sagemaker_session=self._session,
)

def _convert_api_response(self, response) -> LineageQueryResult:
def _convert_api_response(self, response, converted) -> LineageQueryResult:
"""Convert the lineage query API response to its Python representation."""
converted = LineageQueryResult()
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]

Expand Down Expand Up @@ -538,7 +687,9 @@ def query(
Filters=query_filter._to_request_dict() if query_filter else {},
MaxDepth=max_depth,
)
query_response = self._convert_api_response(query_response)
# create query result for startarn info
query_result = LineageQueryResult(startarn=start_arns)
query_response = self._convert_api_response(query_response, query_result)
query_response = self._collapse_cross_account_artifacts(query_response)

return query_response