Skip to content

Commit 32e7383

Browse files
author
jkasiraj
committed
add support for customizing pyvis options
1 parent 0fa2222 commit 32e7383

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

src/sagemaker/lineage/query.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from datetime import datetime
1717
from enum import Enum
18-
from typing import Optional, Union, List, Dict
19-
import re
18+
from typing import Any, Optional, Union, List, Dict
19+
from json import dumps
20+
from re import sub, search
2021

2122
from sagemaker.utils import get_module
2223
from sagemaker.lineage._utils import get_resource_name_from_arn
@@ -235,7 +236,7 @@ def _artifact_to_lineage_object(self):
235236
class PyvisVisualizer(object):
236237
"""Create object used for visualizing graph using Pyvis library."""
237238

238-
def __init__(self, graph_styles):
239+
def __init__(self, graph_styles, pyvis_options: Optional[Dict[str, Any]] = None):
239240
"""Init for PyvisVisualizer.
240241
241242
Args:
@@ -260,7 +261,8 @@ def __init__(self, graph_styles):
260261
"symbol": "★", # shape symbol for legend
261262
},
262263
}
263-
264+
pyvis_options(optional): A dict containing PyVis options to customize visualization.
265+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
264266
"""
265267
# import visualization packages
266268
(
@@ -272,36 +274,29 @@ def __init__(self, graph_styles):
272274

273275
self.graph_styles = graph_styles
274276

275-
# pyvis graph options
276-
self._options = """
277-
var options = {
278-
"configure":{
279-
"enabled": false
280-
},
281-
"layout": {
282-
"hierarchical": {
283-
"enabled": true,
284-
"blockShifting": true,
285-
"direction": "LR",
286-
"sortMethod": "directed",
287-
"shakeTowards": "leaves"
288-
}
289-
},
290-
"interaction": {
291-
"multiselect": true,
292-
"navigationButtons": true
293-
},
294-
"physics": {
295-
"enabled": false,
296-
"hierarchicalRepulsion": {
297-
"centralGravity": 0,
298-
"avoidOverlap": null
277+
if pyvis_options is None:
278+
# default pyvis graph options
279+
pyvis_options = {
280+
"configure": {"enabled": False},
281+
"layout": {
282+
"hierarchical": {
283+
"enabled": True,
284+
"blockShifting": True,
285+
"direction": "LR",
286+
"sortMethod": "directed",
287+
"shakeTowards": "leaves",
288+
}
289+
},
290+
"interaction": {"multiselect": True, "navigationButtons": True},
291+
"physics": {
292+
"enabled": False,
293+
"hierarchicalRepulsion": {"centralGravity": 0, "avoidOverlap": None},
294+
"minVelocity": 0.75,
295+
"solver": "hierarchicalRepulsion",
299296
},
300-
"minVelocity": 0.75,
301-
"solver": "hierarchicalRepulsion"
302297
}
303-
}
304-
"""
298+
# A string representation of a Javascript-like object used to override pyvis options
299+
self._pyvis_options = f"var options = {dumps(pyvis_options)}"
305300

306301
def _import_visual_modules(self):
307302
"""Import modules needed for visualization."""
@@ -382,14 +377,14 @@ def render(self, elements, path="lineage_graph_pyvis.html"):
382377
383378
"""
384379
net = self.Network(height="600px", width="82%", notebook=True, directed=True)
385-
net.set_options(self._options)
380+
net.set_options(self._pyvis_options)
386381

387382
# add nodes to graph
388383
for arn, source, entity, is_start_arn in elements["nodes"]:
389-
entity_text = re.sub(r"(\w)([A-Z])", r"\1 \2", entity)
390-
source = re.sub(r"(\w)([A-Z])", r"\1 \2", source)
391-
account_id = re.search(r":\d{12}:", arn)
392-
name = re.search(r"\/.*", arn)
384+
entity_text = sub(r"(\w)([A-Z])", r"\1 \2", entity)
385+
source = sub(r"(\w)([A-Z])", r"\1 \2", source)
386+
account_id = search(r":\d{12}:", arn)
387+
name = search(r"\/.*", arn)
393388
node_info = (
394389
"Entity: "
395390
+ entity_text
@@ -516,7 +511,11 @@ def _get_visualization_elements(self):
516511
elements = {"nodes": verts, "edges": edges}
517512
return elements
518513

519-
def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
514+
def visualize(
515+
self,
516+
path: Optional[str] = "lineage_graph_pyvis.html",
517+
pyvis_options: Optional[Dict[str, Any]] = None,
518+
):
520519
"""Visualize lineage query result.
521520
522521
Creates a PyvisVisualizer object to render network graph with Pyvis library.
@@ -527,6 +526,8 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
527526
Args:
528527
path(optional): The path/filename of the rendered graph html file.
529528
(default path: "lineage_graph_pyvis.html")
529+
pyvis_options(optional): A dict containing PyVis options to customize visualization.
530+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
530531
531532
Returns:
532533
display graph: The interactive visualization is presented as a static HTML file.
@@ -561,7 +562,7 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
561562
},
562563
}
563564

564-
pyvis_vis = PyvisVisualizer(lineage_graph_styles)
565+
pyvis_vis = PyvisVisualizer(lineage_graph_styles, pyvis_options)
565566
elements = self._get_visualization_elements()
566567
return pyvis_vis.render(elements=elements, path=path)
567568

0 commit comments

Comments
 (0)