15
15
16
16
from datetime import datetime
17
17
from enum import Enum
18
- from typing import Optional , Union , List , Dict
19
- import re
18
+ from typing import Any , Optional , Union , List , Dict
19
+ from json import dumps
20
+ from re import sub , search
20
21
21
22
from sagemaker .utils import get_module
22
23
from sagemaker .lineage ._utils import get_resource_name_from_arn
@@ -235,7 +236,7 @@ def _artifact_to_lineage_object(self):
235
236
class PyvisVisualizer (object ):
236
237
"""Create object used for visualizing graph using Pyvis library."""
237
238
238
- def __init__ (self , graph_styles ):
239
+ def __init__ (self , graph_styles , pyvis_options : Optional [ dict [ str , Any ]] = None ):
239
240
"""Init for PyvisVisualizer.
240
241
241
242
Args:
@@ -260,7 +261,8 @@ def __init__(self, graph_styles):
260
261
"symbol": "★", # shape symbol for legend
261
262
},
262
263
}
263
-
264
+ pyvis_options(optional): A dictionary containing PyVis options to customize visualization.
265
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
264
266
"""
265
267
# import visualization packages
266
268
(
@@ -272,36 +274,37 @@ def __init__(self, graph_styles):
272
274
273
275
self .graph_styles = graph_styles
274
276
275
- # pyvis graph options
276
- self ._options = """
277
- var options = {
278
- "configure":{
279
- "enabled": false
280
- },
281
- "layout": {
282
- "hierarchical": {
283
- "enabled": true,
284
- "blockShifting": true,
285
- "direction": "LR",
286
- "sortMethod": "directed",
287
- "shakeTowards": "leaves"
288
- }
289
- },
290
- "interaction": {
291
- "multiselect": true,
292
- "navigationButtons": true
293
- },
294
- "physics": {
295
- "enabled": false,
296
- "hierarchicalRepulsion": {
297
- "centralGravity": 0,
298
- "avoidOverlap": null
277
+ if pyvis_options is None :
278
+ # default pyvis graph options
279
+ pyvis_options = {
280
+ "configure" : {
281
+ "enabled" : False
299
282
},
300
- "minVelocity": 0.75,
301
- "solver": "hierarchicalRepulsion"
283
+ "layout" : {
284
+ "hierarchical" : {
285
+ "enabled" : True ,
286
+ "blockShifting" : True ,
287
+ "direction" : "LR" ,
288
+ "sortMethod" : "directed" ,
289
+ "shakeTowards" : "leaves"
290
+ }
291
+ },
292
+ "interaction" : {
293
+ "multiselect" : True ,
294
+ "navigationButtons" : True
295
+ },
296
+ "physics" : {
297
+ "enabled" : False ,
298
+ "hierarchicalRepulsion" : {
299
+ "centralGravity" : 0 ,
300
+ "avoidOverlap" : None
301
+ },
302
+ "minVelocity" : 0.75 ,
303
+ "solver" : "hierarchicalRepulsion"
304
+ }
302
305
}
303
- }
304
- """
306
+ # A string representation of a Javascript-like object to be used to override default pyvis options
307
+ self . _pyvis_options = f'var options = { dumps ( pyvis_options ) } '
305
308
306
309
def _import_visual_modules (self ):
307
310
"""Import modules needed for visualization."""
@@ -382,14 +385,14 @@ def render(self, elements, path="lineage_graph_pyvis.html"):
382
385
383
386
"""
384
387
net = self .Network (height = "600px" , width = "82%" , notebook = True , directed = True )
385
- net .set_options (self ._options )
388
+ net .set_options (self ._pyvis_options )
386
389
387
390
# add nodes to graph
388
391
for arn , source , entity , is_start_arn in elements ["nodes" ]:
389
- entity_text = re . sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
390
- source = re . sub (r"(\w)([A-Z])" , r"\1 \2" , source )
391
- account_id = re . search (r":\d{12}:" , arn )
392
- name = re . search (r"\/.*" , arn )
392
+ entity_text = sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
393
+ source = sub (r"(\w)([A-Z])" , r"\1 \2" , source )
394
+ account_id = search (r":\d{12}:" , arn )
395
+ name = search (r"\/.*" , arn )
393
396
node_info = (
394
397
"Entity: "
395
398
+ entity_text
@@ -516,7 +519,9 @@ def _get_visualization_elements(self):
516
519
elements = {"nodes" : verts , "edges" : edges }
517
520
return elements
518
521
519
- def visualize (self , path : Optional [str ] = "lineage_graph_pyvis.html" ):
522
+ def visualize (self ,
523
+ path : Optional [str ] = "lineage_graph_pyvis.html" ,
524
+ pyvis_options : Optional [dict [str , Any ]] = None ):
520
525
"""Visualize lineage query result.
521
526
522
527
Creates a PyvisVisualizer object to render network graph with Pyvis library.
@@ -527,6 +532,8 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
527
532
Args:
528
533
path(optional): The path/filename of the rendered graph html file.
529
534
(default path: "lineage_graph_pyvis.html")
535
+ pyvis_options(optional): A dictionary containing PyVis options to customize visualization.
536
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
530
537
531
538
Returns:
532
539
display graph: The interactive visualization is presented as a static HTML file.
@@ -561,7 +568,7 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
561
568
},
562
569
}
563
570
564
- pyvis_vis = PyvisVisualizer (lineage_graph_styles )
571
+ pyvis_vis = PyvisVisualizer (lineage_graph_styles , pyvis_options )
565
572
elements = self ._get_visualization_elements ()
566
573
return pyvis_vis .render (elements = elements , path = path )
567
574
0 commit comments