@@ -204,32 +204,89 @@ def _artifact_to_lineage_object(self):
204
204
class DashVisualizer (object ):
205
205
"""Create object used for visualizing graph using Dash library."""
206
206
207
- def __init__ (self ):
207
+ def __init__ (self , graph_styles ):
208
208
"""Init for DashVisualizer."""
209
209
# import visualization packages
210
- self .cyto , self .JupyterDash , self .html = self ._import_visual_modules ()
210
+ (
211
+ self .cyto ,
212
+ self .JupyterDash ,
213
+ self .html ,
214
+ self .Input ,
215
+ self .Output ,
216
+ ) = self ._import_visual_modules ()
217
+
218
+ self .graph_styles = graph_styles
211
219
212
220
def _import_visual_modules (self ):
213
221
"""Import modules needed for visualization."""
214
222
try :
215
223
import dash_cytoscape as cyto
216
224
except ImportError as e :
217
225
print (e )
218
- print ("try pip install dash-cytoscape" )
226
+ print ("Try: pip install dash-cytoscape" )
227
+ raise
219
228
220
229
try :
221
230
from jupyter_dash import JupyterDash
222
231
except ImportError as e :
223
232
print (e )
224
- print ("try pip install jupyter-dash" )
233
+ print ("Try: pip install jupyter-dash" )
234
+ raise
225
235
226
236
try :
227
237
from dash import html
228
238
except ImportError as e :
229
239
print (e )
230
- print ("try pip install dash" )
240
+ print ("Try: pip install dash" )
241
+ raise
242
+
243
+ try :
244
+ from dash .dependencies import Input , Output
245
+ except ImportError as e :
246
+ print (e )
247
+ print ("Try: pip install dash" )
248
+ raise
249
+
250
+ return cyto , JupyterDash , html , Input , Output
251
+
252
+ def _create_legend_component (self , style ):
253
+ """Create legend component div."""
254
+ text = style ["name" ]
255
+ symbol = ""
256
+ color = "#ffffff"
257
+ if style ["isShape" ] == "False" :
258
+ color = style ["style" ]["background-color" ]
259
+ else :
260
+ symbol = style ["symbol" ]
261
+ return self .html .Div (
262
+ [
263
+ self .html .Div (
264
+ symbol ,
265
+ style = {
266
+ "background-color" : color ,
267
+ "width" : "1.5vw" ,
268
+ "height" : "1.5vw" ,
269
+ "display" : "inline-block" ,
270
+ "font-size" : "1.5vw" ,
271
+ },
272
+ ),
273
+ self .html .Div (
274
+ style = {
275
+ "width" : "0.5vw" ,
276
+ "height" : "1.5vw" ,
277
+ "display" : "inline-block" ,
278
+ }
279
+ ),
280
+ self .html .Div (
281
+ text ,
282
+ style = {"display" : "inline-block" , "font-size" : "1.5vw" },
283
+ ),
284
+ ]
285
+ )
231
286
232
- return cyto , JupyterDash , html
287
+ def _create_entity_selector (self , entity_name , style ):
288
+ """Create selector for each lineage entity."""
289
+ return {"selector" : "." + entity_name , "style" : style ["style" ]}
233
290
234
291
def _get_app (self , elements ):
235
292
"""Create JupyterDash app for interactivity on Jupyter notebook."""
@@ -238,10 +295,17 @@ def _get_app(self, elements):
238
295
239
296
app .layout = self .html .Div (
240
297
[
298
+ # graph section
241
299
self .cyto .Cytoscape (
242
- id = "cytoscape-layout-1 " ,
300
+ id = "cytoscape-graph " ,
243
301
elements = elements ,
244
- style = {"width" : "100%" , "height" : "350px" },
302
+ style = {
303
+ "width" : "84%" ,
304
+ "height" : "350px" ,
305
+ "display" : "inline-block" ,
306
+ "border-width" : "1vw" ,
307
+ "border-color" : "#232f3e" ,
308
+ },
245
309
layout = {"name" : "klay" },
246
310
stylesheet = [
247
311
{
@@ -251,6 +315,10 @@ def _get_app(self, elements):
251
315
"font-size" : "3.5vw" ,
252
316
"height" : "10vw" ,
253
317
"width" : "10vw" ,
318
+ "border-width" : "0.8" ,
319
+ "border-opacity" : "0" ,
320
+ "border-color" : "#232f3e" ,
321
+ "font-family" : "verdana" ,
254
322
},
255
323
},
256
324
{
@@ -259,23 +327,61 @@ def _get_app(self, elements):
259
327
"label" : "data(label)" ,
260
328
"color" : "gray" ,
261
329
"text-halign" : "left" ,
262
- "text-margin-y" : "3px " ,
263
- "text-margin-x " : "-2px " ,
264
- "font-size " : "3% " ,
265
- "width " : "1% " ,
266
- "curve-style " : "taxi " ,
330
+ "text-margin-y" : "2.5 " ,
331
+ "font-size " : "3 " ,
332
+ "width " : "1 " ,
333
+ "curve-style " : "bezier " ,
334
+ "control-point-step-size " : "15 " ,
267
335
"target-arrow-color" : "gray" ,
268
336
"target-arrow-shape" : "triangle" ,
269
337
"line-color" : "gray" ,
270
338
"arrow-scale" : "0.5" ,
339
+ "font-family" : "verdana" ,
271
340
},
272
341
},
273
- ],
342
+ {"selector" : ".select" , "style" : {"border-opacity" : "0.7" }},
343
+ ]
344
+ + [self ._create_entity_selector (k , v ) for k , v in self .graph_styles .items ()],
274
345
responsive = True ,
275
- )
346
+ ),
347
+ self .html .Div (
348
+ style = {
349
+ "width" : "0.5%" ,
350
+ "display" : "inline-block" ,
351
+ "font-size" : "1vw" ,
352
+ "font-family" : "verdana" ,
353
+ "vertical-align" : "top" ,
354
+ },
355
+ ),
356
+ # legend section
357
+ self .html .Div (
358
+ [self ._create_legend_component (v ) for k , v in self .graph_styles .items ()],
359
+ style = {
360
+ "display" : "inline-block" ,
361
+ "font-size" : "1vw" ,
362
+ "font-family" : "verdana" ,
363
+ "vertical-align" : "top" ,
364
+ },
365
+ ),
276
366
]
277
367
)
278
368
369
+ @app .callback (
370
+ self .Output ("cytoscape-graph" , "elements" ),
371
+ self .Input ("cytoscape-graph" , "tapNodeData" ),
372
+ self .Input ("cytoscape-graph" , "elements" ),
373
+ )
374
+ def selectNode (tapData , elements ):
375
+ for n in elements :
376
+ if tapData is not None and n ["data" ]["id" ] == tapData ["id" ]:
377
+ # if is tapped node, add "select" class to node
378
+ n ["classes" ] += " select"
379
+ elif "classes" in n :
380
+ # remove "select" class in "classes" if node not selected
381
+ n ["classes" ] = n ["classes" ].replace ("select" , "" )
382
+
383
+ return elements
384
+
279
385
return app
280
386
281
387
def render (self , elements , mode ):
@@ -292,6 +398,7 @@ def __init__(
292
398
self ,
293
399
edges : List [Edge ] = None ,
294
400
vertices : List [Vertex ] = None ,
401
+ startarn : List [str ] = None ,
295
402
):
296
403
"""Init for LineageQueryResult.
297
404
@@ -301,63 +408,75 @@ def __init__(
301
408
"""
302
409
self .edges = []
303
410
self .vertices = []
411
+ self .startarn = []
304
412
305
413
if edges is not None :
306
414
self .edges = edges
307
415
308
416
if vertices is not None :
309
417
self .vertices = vertices
310
418
419
+ if startarn is not None :
420
+ self .startarn = startarn
421
+
311
422
def __str__ (self ):
312
423
"""Define string representation of ``LineageQueryResult``.
313
424
314
425
Format:
315
426
{
316
427
'edges':[
317
- {
428
+ " {
318
429
'source_arn': 'string', 'destination_arn': 'string',
319
430
'association_type': 'string'
320
- },
431
+ }" ,
321
432
...
322
- ]
433
+ ],
323
434
'vertices':[
324
- {
435
+ " {
325
436
'arn': 'string', 'lineage_entity': 'string',
326
437
'lineage_source': 'string',
327
438
'_session': <sagemaker.session.Session object>
328
- },
439
+ }",
440
+ ...
441
+ ],
442
+ 'startarn':[
443
+ 'string',
329
444
...
330
445
]
331
446
}
332
447
333
448
"""
334
449
result_dict = vars (self )
335
- return str ({k : [vars (val ) for val in v ] for k , v in result_dict .items ()})
450
+ return str ({k : [str (val ) for val in v ] for k , v in result_dict .items ()})
336
451
337
452
def _covert_vertices_to_tuples (self ):
338
453
"""Convert vertices to tuple format for visualizer."""
339
454
verts = []
455
+ # get vertex info in the form of (id, label, class)
340
456
for vert in self .vertices :
341
- verts .append ((vert .arn , vert .lineage_source ))
457
+ if vert .arn in self .startarn :
458
+ # add "startarn" class to node if arn is a startarn
459
+ verts .append ((vert .arn , vert .lineage_source , vert .lineage_entity + " startarn" ))
460
+ else :
461
+ verts .append ((vert .arn , vert .lineage_source , vert .lineage_entity ))
342
462
return verts
343
463
344
464
def _covert_edges_to_tuples (self ):
345
465
"""Convert edges to tuple format for visualizer."""
346
466
edges = []
467
+ # get edge info in the form of (source, target, label)
347
468
for edge in self .edges :
348
469
edges .append ((edge .source_arn , edge .destination_arn , edge .association_type ))
349
470
return edges
350
471
351
472
def _get_visualization_elements (self ):
352
473
"""Get elements for visualization."""
474
+ # get vertices and edges info for graph
353
475
verts = self ._covert_vertices_to_tuples ()
354
476
edges = self ._covert_edges_to_tuples ()
355
477
356
478
nodes = [
357
- {
358
- "data" : {"id" : id , "label" : label },
359
- }
360
- for id , label in verts
479
+ {"data" : {"id" : id , "label" : label }, "classes" : classes } for id , label , classes in verts
361
480
]
362
481
363
482
edges = [
@@ -373,7 +492,38 @@ def visualize(self):
373
492
"""Visualize lineage query result."""
374
493
elements = self ._get_visualization_elements ()
375
494
376
- dash_vis = DashVisualizer ()
495
+ lineage_graph = {
496
+ # nodes can have shape / color
497
+ "TrialComponent" : {
498
+ "name" : "Trial Component" ,
499
+ "style" : {"background-color" : "#f6cf61" },
500
+ "isShape" : "False" ,
501
+ },
502
+ "Context" : {
503
+ "name" : "Context" ,
504
+ "style" : {"background-color" : "#ff9900" },
505
+ "isShape" : "False" ,
506
+ },
507
+ "Action" : {
508
+ "name" : "Action" ,
509
+ "style" : {"background-color" : "#88c396" },
510
+ "isShape" : "False" ,
511
+ },
512
+ "Artifact" : {
513
+ "name" : "Artifact" ,
514
+ "style" : {"background-color" : "#146eb4" },
515
+ "isShape" : "False" ,
516
+ },
517
+ "StartArn" : {
518
+ "name" : "StartArn" ,
519
+ "style" : {"shape" : "star" },
520
+ "isShape" : "True" ,
521
+ "symbol" : "★" , # shape symbol for legend
522
+ },
523
+ }
524
+
525
+ # initialize DashVisualizer instance to render graph & interactive components
526
+ dash_vis = DashVisualizer (lineage_graph )
377
527
378
528
dash_server = dash_vis .render (elements = elements , mode = "inline" )
379
529
@@ -453,9 +603,8 @@ def _get_vertex(self, vertex):
453
603
sagemaker_session = self ._session ,
454
604
)
455
605
456
- def _convert_api_response (self , response ) -> LineageQueryResult :
606
+ def _convert_api_response (self , response , converted ) -> LineageQueryResult :
457
607
"""Convert the lineage query API response to its Python representation."""
458
- converted = LineageQueryResult ()
459
608
converted .edges = [self ._get_edge (edge ) for edge in response ["Edges" ]]
460
609
converted .vertices = [self ._get_vertex (vertex ) for vertex in response ["Vertices" ]]
461
610
@@ -538,7 +687,9 @@ def query(
538
687
Filters = query_filter ._to_request_dict () if query_filter else {},
539
688
MaxDepth = max_depth ,
540
689
)
541
- query_response = self ._convert_api_response (query_response )
690
+ # create query result for startarn info
691
+ query_result = LineageQueryResult (startarn = start_arns )
692
+ query_response = self ._convert_api_response (query_response , query_result )
542
693
query_response = self ._collapse_cross_account_artifacts (query_response )
543
694
544
695
return query_response
0 commit comments