Skip to content

Commit e6078a9

Browse files
authored
Merge pull request #4 from ytlee93/master
feature: query lineage visualizer advanced styling & interactive component handling
2 parents cbe445e + 41d2453 commit e6078a9

File tree

1 file changed

+181
-30
lines changed

1 file changed

+181
-30
lines changed

src/sagemaker/lineage/query.py

Lines changed: 181 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -204,32 +204,89 @@ def _artifact_to_lineage_object(self):
204204
class DashVisualizer(object):
205205
"""Create object used for visualizing graph using Dash library."""
206206

207-
def __init__(self):
207+
def __init__(self, graph_styles):
208208
"""Init for DashVisualizer."""
209209
# import visualization packages
210-
self.cyto, self.JupyterDash, self.html = self._import_visual_modules()
210+
(
211+
self.cyto,
212+
self.JupyterDash,
213+
self.html,
214+
self.Input,
215+
self.Output,
216+
) = self._import_visual_modules()
217+
218+
self.graph_styles = graph_styles
211219

212220
def _import_visual_modules(self):
213221
"""Import modules needed for visualization."""
214222
try:
215223
import dash_cytoscape as cyto
216224
except ImportError as e:
217225
print(e)
218-
print("try pip install dash-cytoscape")
226+
print("Try: pip install dash-cytoscape")
227+
raise
219228

220229
try:
221230
from jupyter_dash import JupyterDash
222231
except ImportError as e:
223232
print(e)
224-
print("try pip install jupyter-dash")
233+
print("Try: pip install jupyter-dash")
234+
raise
225235

226236
try:
227237
from dash import html
228238
except ImportError as e:
229239
print(e)
230-
print("try pip install dash")
240+
print("Try: pip install dash")
241+
raise
242+
243+
try:
244+
from dash.dependencies import Input, Output
245+
except ImportError as e:
246+
print(e)
247+
print("Try: pip install dash")
248+
raise
249+
250+
return cyto, JupyterDash, html, Input, Output
251+
252+
def _create_legend_component(self, style):
253+
"""Create legend component div."""
254+
text = style["name"]
255+
symbol = ""
256+
color = "#ffffff"
257+
if style["isShape"] == "False":
258+
color = style["style"]["background-color"]
259+
else:
260+
symbol = style["symbol"]
261+
return self.html.Div(
262+
[
263+
self.html.Div(
264+
symbol,
265+
style={
266+
"background-color": color,
267+
"width": "1.5vw",
268+
"height": "1.5vw",
269+
"display": "inline-block",
270+
"font-size": "1.5vw",
271+
},
272+
),
273+
self.html.Div(
274+
style={
275+
"width": "0.5vw",
276+
"height": "1.5vw",
277+
"display": "inline-block",
278+
}
279+
),
280+
self.html.Div(
281+
text,
282+
style={"display": "inline-block", "font-size": "1.5vw"},
283+
),
284+
]
285+
)
231286

232-
return cyto, JupyterDash, html
287+
def _create_entity_selector(self, entity_name, style):
288+
"""Create selector for each lineage entity."""
289+
return {"selector": "." + entity_name, "style": style["style"]}
233290

234291
def _get_app(self, elements):
235292
"""Create JupyterDash app for interactivity on Jupyter notebook."""
@@ -238,10 +295,17 @@ def _get_app(self, elements):
238295

239296
app.layout = self.html.Div(
240297
[
298+
# graph section
241299
self.cyto.Cytoscape(
242-
id="cytoscape-layout-1",
300+
id="cytoscape-graph",
243301
elements=elements,
244-
style={"width": "100%", "height": "350px"},
302+
style={
303+
"width": "84%",
304+
"height": "350px",
305+
"display": "inline-block",
306+
"border-width": "1vw",
307+
"border-color": "#232f3e",
308+
},
245309
layout={"name": "klay"},
246310
stylesheet=[
247311
{
@@ -251,6 +315,10 @@ def _get_app(self, elements):
251315
"font-size": "3.5vw",
252316
"height": "10vw",
253317
"width": "10vw",
318+
"border-width": "0.8",
319+
"border-opacity": "0",
320+
"border-color": "#232f3e",
321+
"font-family": "verdana",
254322
},
255323
},
256324
{
@@ -259,23 +327,61 @@ def _get_app(self, elements):
259327
"label": "data(label)",
260328
"color": "gray",
261329
"text-halign": "left",
262-
"text-margin-y": "3px",
263-
"text-margin-x": "-2px",
264-
"font-size": "3%",
265-
"width": "1%",
266-
"curve-style": "taxi",
330+
"text-margin-y": "2.5",
331+
"font-size": "3",
332+
"width": "1",
333+
"curve-style": "bezier",
334+
"control-point-step-size": "15",
267335
"target-arrow-color": "gray",
268336
"target-arrow-shape": "triangle",
269337
"line-color": "gray",
270338
"arrow-scale": "0.5",
339+
"font-family": "verdana",
271340
},
272341
},
273-
],
342+
{"selector": ".select", "style": {"border-opacity": "0.7"}},
343+
]
344+
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
274345
responsive=True,
275-
)
346+
),
347+
self.html.Div(
348+
style={
349+
"width": "0.5%",
350+
"display": "inline-block",
351+
"font-size": "1vw",
352+
"font-family": "verdana",
353+
"vertical-align": "top",
354+
},
355+
),
356+
# legend section
357+
self.html.Div(
358+
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
359+
style={
360+
"display": "inline-block",
361+
"font-size": "1vw",
362+
"font-family": "verdana",
363+
"vertical-align": "top",
364+
},
365+
),
276366
]
277367
)
278368

369+
@app.callback(
370+
self.Output("cytoscape-graph", "elements"),
371+
self.Input("cytoscape-graph", "tapNodeData"),
372+
self.Input("cytoscape-graph", "elements"),
373+
)
374+
def selectNode(tapData, elements):
375+
for n in elements:
376+
if tapData is not None and n["data"]["id"] == tapData["id"]:
377+
# if is tapped node, add "select" class to node
378+
n["classes"] += " select"
379+
elif "classes" in n:
380+
# remove "select" class in "classes" if node not selected
381+
n["classes"] = n["classes"].replace("select", "")
382+
383+
return elements
384+
279385
return app
280386

281387
def render(self, elements, mode):
@@ -292,6 +398,7 @@ def __init__(
292398
self,
293399
edges: List[Edge] = None,
294400
vertices: List[Vertex] = None,
401+
startarn: List[str] = None,
295402
):
296403
"""Init for LineageQueryResult.
297404
@@ -301,63 +408,75 @@ def __init__(
301408
"""
302409
self.edges = []
303410
self.vertices = []
411+
self.startarn = []
304412

305413
if edges is not None:
306414
self.edges = edges
307415

308416
if vertices is not None:
309417
self.vertices = vertices
310418

419+
if startarn is not None:
420+
self.startarn = startarn
421+
311422
def __str__(self):
312423
"""Define string representation of ``LineageQueryResult``.
313424
314425
Format:
315426
{
316427
'edges':[
317-
{
428+
"{
318429
'source_arn': 'string', 'destination_arn': 'string',
319430
'association_type': 'string'
320-
},
431+
}",
321432
...
322-
]
433+
],
323434
'vertices':[
324-
{
435+
"{
325436
'arn': 'string', 'lineage_entity': 'string',
326437
'lineage_source': 'string',
327438
'_session': <sagemaker.session.Session object>
328-
},
439+
}",
440+
...
441+
],
442+
'startarn':[
443+
'string',
329444
...
330445
]
331446
}
332447
333448
"""
334449
result_dict = vars(self)
335-
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
450+
return str({k: [str(val) for val in v] for k, v in result_dict.items()})
336451

337452
def _covert_vertices_to_tuples(self):
338453
"""Convert vertices to tuple format for visualizer."""
339454
verts = []
455+
# get vertex info in the form of (id, label, class)
340456
for vert in self.vertices:
341-
verts.append((vert.arn, vert.lineage_source))
457+
if vert.arn in self.startarn:
458+
# add "startarn" class to node if arn is a startarn
459+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
460+
else:
461+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
342462
return verts
343463

344464
def _covert_edges_to_tuples(self):
345465
"""Convert edges to tuple format for visualizer."""
346466
edges = []
467+
# get edge info in the form of (source, target, label)
347468
for edge in self.edges:
348469
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
349470
return edges
350471

351472
def _get_visualization_elements(self):
352473
"""Get elements for visualization."""
474+
# get vertices and edges info for graph
353475
verts = self._covert_vertices_to_tuples()
354476
edges = self._covert_edges_to_tuples()
355477

356478
nodes = [
357-
{
358-
"data": {"id": id, "label": label},
359-
}
360-
for id, label in verts
479+
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
361480
]
362481

363482
edges = [
@@ -373,7 +492,38 @@ def visualize(self):
373492
"""Visualize lineage query result."""
374493
elements = self._get_visualization_elements()
375494

376-
dash_vis = DashVisualizer()
495+
lineage_graph = {
496+
# nodes can have shape / color
497+
"TrialComponent": {
498+
"name": "Trial Component",
499+
"style": {"background-color": "#f6cf61"},
500+
"isShape": "False",
501+
},
502+
"Context": {
503+
"name": "Context",
504+
"style": {"background-color": "#ff9900"},
505+
"isShape": "False",
506+
},
507+
"Action": {
508+
"name": "Action",
509+
"style": {"background-color": "#88c396"},
510+
"isShape": "False",
511+
},
512+
"Artifact": {
513+
"name": "Artifact",
514+
"style": {"background-color": "#146eb4"},
515+
"isShape": "False",
516+
},
517+
"StartArn": {
518+
"name": "StartArn",
519+
"style": {"shape": "star"},
520+
"isShape": "True",
521+
"symbol": "★", # shape symbol for legend
522+
},
523+
}
524+
525+
# initialize DashVisualizer instance to render graph & interactive components
526+
dash_vis = DashVisualizer(lineage_graph)
377527

378528
dash_server = dash_vis.render(elements=elements, mode="inline")
379529

@@ -453,9 +603,8 @@ def _get_vertex(self, vertex):
453603
sagemaker_session=self._session,
454604
)
455605

456-
def _convert_api_response(self, response) -> LineageQueryResult:
606+
def _convert_api_response(self, response, converted) -> LineageQueryResult:
457607
"""Convert the lineage query API response to its Python representation."""
458-
converted = LineageQueryResult()
459608
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
460609
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
461610

@@ -538,7 +687,9 @@ def query(
538687
Filters=query_filter._to_request_dict() if query_filter else {},
539688
MaxDepth=max_depth,
540689
)
541-
query_response = self._convert_api_response(query_response)
690+
# create query result for startarn info
691+
query_result = LineageQueryResult(startarn=start_arns)
692+
query_response = self._convert_api_response(query_response, query_result)
542693
query_response = self._collapse_cross_account_artifacts(query_response)
543694

544695
return query_response

0 commit comments

Comments
 (0)