Skip to content

Commit c5125c0

Browse files
author
jkasiraj
committed
add support for customizing pyvis options
1 parent 0fa2222 commit c5125c0

File tree

1 file changed

+46
-39
lines changed

1 file changed

+46
-39
lines changed

src/sagemaker/lineage/query.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from datetime import datetime
1717
from enum import Enum
18-
from typing import Optional, Union, List, Dict
19-
import re
18+
from typing import Any, Optional, Union, List, Dict
19+
from json import dumps
20+
from re import sub, search
2021

2122
from sagemaker.utils import get_module
2223
from sagemaker.lineage._utils import get_resource_name_from_arn
@@ -235,7 +236,7 @@ def _artifact_to_lineage_object(self):
235236
class PyvisVisualizer(object):
236237
"""Create object used for visualizing graph using Pyvis library."""
237238

238-
def __init__(self, graph_styles):
239+
def __init__(self, graph_styles, pyvis_options: Optional[dict[str, Any]] = None):
239240
"""Init for PyvisVisualizer.
240241
241242
Args:
@@ -260,7 +261,8 @@ def __init__(self, graph_styles):
260261
"symbol": "★", # shape symbol for legend
261262
},
262263
}
263-
264+
pyvis_options(optional): A dictionary containing PyVis options to customize visualization.
265+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
264266
"""
265267
# import visualization packages
266268
(
@@ -272,36 +274,37 @@ def __init__(self, graph_styles):
272274

273275
self.graph_styles = graph_styles
274276

275-
# pyvis graph options
276-
self._options = """
277-
var options = {
278-
"configure":{
279-
"enabled": false
280-
},
281-
"layout": {
282-
"hierarchical": {
283-
"enabled": true,
284-
"blockShifting": true,
285-
"direction": "LR",
286-
"sortMethod": "directed",
287-
"shakeTowards": "leaves"
288-
}
289-
},
290-
"interaction": {
291-
"multiselect": true,
292-
"navigationButtons": true
293-
},
294-
"physics": {
295-
"enabled": false,
296-
"hierarchicalRepulsion": {
297-
"centralGravity": 0,
298-
"avoidOverlap": null
277+
if pyvis_options is None:
278+
# default pyvis graph options
279+
pyvis_options = {
280+
"configure": {
281+
"enabled": False
299282
},
300-
"minVelocity": 0.75,
301-
"solver": "hierarchicalRepulsion"
283+
"layout": {
284+
"hierarchical": {
285+
"enabled": True,
286+
"blockShifting": True,
287+
"direction": "LR",
288+
"sortMethod": "directed",
289+
"shakeTowards": "leaves"
290+
}
291+
},
292+
"interaction": {
293+
"multiselect": True,
294+
"navigationButtons": True
295+
},
296+
"physics": {
297+
"enabled": False,
298+
"hierarchicalRepulsion": {
299+
"centralGravity": 0,
300+
"avoidOverlap": None
301+
},
302+
"minVelocity": 0.75,
303+
"solver": "hierarchicalRepulsion"
304+
}
302305
}
303-
}
304-
"""
306+
# A string representation of a Javascript-like object to be used to override default pyvis options
307+
self._pyvis_options = f'var options = {dumps(pyvis_options)}'
305308

306309
def _import_visual_modules(self):
307310
"""Import modules needed for visualization."""
@@ -382,14 +385,14 @@ def render(self, elements, path="lineage_graph_pyvis.html"):
382385
383386
"""
384387
net = self.Network(height="600px", width="82%", notebook=True, directed=True)
385-
net.set_options(self._options)
388+
net.set_options(self._pyvis_options)
386389

387390
# add nodes to graph
388391
for arn, source, entity, is_start_arn in elements["nodes"]:
389-
entity_text = re.sub(r"(\w)([A-Z])", r"\1 \2", entity)
390-
source = re.sub(r"(\w)([A-Z])", r"\1 \2", source)
391-
account_id = re.search(r":\d{12}:", arn)
392-
name = re.search(r"\/.*", arn)
392+
entity_text = sub(r"(\w)([A-Z])", r"\1 \2", entity)
393+
source = sub(r"(\w)([A-Z])", r"\1 \2", source)
394+
account_id = search(r":\d{12}:", arn)
395+
name = search(r"\/.*", arn)
393396
node_info = (
394397
"Entity: "
395398
+ entity_text
@@ -516,7 +519,9 @@ def _get_visualization_elements(self):
516519
elements = {"nodes": verts, "edges": edges}
517520
return elements
518521

519-
def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
522+
def visualize(self,
523+
path: Optional[str] = "lineage_graph_pyvis.html",
524+
pyvis_options: Optional[dict[str, Any]] = None):
520525
"""Visualize lineage query result.
521526
522527
Creates a PyvisVisualizer object to render network graph with Pyvis library.
@@ -527,6 +532,8 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
527532
Args:
528533
path(optional): The path/filename of the rendered graph html file.
529534
(default path: "lineage_graph_pyvis.html")
535+
pyvis_options(optional): A dictionary containing PyVis options to customize visualization.
536+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
530537
531538
Returns:
532539
display graph: The interactive visualization is presented as a static HTML file.
@@ -561,7 +568,7 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
561568
},
562569
}
563570

564-
pyvis_vis = PyvisVisualizer(lineage_graph_styles)
571+
pyvis_vis = PyvisVisualizer(lineage_graph_styles, pyvis_options)
565572
elements = self._get_visualization_elements()
566573
return pyvis_vis.render(elements=elements, path=path)
567574

0 commit comments

Comments
 (0)