15
15
16
16
from datetime import datetime
17
17
from enum import Enum
18
- from typing import Optional , Union , List , Dict
19
- import re
18
+ from typing import Any , Optional , Union , List , Dict
19
+ from json import dumps
20
+ from re import sub , search
20
21
21
22
from sagemaker .utils import get_module
22
23
from sagemaker .lineage ._utils import get_resource_name_from_arn
@@ -235,7 +236,7 @@ def _artifact_to_lineage_object(self):
235
236
class PyvisVisualizer (object ):
236
237
"""Create object used for visualizing graph using Pyvis library."""
237
238
238
- def __init__ (self , graph_styles ):
239
+ def __init__ (self , graph_styles , pyvis_options : Optional [ Dict [ str , Any ]] = None ):
239
240
"""Init for PyvisVisualizer.
240
241
241
242
Args:
@@ -260,7 +261,8 @@ def __init__(self, graph_styles):
260
261
"symbol": "★", # shape symbol for legend
261
262
},
262
263
}
263
-
264
+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
265
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
264
266
"""
265
267
# import visualization packages
266
268
(
@@ -272,36 +274,29 @@ def __init__(self, graph_styles):
272
274
273
275
self .graph_styles = graph_styles
274
276
275
- # pyvis graph options
276
- self ._options = """
277
- var options = {
278
- "configure":{
279
- "enabled": false
280
- },
281
- "layout": {
282
- "hierarchical": {
283
- "enabled": true,
284
- "blockShifting": true,
285
- "direction": "LR",
286
- "sortMethod": "directed",
287
- "shakeTowards": "leaves"
288
- }
289
- },
290
- "interaction": {
291
- "multiselect": true,
292
- "navigationButtons": true
293
- },
294
- "physics": {
295
- "enabled": false,
296
- "hierarchicalRepulsion": {
297
- "centralGravity": 0,
298
- "avoidOverlap": null
277
+ if pyvis_options is None :
278
+ # default pyvis graph options
279
+ pyvis_options = {
280
+ "configure" : {"enabled" : False },
281
+ "layout" : {
282
+ "hierarchical" : {
283
+ "enabled" : True ,
284
+ "blockShifting" : True ,
285
+ "direction" : "LR" ,
286
+ "sortMethod" : "directed" ,
287
+ "shakeTowards" : "leaves" ,
288
+ }
289
+ },
290
+ "interaction" : {"multiselect" : True , "navigationButtons" : True },
291
+ "physics" : {
292
+ "enabled" : False ,
293
+ "hierarchicalRepulsion" : {"centralGravity" : 0 , "avoidOverlap" : None },
294
+ "minVelocity" : 0.75 ,
295
+ "solver" : "hierarchicalRepulsion" ,
299
296
},
300
- "minVelocity": 0.75,
301
- "solver": "hierarchicalRepulsion"
302
297
}
303
- }
304
- "" "
298
+ # A string representation of a Javascript-like object used to override pyvis options
299
+ self . _pyvis_options = f"var options = { dumps ( pyvis_options ) } "
305
300
306
301
def _import_visual_modules (self ):
307
302
"""Import modules needed for visualization."""
@@ -382,14 +377,14 @@ def render(self, elements, path="lineage_graph_pyvis.html"):
382
377
383
378
"""
384
379
net = self .Network (height = "600px" , width = "82%" , notebook = True , directed = True )
385
- net .set_options (self ._options )
380
+ net .set_options (self ._pyvis_options )
386
381
387
382
# add nodes to graph
388
383
for arn , source , entity , is_start_arn in elements ["nodes" ]:
389
- entity_text = re . sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
390
- source = re . sub (r"(\w)([A-Z])" , r"\1 \2" , source )
391
- account_id = re . search (r":\d{12}:" , arn )
392
- name = re . search (r"\/.*" , arn )
384
+ entity_text = sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
385
+ source = sub (r"(\w)([A-Z])" , r"\1 \2" , source )
386
+ account_id = search (r":\d{12}:" , arn )
387
+ name = search (r"\/.*" , arn )
393
388
node_info = (
394
389
"Entity: "
395
390
+ entity_text
@@ -516,7 +511,11 @@ def _get_visualization_elements(self):
516
511
elements = {"nodes" : verts , "edges" : edges }
517
512
return elements
518
513
519
- def visualize (self , path : Optional [str ] = "lineage_graph_pyvis.html" ):
514
+ def visualize (
515
+ self ,
516
+ path : Optional [str ] = "lineage_graph_pyvis.html" ,
517
+ pyvis_options : Optional [Dict [str , Any ]] = None ,
518
+ ):
520
519
"""Visualize lineage query result.
521
520
522
521
Creates a PyvisVisualizer object to render network graph with Pyvis library.
@@ -527,6 +526,8 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
527
526
Args:
528
527
path(optional): The path/filename of the rendered graph html file.
529
528
(default path: "lineage_graph_pyvis.html")
529
+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
530
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
530
531
531
532
Returns:
532
533
display graph: The interactive visualization is presented as a static HTML file.
@@ -561,7 +562,7 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
561
562
},
562
563
}
563
564
564
- pyvis_vis = PyvisVisualizer (lineage_graph_styles )
565
+ pyvis_vis = PyvisVisualizer (lineage_graph_styles , pyvis_options )
565
566
elements = self ._get_visualization_elements ()
566
567
return pyvis_vis .render (elements = elements , path = path )
567
568
0 commit comments