Skip to content

feature: query lineage visualizer advanced styling & interactive component handling #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 26, 2022
254 changes: 227 additions & 27 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,45 @@ class DashVisualizer(object):
def __init__(self):
"""Init for DashVisualizer."""
# import visualization packages
self.cyto, self.JupyterDash, self.html = self._import_visual_modules()
(
self.cyto,
self.JupyterDash,
self.html,
self.Input,
self.Output,
) = self._import_visual_modules()

def _import_visual_modules(self):
"""Import modules needed for visualization."""
try:
import dash_cytoscape as cyto
except ImportError as e:
print(e)
print("try pip install dash-cytoscape")
print("Try: pip install dash-cytoscape")
raise

try:
from jupyter_dash import JupyterDash
except ImportError as e:
print(e)
print("try pip install jupyter-dash")
print("Try: pip install jupyter-dash")
raise

try:
from dash import html
except ImportError as e:
print(e)
print("try pip install dash")
print("Try: pip install dash")
raise

return cyto, JupyterDash, html
try:
from dash.dependencies import Input, Output
except ImportError as e:
print(e)
print("Try: pip install dash")
raise

return cyto, JupyterDash, html, Input, Output

def _get_app(self, elements):
"""Create JupyterDash app for interactivity on Jupyter notebook."""
Expand All @@ -239,9 +255,15 @@ def _get_app(self, elements):
app.layout = self.html.Div(
[
self.cyto.Cytoscape(
id="cytoscape-layout-1",
id="cytoscape-graph",
elements=elements,
style={"width": "100%", "height": "350px"},
style={
"width": "84%",
"height": "350px",
"display": "inline-block",
"border-width": "1vw",
"border-color": "#232f3e",
},
layout={"name": "klay"},
stylesheet=[
{
Expand All @@ -251,6 +273,10 @@ def _get_app(self, elements):
"font-size": "3.5vw",
"height": "10vw",
"width": "10vw",
"border-width": "0.8",
"border-opacity": "0",
"border-color": "#232f3e",
"font-family": "verdana",
},
},
{
Expand All @@ -259,23 +285,182 @@ def _get_app(self, elements):
"label": "data(label)",
"color": "gray",
"text-halign": "left",
"text-margin-y": "3px",
"text-margin-x": "-2px",
"font-size": "3%",
"width": "1%",
"curve-style": "taxi",
"text-margin-y": "2.5",
"font-size": "3",
"width": "1",
"curve-style": "bezier",
"control-point-step-size": "15",
"target-arrow-color": "gray",
"target-arrow-shape": "triangle",
"line-color": "gray",
"arrow-scale": "0.5",
"font-family": "verdana",
},
},
{"selector": ".Artifact", "style": {"background-color": "#146eb4"}},
{"selector": ".Context", "style": {"background-color": "#ff9900"}},
{"selector": ".TrialComponent", "style": {"background-color": "#f6cf61"}},
{"selector": ".Action", "style": {"background-color": "#88c396"}},
{"selector": ".startarn", "style": {"shape": "star"}},
{"selector": ".select", "style": {"border-opacity": "0.7"}},
],
responsive=True,
)
),
self.html.Div(
style={
"width": "0.5%",
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
self.html.Div(
[
self.html.Div(
[
self.html.Div(
style={
"background-color": "#f6cf61",
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
" Trial Component",
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
),
self.html.Div(
[
self.html.Div(
style={
"background-color": "#ff9900",
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
" Context",
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
),
self.html.Div(
[
self.html.Div(
style={
"background-color": "#88c396",
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
" Action",
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
),
self.html.Div(
[
self.html.Div(
style={
"background-color": "#146eb4",
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
" Artifact",
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
),
self.html.Div(
[
self.html.Div(
"★",
style={
"background-color": "white",
"width": "1.5vw",
"height": "1.5vw",
"display": "inline-block",
"font-size": "1.5vw",
},
),
self.html.Div(
style={
"width": "0.5vw",
"height": "1.5vw",
"display": "inline-block",
}
),
self.html.Div(
"StartArn",
style={"display": "inline-block", "font-size": "1.5vw"},
),
]
),
],
style={
"display": "inline-block",
"font-size": "1vw",
"font-family": "verdana",
"vertical-align": "top",
},
),
]
)

@app.callback(
self.Output("cytoscape-graph", "elements"),
self.Input("cytoscape-graph", "tapNodeData"),
self.Input("cytoscape-graph", "elements"),
)
def selectNode(tapData, elements):
for n in elements:
if tapData is not None and n["data"]["id"] == tapData["id"]:
# if is tapped node, add "select" class to node
n["classes"] += " select"
elif "classes" in n:
# remove "select" class in "classes" if node not selected
n["classes"] = n["classes"].replace("select", "")

return elements

return app

def render(self, elements, mode):
Expand All @@ -292,6 +477,7 @@ def __init__(
self,
edges: List[Edge] = None,
vertices: List[Vertex] = None,
startarn: List[str] = None,
):
"""Init for LineageQueryResult.

Expand All @@ -301,63 +487,75 @@ def __init__(
"""
self.edges = []
self.vertices = []
self.startarn = []

if edges is not None:
self.edges = edges

if vertices is not None:
self.vertices = vertices

if startarn is not None:
self.startarn = startarn

def __str__(self):
"""Define string representation of ``LineageQueryResult``.

Format:
{
'edges':[
{
"{
'source_arn': 'string', 'destination_arn': 'string',
'association_type': 'string'
},
}",
...
]
],
'vertices':[
{
"{
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'_session': <sagemaker.session.Session object>
},
}",
...
],
'startarn':[
'string',
...
]
}

"""
result_dict = vars(self)
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
return str({k: [str(val) for val in v] for k, v in result_dict.items()})

def _covert_vertices_to_tuples(self):
"""Convert vertices to tuple format for visualizer."""
verts = []
# get vertex info in the form of (id, label, class)
for vert in self.vertices:
verts.append((vert.arn, vert.lineage_source))
if vert.arn in self.startarn:
# add "startarn" class to node if arn is a startarn
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
else:
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
return verts

def _covert_edges_to_tuples(self):
"""Convert edges to tuple format for visualizer."""
edges = []
# get edge info in the form of (source, target, label)
for edge in self.edges:
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
return edges

def _get_visualization_elements(self):
"""Get elements for visualization."""
# get vertices and edges info for graph
verts = self._covert_vertices_to_tuples()
edges = self._covert_edges_to_tuples()

nodes = [
{
"data": {"id": id, "label": label},
}
for id, label in verts
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
]

edges = [
Expand All @@ -373,6 +571,7 @@ def visualize(self):
"""Visualize lineage query result."""
elements = self._get_visualization_elements()

# initialize DashVisualizer instance to render graph & interactive components
dash_vis = DashVisualizer()

dash_server = dash_vis.render(elements=elements, mode="inline")
Expand Down Expand Up @@ -453,9 +652,8 @@ def _get_vertex(self, vertex):
sagemaker_session=self._session,
)

def _convert_api_response(self, response) -> LineageQueryResult:
def _convert_api_response(self, response, converted) -> LineageQueryResult:
"""Convert the lineage query API response to its Python representation."""
converted = LineageQueryResult()
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]

Expand Down Expand Up @@ -538,7 +736,9 @@ def query(
Filters=query_filter._to_request_dict() if query_filter else {},
MaxDepth=max_depth,
)
query_response = self._convert_api_response(query_response)
# create query result for startarn info
query_result = LineageQueryResult(startarn=start_arns)
query_response = self._convert_api_response(query_response, query_result)
query_response = self._collapse_cross_account_artifacts(query_response)

return query_response