|
17 | 17 | from enum import Enum
|
18 | 18 | from typing import Optional, Union, List, Dict
|
19 | 19 |
|
20 |
| -from sagemaker.lineage._utils import get_resource_name_from_arn |
| 20 | +from sagemaker.lineage._utils import get_resource_name_from_arn, get_module |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class LineageEntityEnum(Enum):
|
@@ -201,194 +201,81 @@ def _artifact_to_lineage_object(self):
|
201 | 201 | return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
|
202 | 202 |
|
203 | 203 |
|
204 |
| -class DashVisualizer(object): |
205 |
| - """Create object used for visualizing graph using Dash library.""" |
| 204 | +class PyvisVisualizer(object): |
| 205 | + """Create object used for visualizing graph using Pyvis library.""" |
206 | 206 |
|
207 | 207 | def __init__(self, graph_styles):
|
208 |
| - """Init for DashVisualizer.""" |
| 208 | + """Init for PyvisVisualizer.""" |
209 | 209 | # import visualization packages
|
210 | 210 | (
|
211 |
| - self.cyto, |
212 |
| - self.JupyterDash, |
213 |
| - self.html, |
214 |
| - self.Input, |
215 |
| - self.Output, |
| 211 | + self.Network, |
| 212 | + self.Options, |
216 | 213 | ) = self._import_visual_modules()
|
217 | 214 |
|
218 | 215 | self.graph_styles = graph_styles
|
219 | 216 |
|
| 217 | + # pyvis graph options |
| 218 | + self._options = """ |
| 219 | + var options = { |
| 220 | + "configure":{ |
| 221 | + "enabled": false |
| 222 | + }, |
| 223 | + "layout": { |
| 224 | + "hierarchical": { |
| 225 | + "enabled": true, |
| 226 | + "blockShifting": true, |
| 227 | + "direction": "LR", |
| 228 | + "sortMethod": "directed", |
| 229 | + "shakeTowards": "leaves" |
| 230 | + } |
| 231 | + }, |
| 232 | + "interaction": { |
| 233 | + "multiselect": true, |
| 234 | + "navigationButtons": true |
| 235 | + }, |
| 236 | + "physics": { |
| 237 | + "enabled": false, |
| 238 | + "hierarchicalRepulsion": { |
| 239 | + "centralGravity": 0, |
| 240 | + "avoidOverlap": null |
| 241 | + }, |
| 242 | + "minVelocity": 0.75, |
| 243 | + "solver": "hierarchicalRepulsion" |
| 244 | + } |
| 245 | + } |
| 246 | + """ |
| 247 | + |
220 | 248 | def _import_visual_modules(self):
|
221 | 249 | """Import modules needed for visualization."""
|
222 |
| - try: |
223 |
| - import dash_cytoscape as cyto |
224 |
| - except ImportError as e: |
225 |
| - print(e) |
226 |
| - print("Try: pip install dash-cytoscape") |
227 |
| - raise |
228 |
| - |
229 |
| - try: |
230 |
| - from jupyter_dash import JupyterDash |
231 |
| - except ImportError as e: |
232 |
| - print(e) |
233 |
| - print("Try: pip install jupyter-dash") |
234 |
| - raise |
235 |
| - |
236 |
| - try: |
237 |
| - from dash import html |
238 |
| - except ImportError as e: |
239 |
| - print(e) |
240 |
| - print("Try: pip install dash") |
241 |
| - raise |
242 |
| - |
243 |
| - try: |
244 |
| - from dash.dependencies import Input, Output |
245 |
| - except ImportError as e: |
246 |
| - print(e) |
247 |
| - print("Try: pip install dash") |
248 |
| - raise |
249 |
| - |
250 |
| - return cyto, JupyterDash, html, Input, Output |
251 |
| - |
252 |
| - def _create_legend_component(self, style): |
253 |
| - """Create legend component div.""" |
254 |
| - text = style["name"] |
255 |
| - symbol = "" |
256 |
| - color = "#ffffff" |
257 |
| - if style["isShape"] == "False": |
258 |
| - color = style["style"]["background-color"] |
259 |
| - else: |
260 |
| - symbol = style["symbol"] |
261 |
| - return self.html.Div( |
262 |
| - [ |
263 |
| - self.html.Div( |
264 |
| - symbol, |
265 |
| - style={ |
266 |
| - "background-color": color, |
267 |
| - "width": "1.5vw", |
268 |
| - "height": "1.5vw", |
269 |
| - "display": "inline-block", |
270 |
| - "font-size": "1.5vw", |
271 |
| - }, |
272 |
| - ), |
273 |
| - self.html.Div( |
274 |
| - style={ |
275 |
| - "width": "0.5vw", |
276 |
| - "height": "1.5vw", |
277 |
| - "display": "inline-block", |
278 |
| - } |
279 |
| - ), |
280 |
| - self.html.Div( |
281 |
| - text, |
282 |
| - style={"display": "inline-block", "font-size": "1.5vw"}, |
283 |
| - ), |
284 |
| - ] |
285 |
| - ) |
286 |
| - |
287 |
| - def _create_entity_selector(self, entity_name, style): |
288 |
| - """Create selector for each lineage entity.""" |
289 |
| - return {"selector": "." + entity_name, "style": style["style"]} |
290 |
| - |
291 |
| - def _get_app(self, elements): |
292 |
| - """Create JupyterDash app for interactivity on Jupyter notebook.""" |
293 |
| - app = self.JupyterDash(__name__) |
294 |
| - self.cyto.load_extra_layouts() |
295 |
| - |
296 |
| - app.layout = self.html.Div( |
297 |
| - [ |
298 |
| - # graph section |
299 |
| - self.cyto.Cytoscape( |
300 |
| - id="cytoscape-graph", |
301 |
| - elements=elements, |
302 |
| - style={ |
303 |
| - "width": "84%", |
304 |
| - "height": "350px", |
305 |
| - "display": "inline-block", |
306 |
| - "border-width": "1vw", |
307 |
| - "border-color": "#232f3e", |
308 |
| - }, |
309 |
| - layout={"name": "klay"}, |
310 |
| - stylesheet=[ |
311 |
| - { |
312 |
| - "selector": "node", |
313 |
| - "style": { |
314 |
| - "label": "data(label)", |
315 |
| - "font-size": "3.5vw", |
316 |
| - "height": "10vw", |
317 |
| - "width": "10vw", |
318 |
| - "border-width": "0.8", |
319 |
| - "border-opacity": "0", |
320 |
| - "border-color": "#232f3e", |
321 |
| - "font-family": "verdana", |
322 |
| - }, |
323 |
| - }, |
324 |
| - { |
325 |
| - "selector": "edge", |
326 |
| - "style": { |
327 |
| - "label": "data(label)", |
328 |
| - "color": "gray", |
329 |
| - "text-halign": "left", |
330 |
| - "text-margin-y": "2.5", |
331 |
| - "font-size": "3", |
332 |
| - "width": "1", |
333 |
| - "curve-style": "bezier", |
334 |
| - "control-point-step-size": "15", |
335 |
| - "target-arrow-color": "gray", |
336 |
| - "target-arrow-shape": "triangle", |
337 |
| - "line-color": "gray", |
338 |
| - "arrow-scale": "0.5", |
339 |
| - "font-family": "verdana", |
340 |
| - }, |
341 |
| - }, |
342 |
| - {"selector": ".select", "style": {"border-opacity": "0.7"}}, |
343 |
| - ] |
344 |
| - + [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()], |
345 |
| - responsive=True, |
346 |
| - ), |
347 |
| - self.html.Div( |
348 |
| - style={ |
349 |
| - "width": "0.5%", |
350 |
| - "display": "inline-block", |
351 |
| - "font-size": "1vw", |
352 |
| - "font-family": "verdana", |
353 |
| - "vertical-align": "top", |
354 |
| - }, |
355 |
| - ), |
356 |
| - # legend section |
357 |
| - self.html.Div( |
358 |
| - [self._create_legend_component(v) for k, v in self.graph_styles.items()], |
359 |
| - style={ |
360 |
| - "display": "inline-block", |
361 |
| - "font-size": "1vw", |
362 |
| - "font-family": "verdana", |
363 |
| - "vertical-align": "top", |
364 |
| - }, |
365 |
| - ), |
366 |
| - ] |
367 |
| - ) |
368 |
| - |
369 |
| - @app.callback( |
370 |
| - self.Output("cytoscape-graph", "elements"), |
371 |
| - self.Input("cytoscape-graph", "tapNodeData"), |
372 |
| - self.Input("cytoscape-graph", "elements"), |
373 |
| - ) |
374 |
| - def selectNode(tapData, elements): |
375 |
| - for n in elements: |
376 |
| - if tapData is not None and n["data"]["id"] == tapData["id"]: |
377 |
| - # if is tapped node, add "select" class to node |
378 |
| - n["classes"] += " select" |
379 |
| - elif "classes" in n: |
380 |
| - # remove "select" class in "classes" if node not selected |
381 |
| - n["classes"] = n["classes"].replace("select", "") |
| 250 | + get_module("pyvis") |
| 251 | + from pyvis.network import Network |
| 252 | + from pyvis.options import Options |
382 | 253 |
|
383 |
| - return elements |
| 254 | + return Network, Options |
384 | 255 |
|
385 |
| - return app |
| 256 | + def _node_color(self, entity): |
| 257 | + """Return node color by background-color specified in graph styles.""" |
| 258 | + return self.graph_styles[entity]["style"]["background-color"] |
386 | 259 |
|
387 |
| - def render(self, elements, mode): |
| 260 | + def render(self, elements, path="pyvisExample.html"): |
388 | 261 | """Render graph for lineage query result."""
|
389 |
| - app = self._get_app(elements) |
| 262 | + net = self.Network(height="500px", width="100%", notebook=True, directed=True) |
| 263 | + net.set_options(self._options) |
| 264 | + |
| 265 | + # add nodes to graph |
| 266 | + for arn, source, entity, is_start_arn in elements["nodes"]: |
| 267 | + if is_start_arn: # startarn |
| 268 | + net.add_node( |
| 269 | + arn, label=source, title=entity, color=self._node_color(entity), shape="star" |
| 270 | + ) |
| 271 | + else: |
| 272 | + net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) |
390 | 273 |
|
391 |
| - return app.run_server(mode=mode) |
| 274 | + # add edges to graph |
| 275 | + for src, dest, asso_type in elements["edges"]: |
| 276 | + net.add_edge(src, dest, title=asso_type) |
| 277 | + |
| 278 | + return net.show(path) |
392 | 279 |
|
393 | 280 |
|
394 | 281 | class LineageQueryResult(object):
|
@@ -449,49 +336,36 @@ def __str__(self):
|
449 | 336 | result_dict = vars(self)
|
450 | 337 | return str({k: [str(val) for val in v] for k, v in result_dict.items()})
|
451 | 338 |
|
| 339 | + def _covert_edges_to_tuples(self): |
| 340 | + """Convert edges to tuple format for visualizer.""" |
| 341 | + edges = [] |
| 342 | + # get edge info in the form of (source, target, label) |
| 343 | + for edge in self.edges: |
| 344 | + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) |
| 345 | + return edges |
| 346 | + |
452 | 347 | def _covert_vertices_to_tuples(self):
|
453 | 348 | """Convert vertices to tuple format for visualizer."""
|
454 | 349 | verts = []
|
455 | 350 | # get vertex info in the form of (id, label, class)
|
456 | 351 | for vert in self.vertices:
|
457 | 352 | if vert.arn in self.startarn:
|
458 | 353 | # add "startarn" class to node if arn is a startarn
|
459 |
| - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) |
| 354 | + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True)) |
460 | 355 | else:
|
461 |
| - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) |
| 356 | + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False)) |
462 | 357 | return verts
|
463 | 358 |
|
464 |
| - def _covert_edges_to_tuples(self): |
465 |
| - """Convert edges to tuple format for visualizer.""" |
466 |
| - edges = [] |
467 |
| - # get edge info in the form of (source, target, label) |
468 |
| - for edge in self.edges: |
469 |
| - edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) |
470 |
| - return edges |
471 |
| - |
472 | 359 | def _get_visualization_elements(self):
|
473 |
| - """Get elements for visualization.""" |
474 |
| - # get vertices and edges info for graph |
| 360 | + """Get elements(nodes+edges) for visualization.""" |
475 | 361 | verts = self._covert_vertices_to_tuples()
|
476 | 362 | edges = self._covert_edges_to_tuples()
|
477 | 363 |
|
478 |
| - nodes = [ |
479 |
| - {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts |
480 |
| - ] |
481 |
| - |
482 |
| - edges = [ |
483 |
| - {"data": {"source": source, "target": target, "label": label}} |
484 |
| - for source, target, label in edges |
485 |
| - ] |
486 |
| - |
487 |
| - elements = nodes + edges |
488 |
| - |
| 364 | + elements = {"nodes": verts, "edges": edges} |
489 | 365 | return elements
|
490 | 366 |
|
491 | 367 | def visualize(self):
|
492 | 368 | """Visualize lineage query result."""
|
493 |
| - elements = self._get_visualization_elements() |
494 |
| - |
495 | 369 | lineage_graph = {
|
496 | 370 | # nodes can have shape / color
|
497 | 371 | "TrialComponent": {
|
@@ -522,12 +396,9 @@ def visualize(self):
|
522 | 396 | },
|
523 | 397 | }
|
524 | 398 |
|
525 |
| - # initialize DashVisualizer instance to render graph & interactive components |
526 |
| - dash_vis = DashVisualizer(lineage_graph) |
527 |
| - |
528 |
| - dash_server = dash_vis.render(elements=elements, mode="inline") |
529 |
| - |
530 |
| - return dash_server |
| 399 | + pyvis_vis = PyvisVisualizer(lineage_graph) |
| 400 | + elements = self._get_visualization_elements() |
| 401 | + return pyvis_vis.render(elements=elements) |
531 | 402 |
|
532 | 403 |
|
533 | 404 | class LineageFilter(object):
|
|
0 commit comments