From 77965db122046b205b04f3f24026ba8281c877aa Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 1 Jul 2022 09:07:32 -0700 Subject: [PATCH 01/46] feature: add __str__ methods to queryLineage output classes --- src/sagemaker/lineage/query.py | 51 ++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a54331c39a..72bde00a1a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -92,6 +92,18 @@ def __eq__(self, other): and self.destination_arn == other.destination_arn ) + def __str__(self): + """Define string representation of ``Edge``. + + Format: + { + 'source_arn': 'string', 'destination_arn': 'string', + 'association_type': 'string' + } + + """ + return (str(self.__dict__)) + class Vertex: """A vertex for a lineage graph.""" @@ -130,6 +142,19 @@ def __eq__(self, other): and self.lineage_source == other.lineage_source ) + def __str__(self): + """Define string representation of ``Vertex``. + + Format: + { + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + } + + """ + return (str(self.__dict__)) + def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -199,6 +224,32 @@ def __init__( if vertices is not None: self.vertices = vertices + def __str__(self): + """Define string representation of ``LineageQueryResult``. + + Format: + { + 'edges':[ + { + 'source_arn': 'string', 'destination_arn': 'string', + 'association_type': 'string' + }, + ... + ] + 'vertices':[ + { + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + }, + ... + ] + } + + """ + result_dict = vars(self) + return (str({k: [vars(val) for val in v] for k, v in result_dict.items()})) + class LineageFilter(object): """A filter used in a lineage query.""" From 7359b3d16f6112b52e8303c44973ef0ad1fc68ea Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 7 Jul 2022 17:34:45 -0700 Subject: [PATCH 02/46] feature: query lineage visualizer for general case edge.association_type added style changes of graph --- src/sagemaker/lineage/query.py | 119 +++++++++++++++++++++++++++++---- tests/data/_repack_model.py | 110 ++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+), 13 deletions(-) create mode 100644 tests/data/_repack_model.py diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 72bde00a1a..7345911df0 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -97,12 +97,12 @@ def __str__(self): Format: { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' } - + """ - return (str(self.__dict__)) + return str(self.__dict__) class Vertex: @@ -147,13 +147,13 @@ def __str__(self): Format: { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': } - + """ - return (str(self.__dict__)) + return str(self.__dict__) def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -226,29 +226,122 @@ def __init__( def __str__(self): """Define string representation of ``LineageQueryResult``. - + Format: { 'edges':[ { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' }, ... ] 'vertices':[ { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': }, ... ] } - + """ result_dict = vars(self) - return (str({k: [vars(val) for val in v] for k, v in result_dict.items()})) + return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + + def _import_visual_modules(self): + """Import modules needed for visualization.""" + import dash_cytoscape as cyto + + from jupyter_dash import JupyterDash + + from dash import html + + return cyto, JupyterDash, html + + def _get_verts(self): + """Convert vertices to tuple format for visualizer""" + verts = [] + for vert in self.vertices: + verts.append((vert.arn, vert.lineage_source)) + return verts + + def _get_edges(self): + """Convert edges to tuple format for visualizer""" + edges = [] + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + + def visualize(self): + """Visualize lineage query result.""" + + cyto, JupyterDash, html = self._import_visual_modules() + + cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts + app = JupyterDash(__name__) + + verts = self._get_verts() + edges = self._get_edges() + + nodes = [ + { + "data": {"id": id, "label": label}, + } + for id, label in verts + ] + + edges = [ + { + "data": {"source": source, "target": target, "label": label} + } + for source, target, label in edges + ] + + elements = nodes + edges + + app.layout = html.Div( + [ + cyto.Cytoscape( + id="cytoscape-layout-1", + elements=elements, + style={"width": "100%", "height": "350px"}, + layout={"name": "klay"}, + stylesheet=[ + { + "selector": "node", + "style": { + "label": "data(label)", + "font-size": "3.5vw", + "height": "10vw", + "width": "10vw" + } + }, + { + "selector": "edge", + "style": { + "label": "data(label)", + "color": "gray", + "text-halign": "left", + "text-margin-y": "3px", + "text-margin-x": "-2px", + "font-size": "3%", + "width": "1%", + "curve-style": "taxi", + "target-arrow-color": "gray", + "target-arrow-shape": "triangle", + "line-color": "gray", + "arrow-scale": "0.5" + }, + }, + ], + responsive=True, + ) + ] + ) + + return app.run_server(mode="inline") class LineageFilter(object): diff --git a/tests/data/_repack_model.py b/tests/data/_repack_model.py new file mode 100644 index 0000000000..3cfa6760b3 --- /dev/null +++ b/tests/data/_repack_model.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Repack model script for training jobs to inject entry points""" +from __future__ import absolute_import + +import argparse +import os +import shutil +import tarfile +import tempfile + +# Repack Model +# The following script is run via a training job which takes an existing model and a custom +# entry point script as arguments. The script creates a new model archive with the custom +# entry point in the "code" directory along with the existing model. Subsequently, when the model +# is unpacked for inference, the custom entry point will be used. +# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html + +# distutils.dir_util.copy_tree works way better than the half-baked +# shutil.copytree which bombs on previously existing target dirs... +# alas ... https://bugs.python.org/issue10948 +# we'll go ahead and use the copy_tree function anyways because this +# repacking is some short-lived hackery, right?? +from distutils.dir_util import copy_tree + + +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover + """Repack custom dependencies and code into an existing model TAR archive + + Args: + inference_script (str): The path to the custom entry point. + model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive. + dependencies (str): A space-delimited string of paths to custom dependencies. + source_dir (str): The path to a custom source directory. + """ + + # the data directory contains a model archive generated by a previous training job + data_directory = "/opt/ml/input/data/training" + model_path = os.path.join(data_directory, model_archive.split("/")[-1]) + + # create a temporary directory + with tempfile.TemporaryDirectory() as tmp: + local_path = os.path.join(tmp, "local.tar.gz") + # copy the previous training job's model archive to the temporary directory + shutil.copy2(model_path, local_path) + src_dir = os.path.join(tmp, "src") + # create the "code" directory which will contain the inference script + code_dir = os.path.join(src_dir, "code") + os.makedirs(code_dir) + # extract the contents of the previous training job's model archive to the "src" + # directory of this training job + with tarfile.open(name=local_path, mode="r:gz") as tf: + tf.extractall(path=src_dir) + + if source_dir: + # copy /opt/ml/code to code/ + if os.path.exists(code_dir): + shutil.rmtree(code_dir) + shutil.copytree("/opt/ml/code", code_dir) + else: + # copy the custom inference script to code/ + entry_point = os.path.join("/opt/ml/code", inference_script) + shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) + + # copy any dependencies to code/lib/ + if dependencies: + for dependency in dependencies.split(" "): + actual_dependency_path = os.path.join("/opt/ml/code", dependency) + lib_dir = os.path.join(code_dir, "lib") + if not os.path.exists(lib_dir): + os.mkdir(lib_dir) + if os.path.isfile(actual_dependency_path): + shutil.copy2(actual_dependency_path, lib_dir) + else: + if os.path.exists(lib_dir): + shutil.rmtree(lib_dir) + # a directory is in the dependencies. we have to copy + # all of /opt/ml/code into the lib dir because the original directory + # was flattened by the SDK training job upload.. + shutil.copytree("/opt/ml/code", lib_dir) + break + + # copy the "src" dir, which includes the previous training job's model and the + # custom inference script, to the output of this training job + copy_tree(src_dir, "/opt/ml/model") + + +if __name__ == "__main__": # pragma: no cover + parser = argparse.ArgumentParser() + parser.add_argument("--inference_script", type=str, default="inference.py") + parser.add_argument("--dependencies", type=str, default=None) + parser.add_argument("--source_dir", type=str, default=None) + parser.add_argument("--model_archive", type=str, default="model.tar.gz") + args, extra = parser.parse_known_args() + repack( + inference_script=args.inference_script, + dependencies=args.dependencies, + source_dir=args.source_dir, + model_archive=args.model_archive, + ) From d5244167c2d5c0521a8dffcc383972651348a463 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 14 Jul 2022 10:27:31 -0700 Subject: [PATCH 03/46] startarn added to lineageQueryResult --- src/sagemaker/lineage/query.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 7345911df0..400cf5ceeb 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,6 +15,7 @@ from datetime import datetime from enum import Enum +from tracemalloc import start from typing import Optional, Union, List, Dict from sagemaker.lineage._utils import get_resource_name_from_arn @@ -208,6 +209,7 @@ def __init__( self, edges: List[Edge] = None, vertices: List[Vertex] = None, + startarn: List[str] = None, ): """Init for LineageQueryResult. @@ -217,6 +219,7 @@ def __init__( """ self.edges = [] self.vertices = [] + self.startarn = [] if edges is not None: self.edges = edges @@ -224,6 +227,9 @@ def __init__( if vertices is not None: self.vertices = vertices + if startarn is not None: + self.startarn = startarn + def __str__(self): """Define string representation of ``LineageQueryResult``. @@ -248,7 +254,7 @@ def __str__(self): """ result_dict = vars(self) - return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + return str({k: [str(val) for val in v] for k, v in result_dict.items()}) def _import_visual_modules(self): """Import modules needed for visualization.""" @@ -417,9 +423,8 @@ def _get_vertex(self, vertex): sagemaker_session=self._session, ) - def _convert_api_response(self, response) -> LineageQueryResult: + def _convert_api_response(self, response, converted) -> LineageQueryResult: """Convert the lineage query API response to its Python representation.""" - converted = LineageQueryResult() converted.edges = [self._get_edge(edge) for edge in response["Edges"]] converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]] @@ -502,7 +507,9 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) - query_response = self._convert_api_response(query_response) + # create query result for startarn info + query_result = LineageQueryResult(startarn=start_arns) + query_response = self._convert_api_response(query_response, query_result) query_response = self._collapse_cross_account_artifacts(query_response) return query_response From acf64f444abe76ea0ae2eab081481d04d2736674 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 14 Jul 2022 11:21:20 -0700 Subject: [PATCH 04/46] color node by lineage entity --- src/sagemaker/lineage/query.py | 44 ++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 400cf5ceeb..a6b56ae7a8 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -236,18 +236,22 @@ def __str__(self): Format: { 'edges':[ - { + "{ 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' - }, + }", ... - ] + ], 'vertices':[ - { + "{ 'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': - }, + }", + ... + ], + 'startarn':[ + 'string', ... ] } @@ -270,7 +274,7 @@ def _get_verts(self): """Convert vertices to tuple format for visualizer""" verts = [] for vert in self.vertices: - verts.append((vert.arn, vert.lineage_source)) + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) return verts def _get_edges(self): @@ -288,14 +292,16 @@ def visualize(self): cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts app = JupyterDash(__name__) + # get vertices and edges info for graph verts = self._get_verts() edges = self._get_edges() nodes = [ { "data": {"id": id, "label": label}, + "classes": classes } - for id, label in verts + for id, label, classes in verts ] edges = [ @@ -341,6 +347,30 @@ def visualize(self): "arrow-scale": "0.5" }, }, + { + "selector": ".Artifact", + "style": { + "background-color": "#146eb4" + } + }, + { + "selector": ".Context", + "style": { + "background-color": "#ff9900" + } + }, + { + "selector": ".TrialComponent", + "style": { + "background-color": "#f6cf61" + } + }, + { + "selector": ".Action", + "style": { + "background-color": "#88c396" + } + } ], responsive=True, ) From e07e83c774fad640761501080f6d50a0264f282c Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 14 Jul 2022 11:43:33 -0700 Subject: [PATCH 05/46] identify startarn node by shape --- src/sagemaker/lineage/query.py | 62 +++++++++++----------------------- 1 file changed, 19 insertions(+), 43 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a6b56ae7a8..030dc866c0 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,7 +15,6 @@ from datetime import datetime from enum import Enum -from tracemalloc import start from typing import Optional, Union, List, Dict from sagemaker.lineage._utils import get_resource_name_from_arn @@ -271,14 +270,17 @@ def _import_visual_modules(self): return cyto, JupyterDash, html def _get_verts(self): - """Convert vertices to tuple format for visualizer""" + """Convert vertices to tuple format for visualizer.""" verts = [] for vert in self.vertices: - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) + if vert.arn in self.startarn: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) + else: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) return verts def _get_edges(self): - """Convert edges to tuple format for visualizer""" + """Convert edges to tuple format for visualizer.""" edges = [] for edge in self.edges: edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) @@ -286,7 +288,6 @@ def _get_edges(self): def visualize(self): """Visualize lineage query result.""" - cyto, JupyterDash, html = self._import_visual_modules() cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts @@ -297,17 +298,11 @@ def visualize(self): edges = self._get_edges() nodes = [ - { - "data": {"id": id, "label": label}, - "classes": classes - } - for id, label, classes in verts + {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts ] edges = [ - { - "data": {"source": source, "target": target, "label": label} - } + {"data": {"source": source, "target": target, "label": label}} for source, target, label in edges ] @@ -322,13 +317,13 @@ def visualize(self): layout={"name": "klay"}, stylesheet=[ { - "selector": "node", + "selector": "node", "style": { - "label": "data(label)", - "font-size": "3.5vw", + "label": "data(label)", + "font-size": "3.5vw", "height": "10vw", - "width": "10vw" - } + "width": "10vw", + }, }, { "selector": "edge", @@ -344,33 +339,14 @@ def visualize(self): "target-arrow-color": "gray", "target-arrow-shape": "triangle", "line-color": "gray", - "arrow-scale": "0.5" + "arrow-scale": "0.5", }, }, - { - "selector": ".Artifact", - "style": { - "background-color": "#146eb4" - } - }, - { - "selector": ".Context", - "style": { - "background-color": "#ff9900" - } - }, - { - "selector": ".TrialComponent", - "style": { - "background-color": "#f6cf61" - } - }, - { - "selector": ".Action", - "style": { - "background-color": "#88c396" - } - } + {"selector": ".Artifact", "style": {"background-color": "#146eb4"}}, + {"selector": ".Context", "style": {"background-color": "#ff9900"}}, + {"selector": ".TrialComponent", "style": {"background-color": "#f6cf61"}}, + {"selector": ".Action", "style": {"background-color": "#88c396"}}, + {"selector": ".startarn", "style": {"shape": "star"}}, ], responsive=True, ) From 687d6b64439ed20388946da67e30206975f4dbc9 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 14 Jul 2022 13:15:52 -0700 Subject: [PATCH 06/46] Add code comments --- src/sagemaker/lineage/query.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 030dc866c0..83dedb5515 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -272,8 +272,10 @@ def _import_visual_modules(self): def _get_verts(self): """Convert vertices to tuple format for visualizer.""" verts = [] + # get vertex info in the form of (id, label, class) for vert in self.vertices: if vert.arn in self.startarn: + # add "startarn" class to node if arn is a startarn verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) else: verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) @@ -282,6 +284,7 @@ def _get_verts(self): def _get_edges(self): """Convert edges to tuple format for visualizer.""" edges = [] + # get edge info in the form of (source, target, label) for edge in self.edges: edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) return edges From e3a4c9de08a3e56bcdbc983a651befd495a943e5 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 15 Jul 2022 13:35:28 -0700 Subject: [PATCH 07/46] Double sided arrows handled --- src/sagemaker/lineage/query.py | 37 ++++++++--- tests/data/_repack_model.py | 110 --------------------------------- 2 files changed, 29 insertions(+), 118 deletions(-) delete mode 100644 tests/data/_repack_model.py diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 83dedb5515..5361458290 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -267,7 +267,9 @@ def _import_visual_modules(self): from dash import html - return cyto, JupyterDash, html + from dash.dependencies import Input, Output + + return cyto, JupyterDash, html, Input, Output def _get_verts(self): """Convert vertices to tuple format for visualizer.""" @@ -287,11 +289,12 @@ def _get_edges(self): # get edge info in the form of (source, target, label) for edge in self.edges: edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + edges.append((self.edges[1].destination_arn, self.edges[1].source_arn, self.edges[1].association_type)) return edges def visualize(self): """Visualize lineage query result.""" - cyto, JupyterDash, html = self._import_visual_modules() + cyto, JupyterDash, html, Input, Output = self._import_visual_modules() cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts app = JupyterDash(__name__) @@ -314,7 +317,7 @@ def visualize(self): app.layout = html.Div( [ cyto.Cytoscape( - id="cytoscape-layout-1", + id="cytoscape-graph", elements=elements, style={"width": "100%", "height": "350px"}, layout={"name": "klay"}, @@ -326,6 +329,9 @@ def visualize(self): "font-size": "3.5vw", "height": "10vw", "width": "10vw", + "border-width": "0.8", + "border-opacity": "0", + "border-color": "#232f3e" }, }, { @@ -334,11 +340,13 @@ def visualize(self): "label": "data(label)", "color": "gray", "text-halign": "left", - "text-margin-y": "3px", - "text-margin-x": "-2px", - "font-size": "3%", - "width": "1%", - "curve-style": "taxi", + "text-margin-y": "2.5", + "font-size": "3", + "width": "1", + "curve-style": "bezier", + "control-point-step-size": "15", + # "taxi-direction": "rightward", + # "taxi-turn": "50%", "target-arrow-color": "gray", "target-arrow-shape": "triangle", "line-color": "gray", @@ -350,12 +358,25 @@ def visualize(self): {"selector": ".TrialComponent", "style": {"background-color": "#f6cf61"}}, {"selector": ".Action", "style": {"background-color": "#88c396"}}, {"selector": ".startarn", "style": {"shape": "star"}}, + {"selector": ".select", "style": { "border-opacity": "0.7"}}, ], responsive=True, ) ] ) + @app.callback(Output("cytoscape-graph", "elements"), + Input("cytoscape-graph", "tapNodeData")) + def selectNode(data): + for n in nodes: + if data != None and n["data"]["id"] == data["id"]: + n["classes"] += " select" + else: + n["classes"] = n["classes"].replace("select", "") + + elements = nodes + edges + return elements + return app.run_server(mode="inline") diff --git a/tests/data/_repack_model.py b/tests/data/_repack_model.py deleted file mode 100644 index 3cfa6760b3..0000000000 --- a/tests/data/_repack_model.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Repack model script for training jobs to inject entry points""" -from __future__ import absolute_import - -import argparse -import os -import shutil -import tarfile -import tempfile - -# Repack Model -# The following script is run via a training job which takes an existing model and a custom -# entry point script as arguments. The script creates a new model archive with the custom -# entry point in the "code" directory along with the existing model. Subsequently, when the model -# is unpacked for inference, the custom entry point will be used. -# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html - -# distutils.dir_util.copy_tree works way better than the half-baked -# shutil.copytree which bombs on previously existing target dirs... -# alas ... https://bugs.python.org/issue10948 -# we'll go ahead and use the copy_tree function anyways because this -# repacking is some short-lived hackery, right?? -from distutils.dir_util import copy_tree - - -def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover - """Repack custom dependencies and code into an existing model TAR archive - - Args: - inference_script (str): The path to the custom entry point. - model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive. - dependencies (str): A space-delimited string of paths to custom dependencies. - source_dir (str): The path to a custom source directory. - """ - - # the data directory contains a model archive generated by a previous training job - data_directory = "/opt/ml/input/data/training" - model_path = os.path.join(data_directory, model_archive.split("/")[-1]) - - # create a temporary directory - with tempfile.TemporaryDirectory() as tmp: - local_path = os.path.join(tmp, "local.tar.gz") - # copy the previous training job's model archive to the temporary directory - shutil.copy2(model_path, local_path) - src_dir = os.path.join(tmp, "src") - # create the "code" directory which will contain the inference script - code_dir = os.path.join(src_dir, "code") - os.makedirs(code_dir) - # extract the contents of the previous training job's model archive to the "src" - # directory of this training job - with tarfile.open(name=local_path, mode="r:gz") as tf: - tf.extractall(path=src_dir) - - if source_dir: - # copy /opt/ml/code to code/ - if os.path.exists(code_dir): - shutil.rmtree(code_dir) - shutil.copytree("/opt/ml/code", code_dir) - else: - # copy the custom inference script to code/ - entry_point = os.path.join("/opt/ml/code", inference_script) - shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) - - # copy any dependencies to code/lib/ - if dependencies: - for dependency in dependencies.split(" "): - actual_dependency_path = os.path.join("/opt/ml/code", dependency) - lib_dir = os.path.join(code_dir, "lib") - if not os.path.exists(lib_dir): - os.mkdir(lib_dir) - if os.path.isfile(actual_dependency_path): - shutil.copy2(actual_dependency_path, lib_dir) - else: - if os.path.exists(lib_dir): - shutil.rmtree(lib_dir) - # a directory is in the dependencies. we have to copy - # all of /opt/ml/code into the lib dir because the original directory - # was flattened by the SDK training job upload.. - shutil.copytree("/opt/ml/code", lib_dir) - break - - # copy the "src" dir, which includes the previous training job's model and the - # custom inference script, to the output of this training job - copy_tree(src_dir, "/opt/ml/model") - - -if __name__ == "__main__": # pragma: no cover - parser = argparse.ArgumentParser() - parser.add_argument("--inference_script", type=str, default="inference.py") - parser.add_argument("--dependencies", type=str, default=None) - parser.add_argument("--source_dir", type=str, default=None) - parser.add_argument("--model_archive", type=str, default="model.tar.gz") - args, extra = parser.parse_known_args() - repack( - inference_script=args.inference_script, - dependencies=args.dependencies, - source_dir=args.source_dir, - model_archive=args.model_archive, - ) From 1808f8799518cab34257ca89e8e1955649229903 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 18 Jul 2022 09:50:25 -0700 Subject: [PATCH 08/46] legend added --- src/sagemaker/lineage/query.py | 103 +++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 5361458290..6f526ce367 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -289,7 +289,6 @@ def _get_edges(self): # get edge info in the form of (source, target, label) for edge in self.edges: edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) - edges.append((self.edges[1].destination_arn, self.edges[1].source_arn, self.edges[1].association_type)) return edges def visualize(self): @@ -319,7 +318,7 @@ def visualize(self): cyto.Cytoscape( id="cytoscape-graph", elements=elements, - style={"width": "100%", "height": "350px"}, + style={"width": "85%", "height": "350px", 'display': 'inline-block', 'border-width': '1vw', "border-color": "#232f3e"}, layout={"name": "klay"}, stylesheet=[ { @@ -331,7 +330,8 @@ def visualize(self): "width": "10vw", "border-width": "0.8", "border-opacity": "0", - "border-color": "#232f3e" + "border-color": "#232f3e", + "font-family": "verdana" }, }, { @@ -351,6 +351,7 @@ def visualize(self): "target-arrow-shape": "triangle", "line-color": "gray", "arrow-scale": "0.5", + "font-family": "verdana" }, }, {"selector": ".Artifact", "style": {"background-color": "#146eb4"}}, @@ -361,7 +362,101 @@ def visualize(self): {"selector": ".select", "style": { "border-opacity": "0.7"}}, ], responsive=True, - ) + ), + html.Div([ + html.Div([ + html.Div( + style={ + 'background-color': "#f6cf61", + 'width': '1.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div( + style={ + 'width': '0.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div(' Trial Component', style={'display': 'inline-block', "font-size": "1.5vw"}), + ]), + html.Div([ + html.Div( + style={ + 'background-color': "#ff9900", + 'width': '1.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div( + style={ + 'width': '0.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div(' Context', style={'display': 'inline-block', "font-size": "1.5vw"}), + ]), + html.Div([ + html.Div( + style={ + 'background-color': "#88c396", + 'width': '1.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div( + style={ + 'width': '0.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div(' Action', style={'display': 'inline-block', "font-size": "1.5vw"}), + ]), + html.Div([ + html.Div( + style={ + 'background-color': "#146eb4", + 'width': '1.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div( + style={ + 'width': '0.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div(' Artifact', style={'display': 'inline-block', "font-size": "1.5vw"}), + ]), + html.Div([ + html.Div( + "★", + style={ + 'background-color': "white", + 'width': '1.5vw', + 'height': '1.5vw', + 'display': 'inline-block', + "font-size": "1.5vw" + } + ), + html.Div( + style={ + 'width': '0.5vw', + 'height': '1.5vw', + 'display': 'inline-block' + } + ), + html.Div('StartArn', style={'display': 'inline-block', "font-size": "1.5vw"}), + ]), + ], style={'width': '15%', 'display': 'inline-block', "font-size": "1vw", "font-family": "verdana", "vertical-align": "top"}) ] ) From 79b193af5dd93c24c8fd421ed7b7f6ae4b040216 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 7 Jul 2022 17:34:45 -0700 Subject: [PATCH 09/46] feature: query lineage visualizer for general case edge.association_type added style changes of graph DashVisualizer class added Try except when importing visual modules --- src/sagemaker/lineage/query.py | 155 ++++++++++++++++++++++++++++++--- 1 file changed, 142 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 72bde00a1a..a198d1ebf5 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -97,12 +97,12 @@ def __str__(self): Format: { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' } - + """ - return (str(self.__dict__)) + return str(self.__dict__) class Vertex: @@ -147,13 +147,13 @@ def __str__(self): Format: { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': } - + """ - return (str(self.__dict__)) + return str(self.__dict__) def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -201,6 +201,90 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) +class DashVisualizer(object): + """Create object used for visualizing graph using Dash library.""" + + def __init__(self): + """Init for DashVisualizer.""" + # import visualization packages + self.cyto, self.JupyterDash, self.html = self._import_visual_modules() + + def _import_visual_modules(self): + """Import modules needed for visualization.""" + try: + import dash_cytoscape as cyto + except ImportError as e: + print(e) + print("try pip install dash-cytoscape") + + try: + from jupyter_dash import JupyterDash + except ImportError as e: + print(e) + print("try pip install jupyter-dash") + + try: + from dash import html + except ImportError as e: + print(e) + print("try pip install dash") + + return cyto, JupyterDash, html + + def _get_app(self, elements): + """Create JupyterDash app for interactivity on Jupyter notebook.""" + app = self.JupyterDash(__name__) + self.cyto.load_extra_layouts() + + app.layout = self.html.Div( + [ + self.cyto.Cytoscape( + id="cytoscape-layout-1", + elements=elements, + style={"width": "100%", "height": "350px"}, + layout={"name": "klay"}, + stylesheet=[ + { + "selector": "node", + "style": { + "label": "data(label)", + "font-size": "3.5vw", + "height": "10vw", + "width": "10vw", + }, + }, + { + "selector": "edge", + "style": { + "label": "data(label)", + "color": "gray", + "text-halign": "left", + "text-margin-y": "3px", + "text-margin-x": "-2px", + "font-size": "3%", + "width": "1%", + "curve-style": "taxi", + "target-arrow-color": "gray", + "target-arrow-shape": "triangle", + "line-color": "gray", + "arrow-scale": "0.5", + }, + }, + ], + responsive=True, + ) + ] + ) + + return app + + def render(self, elements, mode): + """Render graph for lineage query result.""" + app = self._get_app(elements) + + return app.run_server(mode=mode) + + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -226,29 +310,74 @@ def __init__( def __str__(self): """Define string representation of ``LineageQueryResult``. - + Format: { 'edges':[ { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' }, ... ] 'vertices':[ { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': }, ... ] } - + """ result_dict = vars(self) - return (str({k: [vars(val) for val in v] for k, v in result_dict.items()})) + return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + + def _covert_vertices_to_tuples(self): + """Convert vertices to tuple format for visualizer.""" + verts = [] + for vert in self.vertices: + verts.append((vert.arn, vert.lineage_source)) + return verts + + def _covert_edges_to_tuples(self): + """Convert edges to tuple format for visualizer.""" + edges = [] + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + + def _get_visualization_elements(self): + """Get elements for visualization.""" + verts = self._covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() + + nodes = [ + { + "data": {"id": id, "label": label}, + } + for id, label in verts + ] + + edges = [ + {"data": {"source": source, "target": target, "label": label}} + for source, target, label in edges + ] + + elements = nodes + edges + + return elements + + def visualize(self): + """Visualize lineage query result.""" + elements = self._get_visualization_elements() + + dash_vis = DashVisualizer() + + dash_server = dash_vis.render(elements=elements, mode="inline") + + return dash_server class LineageFilter(object): From 60904d5dbe6c77873d610ed7580de0bb0cd805bf Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 18 Jul 2022 17:07:54 -0700 Subject: [PATCH 10/46] try except raise --- src/sagemaker/lineage/query.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a198d1ebf5..3f04800a50 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -215,19 +215,22 @@ def _import_visual_modules(self): import dash_cytoscape as cyto except ImportError as e: print(e) - print("try pip install dash-cytoscape") + print("Try: pip install dash-cytoscape") + raise try: from jupyter_dash import JupyterDash except ImportError as e: print(e) - print("try pip install jupyter-dash") + print("Try: pip install jupyter-dash") + raise try: from dash import html except ImportError as e: print(e) - print("try pip install dash") + print("Try: pip install dash") + raise return cyto, JupyterDash, html From 9125eb930e930375bb27d45d22d0f9765db24e08 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 19 Jul 2022 10:34:02 -0700 Subject: [PATCH 11/46] startarn added --- src/sagemaker/lineage/query.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 3f04800a50..86fd805d81 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,6 +15,7 @@ from datetime import datetime from enum import Enum +from tracemalloc import start from typing import Optional, Union, List, Dict from sagemaker.lineage._utils import get_resource_name_from_arn @@ -295,6 +296,7 @@ def __init__( self, edges: List[Edge] = None, vertices: List[Vertex] = None, + startarn: List[str] = None, ): """Init for LineageQueryResult. @@ -304,6 +306,7 @@ def __init__( """ self.edges = [] self.vertices = [] + self.startarn = [] if edges is not None: self.edges = edges @@ -311,6 +314,9 @@ def __init__( if vertices is not None: self.vertices = vertices + if startarn is not None: + self.startarn = startarn + def __str__(self): """Define string representation of ``LineageQueryResult``. @@ -335,7 +341,7 @@ def __str__(self): """ result_dict = vars(self) - return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + return str({k: [str(val) for val in v] for k, v in result_dict.items()}) def _covert_vertices_to_tuples(self): """Convert vertices to tuple format for visualizer.""" @@ -456,9 +462,8 @@ def _get_vertex(self, vertex): sagemaker_session=self._session, ) - def _convert_api_response(self, response) -> LineageQueryResult: + def _convert_api_response(self, response, converted) -> LineageQueryResult: """Convert the lineage query API response to its Python representation.""" - converted = LineageQueryResult() converted.edges = [self._get_edge(edge) for edge in response["Edges"]] converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]] @@ -541,7 +546,9 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) - query_response = self._convert_api_response(query_response) + # create query result for startarn info + query_result = LineageQueryResult(startarn=start_arns) + query_response = self._convert_api_response(query_response, query_result) query_response = self._collapse_cross_account_artifacts(query_response) return query_response From 3db7fa17ac311ca76a583f5e897736085a229f65 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 19 Jul 2022 11:57:40 -0700 Subject: [PATCH 12/46] add get element function --- src/sagemaker/lineage/query.py | 50 ++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 6f526ce367..4fb62bad37 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -235,18 +235,18 @@ def __str__(self): Format: { 'edges':[ - "{ + { 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' - }", + }, ... ], 'vertices':[ - "{ + { 'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': - }", + }, ... ], 'startarn':[ @@ -271,7 +271,7 @@ def _import_visual_modules(self): return cyto, JupyterDash, html, Input, Output - def _get_verts(self): + def _covert_vertices_to_tuples(self): """Convert vertices to tuple format for visualizer.""" verts = [] # get vertex info in the form of (id, label, class) @@ -283,7 +283,7 @@ def _get_verts(self): verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) return verts - def _get_edges(self): + def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" edges = [] # get edge info in the form of (source, target, label) @@ -291,16 +291,11 @@ def _get_edges(self): edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) return edges - def visualize(self): - """Visualize lineage query result.""" - cyto, JupyterDash, html, Input, Output = self._import_visual_modules() - - cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts - app = JupyterDash(__name__) - + def _get_visualization_elements(self): + """Get elements for visualization.""" # get vertices and edges info for graph - verts = self._get_verts() - edges = self._get_edges() + verts = self._covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() nodes = [ {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts @@ -313,6 +308,17 @@ def visualize(self): elements = nodes + edges + return elements + + def visualize(self): + """Visualize lineage query result.""" + cyto, JupyterDash, html, Input, Output = self._import_visual_modules() + + cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts + app = JupyterDash(__name__) + + elements = self._get_visualization_elements() + app.layout = html.Div( [ cyto.Cytoscape( @@ -461,15 +467,17 @@ def visualize(self): ) @app.callback(Output("cytoscape-graph", "elements"), - Input("cytoscape-graph", "tapNodeData")) - def selectNode(data): - for n in nodes: - if data != None and n["data"]["id"] == data["id"]: + Input("cytoscape-graph", "tapNodeData"), + Input("cytoscape-graph", "elements")) + def selectNode(tapData, elements): + for n in elements: + if tapData != None and n["data"]["id"] == tapData["id"]: + # if is tapped node, add "select" class to node n["classes"] += " select" - else: + elif "classes" in n: + # remove "select" class in "classes" if node not selected n["classes"] = n["classes"].replace("select", "") - elements = nodes + edges return elements return app.run_server(mode="inline") From c135ab957c996a665cae8e2818ad79f7516580b0 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 19 Jul 2022 12:19:41 -0700 Subject: [PATCH 13/46] add DashVisualizer class --- src/sagemaker/lineage/query.py | 313 +++++++++++++++++++-------------- 1 file changed, 177 insertions(+), 136 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 4fb62bad37..8201f0647b 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -201,127 +201,55 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) -class LineageQueryResult(object): - """A wrapper around the results of a lineage query.""" - - def __init__( - self, - edges: List[Edge] = None, - vertices: List[Vertex] = None, - startarn: List[str] = None, - ): - """Init for LineageQueryResult. - - Args: - edges (List[Edge]): The edges of the query result. - vertices (List[Vertex]): The vertices of the query result. - """ - self.edges = [] - self.vertices = [] - self.startarn = [] - - if edges is not None: - self.edges = edges - - if vertices is not None: - self.vertices = vertices - - if startarn is not None: - self.startarn = startarn - - def __str__(self): - """Define string representation of ``LineageQueryResult``. - - Format: - { - 'edges':[ - { - 'source_arn': 'string', 'destination_arn': 'string', - 'association_type': 'string' - }, - ... - ], - 'vertices':[ - { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', - '_session': - }, - ... - ], - 'startarn':[ - 'string', - ... - ] - } +class DashVisualizer(object): + """Create object used for visualizing graph using Dash library.""" - """ - result_dict = vars(self) - return str({k: [str(val) for val in v] for k, v in result_dict.items()}) + def __init__(self): + """Init for DashVisualizer.""" + # import visualization packages + self.cyto, self.JupyterDash, self.html, self.Input, self.Output = self._import_visual_modules() def _import_visual_modules(self): """Import modules needed for visualization.""" - import dash_cytoscape as cyto + try: + import dash_cytoscape as cyto + except ImportError as e: + print(e) + print("Try: pip install dash-cytoscape") + raise + + try: + from jupyter_dash import JupyterDash + except ImportError as e: + print(e) + print("Try: pip install jupyter-dash") + raise + + try: + from dash import html + except ImportError as e: + print(e) + print("Try: pip install dash") + raise + + try: + from dash.dependencies import Input, Output + except ImportError as e: + print(e) + print("Try: pip install dash") + raise - from jupyter_dash import JupyterDash - - from dash import html - - from dash.dependencies import Input, Output return cyto, JupyterDash, html, Input, Output - def _covert_vertices_to_tuples(self): - """Convert vertices to tuple format for visualizer.""" - verts = [] - # get vertex info in the form of (id, label, class) - for vert in self.vertices: - if vert.arn in self.startarn: - # add "startarn" class to node if arn is a startarn - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) - else: - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) - return verts - - def _covert_edges_to_tuples(self): - """Convert edges to tuple format for visualizer.""" - edges = [] - # get edge info in the form of (source, target, label) - for edge in self.edges: - edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) - return edges - - def _get_visualization_elements(self): - """Get elements for visualization.""" - # get vertices and edges info for graph - verts = self._covert_vertices_to_tuples() - edges = self._covert_edges_to_tuples() - - nodes = [ - {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts - ] - - edges = [ - {"data": {"source": source, "target": target, "label": label}} - for source, target, label in edges - ] - - elements = nodes + edges + def _get_app(self, elements): + """Create JupyterDash app for interactivity on Jupyter notebook.""" + app = self.JupyterDash(__name__) + self.cyto.load_extra_layouts() - return elements - - def visualize(self): - """Visualize lineage query result.""" - cyto, JupyterDash, html, Input, Output = self._import_visual_modules() - - cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts - app = JupyterDash(__name__) - - elements = self._get_visualization_elements() - - app.layout = html.Div( + app.layout = self.html.Div( [ - cyto.Cytoscape( + self.cyto.Cytoscape( id="cytoscape-graph", elements=elements, style={"width": "85%", "height": "350px", 'display': 'inline-block', 'border-width': '1vw', "border-color": "#232f3e"}, @@ -369,9 +297,9 @@ def visualize(self): ], responsive=True, ), - html.Div([ - html.Div([ - html.Div( + self.html.Div([ + self.html.Div([ + self.html.Div( style={ 'background-color': "#f6cf61", 'width': '1.5vw', @@ -379,17 +307,17 @@ def visualize(self): 'display': 'inline-block' } ), - html.Div( + self.html.Div( style={ 'width': '0.5vw', 'height': '1.5vw', 'display': 'inline-block' } ), - html.Div(' Trial Component', style={'display': 'inline-block', "font-size": "1.5vw"}), + self.html.Div(' Trial Component', style={'display': 'inline-block', "font-size": "1.5vw"}), ]), - html.Div([ - html.Div( + self.html.Div([ + self.html.Div( style={ 'background-color': "#ff9900", 'width': '1.5vw', @@ -397,17 +325,17 @@ def visualize(self): 'display': 'inline-block' } ), - html.Div( + self.html.Div( style={ 'width': '0.5vw', 'height': '1.5vw', 'display': 'inline-block' } ), - html.Div(' Context', style={'display': 'inline-block', "font-size": "1.5vw"}), + self.html.Div(' Context', style={'display': 'inline-block', "font-size": "1.5vw"}), ]), - html.Div([ - html.Div( + self.html.Div([ + self.html.Div( style={ 'background-color': "#88c396", 'width': '1.5vw', @@ -415,17 +343,17 @@ def visualize(self): 'display': 'inline-block' } ), - html.Div( + self.html.Div( style={ 'width': '0.5vw', 'height': '1.5vw', 'display': 'inline-block' } ), - html.Div(' Action', style={'display': 'inline-block', "font-size": "1.5vw"}), + self.html.Div(' Action', style={'display': 'inline-block', "font-size": "1.5vw"}), ]), - html.Div([ - html.Div( + self.html.Div([ + self.html.Div( style={ 'background-color': "#146eb4", 'width': '1.5vw', @@ -433,17 +361,17 @@ def visualize(self): 'display': 'inline-block' } ), - html.Div( + self.html.Div( style={ 'width': '0.5vw', 'height': '1.5vw', 'display': 'inline-block' } ), - html.Div(' Artifact', style={'display': 'inline-block', "font-size": "1.5vw"}), + self.html.Div(' Artifact', style={'display': 'inline-block', "font-size": "1.5vw"}), ]), - html.Div([ - html.Div( + self.html.Div([ + self.html.Div( "★", style={ 'background-color': "white", @@ -453,22 +381,22 @@ def visualize(self): "font-size": "1.5vw" } ), - html.Div( + self.html.Div( style={ 'width': '0.5vw', 'height': '1.5vw', 'display': 'inline-block' } ), - html.Div('StartArn', style={'display': 'inline-block', "font-size": "1.5vw"}), + self.html.Div('StartArn', style={'display': 'inline-block', "font-size": "1.5vw"}), ]), ], style={'width': '15%', 'display': 'inline-block', "font-size": "1vw", "font-family": "verdana", "vertical-align": "top"}) ] ) - @app.callback(Output("cytoscape-graph", "elements"), - Input("cytoscape-graph", "tapNodeData"), - Input("cytoscape-graph", "elements")) + @app.callback(self.Output("cytoscape-graph", "elements"), + self.Input("cytoscape-graph", "tapNodeData"), + self.Input("cytoscape-graph", "elements")) def selectNode(tapData, elements): for n in elements: if tapData != None and n["data"]["id"] == tapData["id"]: @@ -480,8 +408,121 @@ def selectNode(tapData, elements): return elements - return app.run_server(mode="inline") + return app + + def render(self, elements, mode): + """Render graph for lineage query result.""" + app = self._get_app(elements) + + return app.run_server(mode=mode) + +class LineageQueryResult(object): + """A wrapper around the results of a lineage query.""" + + def __init__( + self, + edges: List[Edge] = None, + vertices: List[Vertex] = None, + startarn: List[str] = None, + ): + """Init for LineageQueryResult. + + Args: + edges (List[Edge]): The edges of the query result. + vertices (List[Vertex]): The vertices of the query result. + """ + self.edges = [] + self.vertices = [] + self.startarn = [] + + if edges is not None: + self.edges = edges + + if vertices is not None: + self.vertices = vertices + + if startarn is not None: + self.startarn = startarn + + def __str__(self): + """Define string representation of ``LineageQueryResult``. + + Format: + { + 'edges':[ + { + 'source_arn': 'string', 'destination_arn': 'string', + 'association_type': 'string' + }, + ... + ], + 'vertices':[ + { + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + }, + ... + ], + 'startarn':[ + 'string', + ... + ] + } + + """ + result_dict = vars(self) + return str({k: [str(val) for val in v] for k, v in result_dict.items()}) + + def _covert_vertices_to_tuples(self): + """Convert vertices to tuple format for visualizer.""" + verts = [] + # get vertex info in the form of (id, label, class) + for vert in self.vertices: + if vert.arn in self.startarn: + # add "startarn" class to node if arn is a startarn + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) + else: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) + return verts + + def _covert_edges_to_tuples(self): + """Convert edges to tuple format for visualizer.""" + edges = [] + # get edge info in the form of (source, target, label) + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + + def _get_visualization_elements(self): + """Get elements for visualization.""" + # get vertices and edges info for graph + verts = self._covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() + + nodes = [ + {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts + ] + + edges = [ + {"data": {"source": source, "target": target, "label": label}} + for source, target, label in edges + ] + + elements = nodes + edges + + return elements + + def visualize(self): + """Visualize lineage query result.""" + elements = self._get_visualization_elements() + + # initialize DashVisualizer instance to render graph & interactive components + dash_vis = DashVisualizer() + + dash_server = dash_vis.render(elements=elements, mode="inline") + return dash_server class LineageFilter(object): """A filter used in a lineage query.""" From daa5cc438a9677e07705dcc8b63760b0f598991d Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 19 Jul 2022 12:29:07 -0700 Subject: [PATCH 14/46] style check --- src/sagemaker/lineage/query.py | 251 +++++++++++++++++++-------------- 1 file changed, 149 insertions(+), 102 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 8201f0647b..d51b7c685d 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -207,7 +207,13 @@ class DashVisualizer(object): def __init__(self): """Init for DashVisualizer.""" # import visualization packages - self.cyto, self.JupyterDash, self.html, self.Input, self.Output = self._import_visual_modules() + ( + self.cyto, + self.JupyterDash, + self.html, + self.Input, + self.Output, + ) = self._import_visual_modules() def _import_visual_modules(self): """Import modules needed for visualization.""" @@ -239,7 +245,6 @@ def _import_visual_modules(self): print("Try: pip install dash") raise - return cyto, JupyterDash, html, Input, Output def _get_app(self, elements): @@ -252,7 +257,13 @@ def _get_app(self, elements): self.cyto.Cytoscape( id="cytoscape-graph", elements=elements, - style={"width": "85%", "height": "350px", 'display': 'inline-block', 'border-width': '1vw', "border-color": "#232f3e"}, + style={ + "width": "85%", + "height": "350px", + "display": "inline-block", + "border-width": "1vw", + "border-color": "#232f3e", + }, layout={"name": "klay"}, stylesheet=[ { @@ -263,9 +274,9 @@ def _get_app(self, elements): "height": "10vw", "width": "10vw", "border-width": "0.8", - "border-opacity": "0", + "border-opacity": "0", "border-color": "#232f3e", - "font-family": "verdana" + "font-family": "verdana", }, }, { @@ -279,13 +290,11 @@ def _get_app(self, elements): "width": "1", "curve-style": "bezier", "control-point-step-size": "15", - # "taxi-direction": "rightward", - # "taxi-turn": "50%", "target-arrow-color": "gray", "target-arrow-shape": "triangle", "line-color": "gray", "arrow-scale": "0.5", - "font-family": "verdana" + "font-family": "verdana", }, }, {"selector": ".Artifact", "style": {"background-color": "#146eb4"}}, @@ -293,113 +302,149 @@ def _get_app(self, elements): {"selector": ".TrialComponent", "style": {"background-color": "#f6cf61"}}, {"selector": ".Action", "style": {"background-color": "#88c396"}}, {"selector": ".startarn", "style": {"shape": "star"}}, - {"selector": ".select", "style": { "border-opacity": "0.7"}}, + {"selector": ".select", "style": {"border-opacity": "0.7"}}, ], responsive=True, ), - self.html.Div([ - self.html.Div([ - self.html.Div( - style={ - 'background-color': "#f6cf61", - 'width': '1.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } - ), - self.html.Div( - style={ - 'width': '0.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } - ), - self.html.Div(' Trial Component', style={'display': 'inline-block', "font-size": "1.5vw"}), - ]), - self.html.Div([ - self.html.Div( - style={ - 'background-color': "#ff9900", - 'width': '1.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } - ), - self.html.Div( - style={ - 'width': '0.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } - ), - self.html.Div(' Context', style={'display': 'inline-block', "font-size": "1.5vw"}), - ]), - self.html.Div([ - self.html.Div( - style={ - 'background-color': "#88c396", - 'width': '1.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } - ), + self.html.Div( + [ self.html.Div( - style={ - 'width': '0.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } + [ + self.html.Div( + style={ + "background-color": "#f6cf61", + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + " Trial Component", + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] ), - self.html.Div(' Action', style={'display': 'inline-block', "font-size": "1.5vw"}), - ]), - self.html.Div([ self.html.Div( - style={ - 'background-color': "#146eb4", - 'width': '1.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } + [ + self.html.Div( + style={ + "background-color": "#ff9900", + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + " Context", + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] ), self.html.Div( - style={ - 'width': '0.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } + [ + self.html.Div( + style={ + "background-color": "#88c396", + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + " Action", + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] ), - self.html.Div(' Artifact', style={'display': 'inline-block', "font-size": "1.5vw"}), - ]), - self.html.Div([ self.html.Div( - "★", - style={ - 'background-color': "white", - 'width': '1.5vw', - 'height': '1.5vw', - 'display': 'inline-block', - "font-size": "1.5vw" - } + [ + self.html.Div( + style={ + "background-color": "#146eb4", + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + " Artifact", + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] ), self.html.Div( - style={ - 'width': '0.5vw', - 'height': '1.5vw', - 'display': 'inline-block' - } + [ + self.html.Div( + "★", + style={ + "background-color": "white", + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + "font-size": "1.5vw", + }, + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + "StartArn", + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] ), - self.html.Div('StartArn', style={'display': 'inline-block', "font-size": "1.5vw"}), - ]), - ], style={'width': '15%', 'display': 'inline-block', "font-size": "1vw", "font-family": "verdana", "vertical-align": "top"}) + ], + style={ + "width": "15%", + "display": "inline-block", + "font-size": "1vw", + "font-family": "verdana", + "vertical-align": "top", + }, + ), ] ) - @app.callback(self.Output("cytoscape-graph", "elements"), - self.Input("cytoscape-graph", "tapNodeData"), - self.Input("cytoscape-graph", "elements")) + @app.callback( + self.Output("cytoscape-graph", "elements"), + self.Input("cytoscape-graph", "tapNodeData"), + self.Input("cytoscape-graph", "elements"), + ) def selectNode(tapData, elements): for n in elements: - if tapData != None and n["data"]["id"] == tapData["id"]: + if tapData is not None and n["data"]["id"] == tapData["id"]: # if is tapped node, add "select" class to node n["classes"] += " select" elif "classes" in n: @@ -416,6 +461,7 @@ def render(self, elements, mode): return app.run_server(mode=mode) + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -450,18 +496,18 @@ def __str__(self): Format: { 'edges':[ - { + "{ 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' - }, + }", ... ], 'vertices':[ - { + "{ 'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': - }, + }", ... ], 'startarn':[ @@ -515,7 +561,7 @@ def _get_visualization_elements(self): def visualize(self): """Visualize lineage query result.""" - elements = self._get_visualization_elements() + elements = self._get_visualization_elements() # initialize DashVisualizer instance to render graph & interactive components dash_vis = DashVisualizer() @@ -524,6 +570,7 @@ def visualize(self): return dash_server + class LineageFilter(object): """A filter used in a lineage query.""" From 28a9eeb75efa77121e889053c62a5d21662ff86c Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 19 Jul 2022 14:23:29 -0700 Subject: [PATCH 15/46] feature: query lineage visualizer advanced styling & interactive component handle --- src/sagemaker/lineage/query.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index d51b7c685d..8087e4720f 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -258,7 +258,7 @@ def _get_app(self, elements): id="cytoscape-graph", elements=elements, style={ - "width": "85%", + "width": "84%", "height": "350px", "display": "inline-block", "border-width": "1vw", @@ -306,6 +306,15 @@ def _get_app(self, elements): ], responsive=True, ), + self.html.Div( + style={ + "width": "0.5%", + "display": "inline-block", + "font-size": "1vw", + "font-family": "verdana", + "vertical-align": "top", + }, + ), self.html.Div( [ self.html.Div( @@ -427,7 +436,6 @@ def _get_app(self, elements): ), ], style={ - "width": "15%", "display": "inline-block", "font-size": "1vw", "font-family": "verdana", From 7c2b0c314b042838e4da482d172f34c7350d831b Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 20 Jul 2022 16:41:15 -0700 Subject: [PATCH 16/46] add functions that generate html components and style selectors --- src/sagemaker/lineage/query.py | 169 +++++++++------------------------ 1 file changed, 45 insertions(+), 124 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 8087e4720f..8e1ec3000e 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -215,6 +215,13 @@ def __init__(self): self.Output, ) = self._import_visual_modules() + self.entity_color = { + "TrialComponent": "#f6cf61", + "Context": "#ff9900", + "Action": "#88c396", + "Artifact": "#146eb4", + } + def _import_visual_modules(self): """Import modules needed for visualization.""" try: @@ -247,6 +254,38 @@ def _import_visual_modules(self): return cyto, JupyterDash, html, Input, Output + def _create_legend_component(self, text, color, colorText=""): + """Create legend component div.""" + return self.html.Div( + [ + self.html.Div( + colorText, + style={ + "background-color": color, + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + "font-size": "1.5vw", + }, + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + text, + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] + ) + + def _create_entity_selector(self, entity_name, color): + """Create selector for each lineage entity.""" + return {"selector": "." + entity_name, "style": {"background-color": color}} + def _get_app(self, elements): """Create JupyterDash app for interactivity on Jupyter notebook.""" app = self.JupyterDash(__name__) @@ -254,6 +293,7 @@ def _get_app(self, elements): app.layout = self.html.Div( [ + # graph section self.cyto.Cytoscape( id="cytoscape-graph", elements=elements, @@ -297,13 +337,10 @@ def _get_app(self, elements): "font-family": "verdana", }, }, - {"selector": ".Artifact", "style": {"background-color": "#146eb4"}}, - {"selector": ".Context", "style": {"background-color": "#ff9900"}}, - {"selector": ".TrialComponent", "style": {"background-color": "#f6cf61"}}, - {"selector": ".Action", "style": {"background-color": "#88c396"}}, {"selector": ".startarn", "style": {"shape": "star"}}, {"selector": ".select", "style": {"border-opacity": "0.7"}}, - ], + ] + + [self._create_entity_selector(k, v) for k, v in self.entity_color.items()], responsive=True, ), self.html.Div( @@ -315,126 +352,10 @@ def _get_app(self, elements): "vertical-align": "top", }, ), + # legend section self.html.Div( - [ - self.html.Div( - [ - self.html.Div( - style={ - "background-color": "#f6cf61", - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - " Trial Component", - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ), - self.html.Div( - [ - self.html.Div( - style={ - "background-color": "#ff9900", - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - " Context", - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ), - self.html.Div( - [ - self.html.Div( - style={ - "background-color": "#88c396", - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - " Action", - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ), - self.html.Div( - [ - self.html.Div( - style={ - "background-color": "#146eb4", - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - " Artifact", - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ), - self.html.Div( - [ - self.html.Div( - "★", - style={ - "background-color": "white", - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - "font-size": "1.5vw", - }, - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - "StartArn", - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ), - ], + [self._create_legend_component(k, v) for k, v in self.entity_color.items()] + + [self._create_legend_component("StartArn", "#ffffff", "★")], style={ "display": "inline-block", "font-size": "1vw", From 41d2453d9c3bd3496204d2747c325206a3772ff8 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 21 Jul 2022 11:24:34 -0700 Subject: [PATCH 17/46] inject graph data to DashVisualizer task --- src/sagemaker/lineage/query.py | 62 +++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 8e1ec3000e..1ea7baecb6 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -204,7 +204,7 @@ def _artifact_to_lineage_object(self): class DashVisualizer(object): """Create object used for visualizing graph using Dash library.""" - def __init__(self): + def __init__(self, graph_styles): """Init for DashVisualizer.""" # import visualization packages ( @@ -215,12 +215,7 @@ def __init__(self): self.Output, ) = self._import_visual_modules() - self.entity_color = { - "TrialComponent": "#f6cf61", - "Context": "#ff9900", - "Action": "#88c396", - "Artifact": "#146eb4", - } + self.graph_styles = graph_styles def _import_visual_modules(self): """Import modules needed for visualization.""" @@ -254,12 +249,19 @@ def _import_visual_modules(self): return cyto, JupyterDash, html, Input, Output - def _create_legend_component(self, text, color, colorText=""): + def _create_legend_component(self, style): """Create legend component div.""" + text = style["name"] + symbol = "" + color = "#ffffff" + if style["isShape"] == "False": + color = style["style"]["background-color"] + else: + symbol = style["symbol"] return self.html.Div( [ self.html.Div( - colorText, + symbol, style={ "background-color": color, "width": "1.5vw", @@ -282,9 +284,9 @@ def _create_legend_component(self, text, color, colorText=""): ] ) - def _create_entity_selector(self, entity_name, color): + def _create_entity_selector(self, entity_name, style): """Create selector for each lineage entity.""" - return {"selector": "." + entity_name, "style": {"background-color": color}} + return {"selector": "." + entity_name, "style": style["style"]} def _get_app(self, elements): """Create JupyterDash app for interactivity on Jupyter notebook.""" @@ -337,10 +339,9 @@ def _get_app(self, elements): "font-family": "verdana", }, }, - {"selector": ".startarn", "style": {"shape": "star"}}, {"selector": ".select", "style": {"border-opacity": "0.7"}}, ] - + [self._create_entity_selector(k, v) for k, v in self.entity_color.items()], + + [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()], responsive=True, ), self.html.Div( @@ -354,8 +355,7 @@ def _get_app(self, elements): ), # legend section self.html.Div( - [self._create_legend_component(k, v) for k, v in self.entity_color.items()] - + [self._create_legend_component("StartArn", "#ffffff", "★")], + [self._create_legend_component(v) for k, v in self.graph_styles.items()], style={ "display": "inline-block", "font-size": "1vw", @@ -492,8 +492,38 @@ def visualize(self): """Visualize lineage query result.""" elements = self._get_visualization_elements() + lineage_graph = { + # nodes can have shape / color + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "Action": { + "name": "Action", + "style": {"background-color": "#88c396"}, + "isShape": "False", + }, + "Artifact": { + "name": "Artifact", + "style": {"background-color": "#146eb4"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + # initialize DashVisualizer instance to render graph & interactive components - dash_vis = DashVisualizer() + dash_vis = DashVisualizer(lineage_graph) dash_server = dash_vis.render(elements=elements, mode="inline") From b74e8610c6f3515b0c479b7802931c7f6af9defa Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 26 Jul 2022 16:36:53 -0700 Subject: [PATCH 18/46] test_lineage_visualize.py created --- tests/integ/sagemaker/lineage/helpers.py | 56 +++++++++++++++++++ .../lineage/test_lineage_visualize.py | 51 +++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 tests/integ/sagemaker/lineage/test_lineage_visualize.py diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..52eb44363c 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -16,6 +16,10 @@ import uuid from datetime import datetime import time +import boto3 +from sagemaker.lineage import association +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association def name(): @@ -78,3 +82,55 @@ def visit(arn, visited: set): ret = [] return visit(start_arn, set()) + + +class LineageResourceHelper: + + def __init__(self): + self.client = boto3.client('sagemaker') + self.artifacts = [] + self.associations = [] + + def create_artifact(self, artifact_name, artifact_type='Dataset'): + response = self.client.create_artifact( + ArtifactName=artifact_name, + Source={ + 'SourceUri': "Test-artifact-" + artifact_name, + 'SourceTypes': [ + { + 'SourceIdType': 'S3ETag', + 'Value': 'Test-artifact-sourceId-value' + }, + ] + }, + ArtifactType=artifact_type + ) + self.artifacts.append(response['ArtifactArn']) + + return response['ArtifactArn'] + + def create_association(self, source_arn, dest_arn, association_type='AssociatedWith'): + response = self.client.add_association( + SourceArn=source_arn, + DestinationArn=dest_arn, + AssociationType=association_type + ) + if "SourceArn" in response.keys(): + self.associations.append((source_arn, dest_arn)) + return True + else: + return False + + def clean_all(self): + for source, dest in self.associations: + try: + self.client.delete_association( + SourceArn=source, + DestinationArn=dest + ) + time.sleep(2) + except(e): + print("skipped " + str(e)) + + for artifact_arn in self.artifacts: + self.client.delete_artifact(ArtifactArn=artifact_arn) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py new file mode 100644 index 0000000000..cb68021da5 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``""" + +import datetime +import logging +import time + +import pytest + +import sagemaker.lineage.query + +from tests.integ.sagemaker.lineage.helpers import name, names, retry, LineageResourceHelper + +def test_LineageResourceHelper(): + lineage_resource_helper = LineageResourceHelper() + art1 = lineage_resource_helper.create_artifact(artifact_name=name()) + art2 = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) + lineage_resource_helper.clean_all() + +def test_wide_graphs(sagemaker_session): + lineage_resource_helper = LineageResourceHelper() + art_root = lineage_resource_helper.create_artifact(artifact_name=name()) + try: + for i in range(10): + art = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=art_root, dest_arn=art) + time.sleep(0.1) + except(e): + lineage_resource_helper.clean_all() + + try: + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + result = lq.query(start_arns=[art_root]) + print(result) + except(e): + lineage_resource_helper.clean_all() + + lineage_resource_helper.clean_all() + From b6b5ee444d8cff2d3ccd651f846ab75d26d2f5bd Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 25 Jul 2022 10:40:00 -0700 Subject: [PATCH 19/46] PyvisVisualizer added --- src/sagemaker/lineage/query.py | 104 ++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 1ea7baecb6..9b65b77646 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -390,6 +390,78 @@ def render(self, elements, mode): return app.run_server(mode=mode) +class PyvisVisualizer(object): + """Create object used for visualizing graph using Pyvis library.""" + + def __init__(self, graph_styles): + """Init for PyvisVisualizer.""" + # import visualization packages + ( + self.pyvis, + self.Network, + self.Options, + ) = self._import_visual_modules() + + self.graph_styles = graph_styles + + def _import_visual_modules(self): + import pyvis + from pyvis.network import Network + from pyvis.options import Options + + return pyvis, Network, Options + + def _get_options(self): + options = """ + var options = { + "configure":{ + "enabled": true + }, + "layout": { + "hierarchical": { + "enabled": true, + "blockShifting": false, + "direction": "LR", + "sortMethod": "directed", + "shakeTowards": "roots" + } + }, + "interaction": { + "multiselect": true, + "navigationButtons": true + }, + "physics": { + "enabled": false, + "hierarchicalRepulsion": { + "centralGravity": 0, + "avoidOverlap": null + }, + "minVelocity": 0.75, + "solver": "hierarchicalRepulsion" + } + } + """ + return options + + def _node_color(self, n): + return self.graph_styles[n[2]]["style"]["background-color"] + + def render(self, elements): + net = self.Network(height='500px', width='100%', notebook = True, directed = True) + options = self._get_options() + net.set_options(options) + + for n in elements["nodes"]: + if(n[3]==True): # startarn + net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n), shape="star") + else: + net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n)) + + for e in elements["edges"]: + print(e) + net.add_edge(e[0], e[1], title=e[2]) + + return net.show('pyvisExample.html') class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -469,6 +541,18 @@ def _covert_edges_to_tuples(self): edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) return edges + def _pyvis_covert_vertices_to_tuples(self): + """Convert vertices to tuple format for visualizer.""" + verts = [] + # get vertex info in the form of (id, label, class) + for vert in self.vertices: + if vert.arn in self.startarn: + # add "startarn" class to node if arn is a startarn + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True)) + else: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False)) + return verts + def _get_visualization_elements(self): """Get elements for visualization.""" # get vertices and edges info for graph @@ -488,6 +572,16 @@ def _get_visualization_elements(self): return elements + def _get_pyvis_visualization_elements(self): + verts = self._pyvis_covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() + + elements = { + "nodes": verts, + "edges": edges + } + return elements + def visualize(self): """Visualize lineage query result.""" elements = self._get_visualization_elements() @@ -523,11 +617,15 @@ def visualize(self): } # initialize DashVisualizer instance to render graph & interactive components - dash_vis = DashVisualizer(lineage_graph) + # dash_vis = DashVisualizer(lineage_graph) + + # dash_server = dash_vis.render(elements=elements, mode="inline") - dash_server = dash_vis.render(elements=elements, mode="inline") + # return dash_server - return dash_server + pyvis_vis = PyvisVisualizer(lineage_graph) + elements = self._get_pyvis_visualization_elements() + return pyvis_vis.render(elements=elements) class LineageFilter(object): From 794fb578bde195e81199d322e8973e66c32d9430 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 25 Jul 2022 14:06:00 -0700 Subject: [PATCH 20/46] modify options --- src/sagemaker/lineage/query.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 9b65b77646..69132e35c9 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -408,6 +408,7 @@ def _import_visual_modules(self): import pyvis from pyvis.network import Network from pyvis.options import Options + # No module named 'pyvis' return pyvis, Network, Options @@ -415,15 +416,15 @@ def _get_options(self): options = """ var options = { "configure":{ - "enabled": true + "enabled": false }, "layout": { "hierarchical": { "enabled": true, - "blockShifting": false, + "blockShifting": true, "direction": "LR", "sortMethod": "directed", - "shakeTowards": "roots" + "shakeTowards": "leaves" } }, "interaction": { From baeaace2292aa3a67ac3de859f233cc96f5fcb96 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 27 Jul 2022 10:01:23 -0700 Subject: [PATCH 21/46] change: change visualization to using pyvis library --- src/sagemaker/lineage/query.py | 268 ++++----------------------------- 1 file changed, 30 insertions(+), 238 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 69132e35c9..b96880209b 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -201,18 +201,16 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) -class DashVisualizer(object): - """Create object used for visualizing graph using Dash library.""" +class PyvisVisualizer(object): + """Create object used for visualizing graph using Pyvis library.""" def __init__(self, graph_styles): - """Init for DashVisualizer.""" + """Init for PyvisVisualizer.""" # import visualization packages ( - self.cyto, - self.JupyterDash, - self.html, - self.Input, - self.Output, + self.pyvis, + self.Network, + self.Options, ) = self._import_visual_modules() self.graph_styles = graph_styles @@ -220,199 +218,30 @@ def __init__(self, graph_styles): def _import_visual_modules(self): """Import modules needed for visualization.""" try: - import dash_cytoscape as cyto + import pyvis except ImportError as e: print(e) - print("Try: pip install dash-cytoscape") + print("Try: pip install pyvis") raise try: - from jupyter_dash import JupyterDash + from pyvis.network import Network except ImportError as e: print(e) - print("Try: pip install jupyter-dash") + print("Try: pip install pyvis") raise try: - from dash import html + from pyvis.options import Options except ImportError as e: print(e) - print("Try: pip install dash") + print("Try: pip install pyvis") raise - try: - from dash.dependencies import Input, Output - except ImportError as e: - print(e) - print("Try: pip install dash") - raise - - return cyto, JupyterDash, html, Input, Output - - def _create_legend_component(self, style): - """Create legend component div.""" - text = style["name"] - symbol = "" - color = "#ffffff" - if style["isShape"] == "False": - color = style["style"]["background-color"] - else: - symbol = style["symbol"] - return self.html.Div( - [ - self.html.Div( - symbol, - style={ - "background-color": color, - "width": "1.5vw", - "height": "1.5vw", - "display": "inline-block", - "font-size": "1.5vw", - }, - ), - self.html.Div( - style={ - "width": "0.5vw", - "height": "1.5vw", - "display": "inline-block", - } - ), - self.html.Div( - text, - style={"display": "inline-block", "font-size": "1.5vw"}, - ), - ] - ) - - def _create_entity_selector(self, entity_name, style): - """Create selector for each lineage entity.""" - return {"selector": "." + entity_name, "style": style["style"]} - - def _get_app(self, elements): - """Create JupyterDash app for interactivity on Jupyter notebook.""" - app = self.JupyterDash(__name__) - self.cyto.load_extra_layouts() - - app.layout = self.html.Div( - [ - # graph section - self.cyto.Cytoscape( - id="cytoscape-graph", - elements=elements, - style={ - "width": "84%", - "height": "350px", - "display": "inline-block", - "border-width": "1vw", - "border-color": "#232f3e", - }, - layout={"name": "klay"}, - stylesheet=[ - { - "selector": "node", - "style": { - "label": "data(label)", - "font-size": "3.5vw", - "height": "10vw", - "width": "10vw", - "border-width": "0.8", - "border-opacity": "0", - "border-color": "#232f3e", - "font-family": "verdana", - }, - }, - { - "selector": "edge", - "style": { - "label": "data(label)", - "color": "gray", - "text-halign": "left", - "text-margin-y": "2.5", - "font-size": "3", - "width": "1", - "curve-style": "bezier", - "control-point-step-size": "15", - "target-arrow-color": "gray", - "target-arrow-shape": "triangle", - "line-color": "gray", - "arrow-scale": "0.5", - "font-family": "verdana", - }, - }, - {"selector": ".select", "style": {"border-opacity": "0.7"}}, - ] - + [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()], - responsive=True, - ), - self.html.Div( - style={ - "width": "0.5%", - "display": "inline-block", - "font-size": "1vw", - "font-family": "verdana", - "vertical-align": "top", - }, - ), - # legend section - self.html.Div( - [self._create_legend_component(v) for k, v in self.graph_styles.items()], - style={ - "display": "inline-block", - "font-size": "1vw", - "font-family": "verdana", - "vertical-align": "top", - }, - ), - ] - ) - - @app.callback( - self.Output("cytoscape-graph", "elements"), - self.Input("cytoscape-graph", "tapNodeData"), - self.Input("cytoscape-graph", "elements"), - ) - def selectNode(tapData, elements): - for n in elements: - if tapData is not None and n["data"]["id"] == tapData["id"]: - # if is tapped node, add "select" class to node - n["classes"] += " select" - elif "classes" in n: - # remove "select" class in "classes" if node not selected - n["classes"] = n["classes"].replace("select", "") - - return elements - - return app - - def render(self, elements, mode): - """Render graph for lineage query result.""" - app = self._get_app(elements) - - return app.run_server(mode=mode) - -class PyvisVisualizer(object): - """Create object used for visualizing graph using Pyvis library.""" - - def __init__(self, graph_styles): - """Init for PyvisVisualizer.""" - # import visualization packages - ( - self.pyvis, - self.Network, - self.Options, - ) = self._import_visual_modules() - - self.graph_styles = graph_styles - - def _import_visual_modules(self): - import pyvis - from pyvis.network import Network - from pyvis.options import Options - # No module named 'pyvis' - return pyvis, Network, Options def _get_options(self): + """Get pyvis graph options.""" options = """ var options = { "configure":{ @@ -445,24 +274,29 @@ def _get_options(self): return options def _node_color(self, n): + """Return node color by background-color specified in graph styles.""" return self.graph_styles[n[2]]["style"]["background-color"] - def render(self, elements): - net = self.Network(height='500px', width='100%', notebook = True, directed = True) + def render(self, elements, path="pyvisExample.html"): + """Render graph for lineage query result.""" + net = self.Network(height="500px", width="100%", notebook=True, directed=True) options = self._get_options() - net.set_options(options) + net.set_options(options) + # add nodes to graph for n in elements["nodes"]: - if(n[3]==True): # startarn - net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n), shape="star") + if n[3]: # startarn + net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n), shape="star") else: - net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n)) + net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n)) + # add edges to graph for e in elements["edges"]: print(e) - net.add_edge(e[0], e[1], title=e[2]) + net.add_edge(e[0], e[1], title=e[2]) + + return net.show(path) - return net.show('pyvisExample.html') class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -522,18 +356,6 @@ def __str__(self): result_dict = vars(self) return str({k: [str(val) for val in v] for k, v in result_dict.items()}) - def _covert_vertices_to_tuples(self): - """Convert vertices to tuple format for visualizer.""" - verts = [] - # get vertex info in the form of (id, label, class) - for vert in self.vertices: - if vert.arn in self.startarn: - # add "startarn" class to node if arn is a startarn - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) - else: - verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) - return verts - def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" edges = [] @@ -542,7 +364,7 @@ def _covert_edges_to_tuples(self): edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) return edges - def _pyvis_covert_vertices_to_tuples(self): + def _covert_vertices_to_tuples(self): """Convert vertices to tuple format for visualizer.""" verts = [] # get vertex info in the form of (id, label, class) @@ -555,38 +377,15 @@ def _pyvis_covert_vertices_to_tuples(self): return verts def _get_visualization_elements(self): - """Get elements for visualization.""" - # get vertices and edges info for graph + """Get elements(nodes+edges) for visualization.""" verts = self._covert_vertices_to_tuples() edges = self._covert_edges_to_tuples() - nodes = [ - {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts - ] - - edges = [ - {"data": {"source": source, "target": target, "label": label}} - for source, target, label in edges - ] - - elements = nodes + edges - - return elements - - def _get_pyvis_visualization_elements(self): - verts = self._pyvis_covert_vertices_to_tuples() - edges = self._covert_edges_to_tuples() - - elements = { - "nodes": verts, - "edges": edges - } + elements = {"nodes": verts, "edges": edges} return elements def visualize(self): """Visualize lineage query result.""" - elements = self._get_visualization_elements() - lineage_graph = { # nodes can have shape / color "TrialComponent": { @@ -617,15 +416,8 @@ def visualize(self): }, } - # initialize DashVisualizer instance to render graph & interactive components - # dash_vis = DashVisualizer(lineage_graph) - - # dash_server = dash_vis.render(elements=elements, mode="inline") - - # return dash_server - pyvis_vis = PyvisVisualizer(lineage_graph) - elements = self._get_pyvis_visualization_elements() + elements = self._get_visualization_elements() return pyvis_vis.render(elements=elements) From 578acdf8dd620180370ae8ab5e68579ca663cfab Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 27 Jul 2022 14:47:06 -0700 Subject: [PATCH 22/46] pyvis import issue on running lineage test --- requirements/extras/test_requirements.txt | 1 + tests/integ/sagemaker/lineage/helpers.py | 35 +++++++------------ .../lineage/test_lineage_visualize.py | 15 ++++---- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 2247394441..7197147032 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -18,3 +18,4 @@ fabric==2.6.0 requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 +pyvis==0.2.1 diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 52eb44363c..7b317c5705 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -85,35 +85,29 @@ def visit(arn, visited: set): class LineageResourceHelper: - def __init__(self): - self.client = boto3.client('sagemaker') + self.client = boto3.client("sagemaker") self.artifacts = [] self.associations = [] - def create_artifact(self, artifact_name, artifact_type='Dataset'): + def create_artifact(self, artifact_name, artifact_type="Dataset"): response = self.client.create_artifact( ArtifactName=artifact_name, Source={ - 'SourceUri': "Test-artifact-" + artifact_name, - 'SourceTypes': [ - { - 'SourceIdType': 'S3ETag', - 'Value': 'Test-artifact-sourceId-value' - }, - ] + "SourceUri": "Test-artifact-" + artifact_name, + "SourceTypes": [ + {"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"}, + ], }, - ArtifactType=artifact_type + ArtifactType=artifact_type, ) - self.artifacts.append(response['ArtifactArn']) + self.artifacts.append(response["ArtifactArn"]) - return response['ArtifactArn'] + return response["ArtifactArn"] - def create_association(self, source_arn, dest_arn, association_type='AssociatedWith'): + def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"): response = self.client.add_association( - SourceArn=source_arn, - DestinationArn=dest_arn, - AssociationType=association_type + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type ) if "SourceArn" in response.keys(): self.associations.append((source_arn, dest_arn)) @@ -124,12 +118,9 @@ def create_association(self, source_arn, dest_arn, association_type='AssociatedW def clean_all(self): for source, dest in self.associations: try: - self.client.delete_association( - SourceArn=source, - DestinationArn=dest - ) + self.client.delete_association(SourceArn=source, DestinationArn=dest) time.sleep(2) - except(e): + except (e): print("skipped " + str(e)) for artifact_arn in self.artifacts: diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index cb68021da5..eede2eaf9f 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -22,6 +22,7 @@ from tests.integ.sagemaker.lineage.helpers import name, names, retry, LineageResourceHelper + def test_LineageResourceHelper(): lineage_resource_helper = LineageResourceHelper() art1 = lineage_resource_helper.create_artifact(artifact_name=name()) @@ -29,6 +30,7 @@ def test_LineageResourceHelper(): lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) lineage_resource_helper.clean_all() + def test_wide_graphs(sagemaker_session): lineage_resource_helper = LineageResourceHelper() art_root = lineage_resource_helper.create_artifact(artifact_name=name()) @@ -36,16 +38,17 @@ def test_wide_graphs(sagemaker_session): for i in range(10): art = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association(source_arn=art_root, dest_arn=art) - time.sleep(0.1) - except(e): + time.sleep(0.2) + except Exception as e: + print(e) lineage_resource_helper.clean_all() try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) - result = lq.query(start_arns=[art_root]) - print(result) - except(e): + lq_result = lq.query(start_arns=[art_root]) + lq_result.visualize() + except Exception as e: + print(e) lineage_resource_helper.clean_all() lineage_resource_helper.clean_all() - From 9171e4e6afecd2d7b19e251f2c361704c0c4fcb4 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 27 Jul 2022 15:02:14 -0700 Subject: [PATCH 23/46] import visualization modules using get_module() --- src/sagemaker/lineage/query.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index b96880209b..95c5f6fa22 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -17,7 +17,7 @@ from enum import Enum from typing import Optional, Union, List, Dict -from sagemaker.lineage._utils import get_resource_name_from_arn +from sagemaker.lineage._utils import get_resource_name_from_arn, get_module class LineageEntityEnum(Enum): @@ -208,7 +208,6 @@ def __init__(self, graph_styles): """Init for PyvisVisualizer.""" # import visualization packages ( - self.pyvis, self.Network, self.Options, ) = self._import_visual_modules() @@ -217,28 +216,11 @@ def __init__(self, graph_styles): def _import_visual_modules(self): """Import modules needed for visualization.""" - try: - import pyvis - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise - - try: - from pyvis.network import Network - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise - - try: - from pyvis.options import Options - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise + get_module("pyvis") + from pyvis.network import Network + from pyvis.options import Options - return pyvis, Network, Options + return Network, Options def _get_options(self): """Get pyvis graph options.""" From d679f4a015a16bfb1d950c93158e7681ef5d0bbb Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 28 Jul 2022 09:51:40 -0700 Subject: [PATCH 24/46] test_wide_graphs added --- .gitignore | 3 ++- requirements/extras/local_requirements.txt | 2 +- tests/integ/sagemaker/lineage/helpers.py | 13 +++++++++---- .../sagemaker/lineage/test_lineage_visualize.py | 16 +++++++++++----- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 5b496055e9..7a61658c20 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ venv/ env/ .vscode/ **/tmp -.python-version \ No newline at end of file +.python-version +*.html \ No newline at end of file diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 17512c3388..439f42010c 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker~=5.0.0 -PyYAML==5.4.1 +PyYAML==5.4.1 \ No newline at end of file diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 7b317c5705..361019d9b4 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -17,6 +17,7 @@ from datetime import datetime import time import boto3 +from botocore.config import Config from sagemaker.lineage import association from sagemaker.lineage.artifact import Artifact from sagemaker.lineage.association import Association @@ -86,7 +87,7 @@ def visit(arn, visited: set): class LineageResourceHelper: def __init__(self): - self.client = boto3.client("sagemaker") + self.client = boto3.client("sagemaker", config=Config(connect_timeout=5, read_timeout=60, retries={'max_attempts': 20})) self.artifacts = [] self.associations = [] @@ -119,9 +120,13 @@ def clean_all(self): for source, dest in self.associations: try: self.client.delete_association(SourceArn=source, DestinationArn=dest) - time.sleep(2) - except (e): + time.sleep(0.5) + except Exception as e: print("skipped " + str(e)) for artifact_arn in self.artifacts: - self.client.delete_artifact(ArtifactArn=artifact_arn) + try: + self.client.delete_artifact(ArtifactArn=artifact_arn) + time.sleep(0.5) + except Exception as e: + print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index eede2eaf9f..3c84ae057b 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -25,23 +25,28 @@ def test_LineageResourceHelper(): lineage_resource_helper = LineageResourceHelper() - art1 = lineage_resource_helper.create_artifact(artifact_name=name()) - art2 = lineage_resource_helper.create_artifact(artifact_name=name()) - lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) - lineage_resource_helper.clean_all() + try: + art1 = lineage_resource_helper.create_artifact(artifact_name=name()) + art2 = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) + lineage_resource_helper.clean_all() + except Exception as e: + print(e) + assert False def test_wide_graphs(sagemaker_session): lineage_resource_helper = LineageResourceHelper() art_root = lineage_resource_helper.create_artifact(artifact_name=name()) try: - for i in range(10): + for i in range(500): art = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association(source_arn=art_root, dest_arn=art) time.sleep(0.2) except Exception as e: print(e) lineage_resource_helper.clean_all() + assert False try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) @@ -50,5 +55,6 @@ def test_wide_graphs(sagemaker_session): except Exception as e: print(e) lineage_resource_helper.clean_all() + assert False lineage_resource_helper.clean_all() From 3bab76a2a2854d6f73ea71e55e15a9cf280a7a40 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 28 Jul 2022 10:25:37 -0700 Subject: [PATCH 25/46] used get_module() method for import & code style change --- .gitignore | 3 +- src/sagemaker/lineage/query.py | 64 ++++++++++++---------------------- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index 5b496055e9..7a61658c20 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ venv/ env/ .vscode/ **/tmp -.python-version \ No newline at end of file +.python-version +*.html \ No newline at end of file diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index b96880209b..5f3164bdc0 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -17,7 +17,7 @@ from enum import Enum from typing import Optional, Union, List, Dict -from sagemaker.lineage._utils import get_resource_name_from_arn +from sagemaker.lineage._utils import get_resource_name_from_arn, get_module class LineageEntityEnum(Enum): @@ -208,41 +208,14 @@ def __init__(self, graph_styles): """Init for PyvisVisualizer.""" # import visualization packages ( - self.pyvis, self.Network, self.Options, ) = self._import_visual_modules() self.graph_styles = graph_styles - def _import_visual_modules(self): - """Import modules needed for visualization.""" - try: - import pyvis - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise - - try: - from pyvis.network import Network - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise - - try: - from pyvis.options import Options - except ImportError as e: - print(e) - print("Try: pip install pyvis") - raise - - return pyvis, Network, Options - - def _get_options(self): - """Get pyvis graph options.""" - options = """ + # pyvis graph options + self._options = """ var options = { "configure":{ "enabled": false @@ -271,29 +244,36 @@ def _get_options(self): } } """ - return options - def _node_color(self, n): + def _import_visual_modules(self): + """Import modules needed for visualization.""" + get_module("pyvis") + from pyvis.network import Network + from pyvis.options import Options + + return Network, Options + + def _node_color(self, entity): """Return node color by background-color specified in graph styles.""" - return self.graph_styles[n[2]]["style"]["background-color"] + return self.graph_styles[entity]["style"]["background-color"] def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result.""" net = self.Network(height="500px", width="100%", notebook=True, directed=True) - options = self._get_options() - net.set_options(options) + net.set_options(self._options) # add nodes to graph - for n in elements["nodes"]: - if n[3]: # startarn - net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n), shape="star") + for arn, source, entity, is_start_arn in elements["nodes"]: + if is_start_arn: # startarn + net.add_node( + arn, label=source, title=entity, color=self._node_color(entity), shape="star" + ) else: - net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n)) + net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) # add edges to graph - for e in elements["edges"]: - print(e) - net.add_edge(e[0], e[1], title=e[2]) + for src, dest, asso_type in elements["edges"]: + net.add_edge(src, dest, title=asso_type) return net.show(path) From 5c89c0931d9ce9c0b1bbfaca443fb5fb98795602 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 28 Jul 2022 13:46:31 -0700 Subject: [PATCH 26/46] long graph visualize test added --- src/sagemaker/lineage/query.py | 5 +- .../lineage/test_lineage_visualize.py | 52 ++++++++++++++++--- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 95c5f6fa22..a5c8c0bfae 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -274,7 +274,6 @@ def render(self, elements, path="pyvisExample.html"): # add edges to graph for e in elements["edges"]: - print(e) net.add_edge(e[0], e[1], title=e[2]) return net.show(path) @@ -366,7 +365,7 @@ def _get_visualization_elements(self): elements = {"nodes": verts, "edges": edges} return elements - def visualize(self): + def visualize(self, path="pyvisExample.html"): """Visualize lineage query result.""" lineage_graph = { # nodes can have shape / color @@ -400,7 +399,7 @@ def visualize(self): pyvis_vis = PyvisVisualizer(lineage_graph) elements = self._get_visualization_elements() - return pyvis_vis.render(elements=elements) + return pyvis_vis.render(elements=elements, path=path) class LineageFilter(object): diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 3c84ae057b..0fb1c04f43 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -19,11 +19,12 @@ import pytest import sagemaker.lineage.query - +from sagemaker.lineage.query import LineageQueryDirectionEnum from tests.integ.sagemaker.lineage.helpers import name, names, retry, LineageResourceHelper def test_LineageResourceHelper(): + # check if LineageResourceHelper works properly lineage_resource_helper = LineageResourceHelper() try: art1 = lineage_resource_helper.create_artifact(artifact_name=name()) @@ -35,13 +36,19 @@ def test_LineageResourceHelper(): assert False -def test_wide_graphs(sagemaker_session): +def test_wide_graph_visualize(sagemaker_session): lineage_resource_helper = LineageResourceHelper() - art_root = lineage_resource_helper.create_artifact(artifact_name=name()) + wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + + # create wide graph + # Artifact ----> Artifact + # \ \ \-> Artifact + # \ \--> Artifact + # \---> ... try: - for i in range(500): - art = lineage_resource_helper.create_artifact(artifact_name=name()) - lineage_resource_helper.create_association(source_arn=art_root, dest_arn=art) + for i in range(3): + artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=wide_graph_root_arn, dest_arn=artifact_arn) time.sleep(0.2) except Exception as e: print(e) @@ -50,11 +57,40 @@ def test_wide_graphs(sagemaker_session): try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) - lq_result = lq.query(start_arns=[art_root]) - lq_result.visualize() + lq_result = lq.query(start_arns=[wide_graph_root_arn]) + lq_result.visualize(path="wideGraph.html") except Exception as e: print(e) lineage_resource_helper.clean_all() assert False lineage_resource_helper.clean_all() + +def test_long_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper() + long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + last_arn = long_graph_root_arn + + # create long graph + # Artifact -> Artifact -> ... -> Artifact + try: + for i in range(20): + new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=last_arn, dest_arn=new_artifact_arn) + last_arn = new_artifact_arn + except Exception as e: + print(e) + lineage_resource_helper.clean_all() + assert False + + try: + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS) + # max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction) + lq_result.visualize(path="longGraph.html") + except Exception as e: + print(e) + lineage_resource_helper.clean_all() + assert False + + lineage_resource_helper.clean_all() \ No newline at end of file From 5609e429b00ad529bceafc4fef2b942399600815 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 29 Jul 2022 10:00:29 -0700 Subject: [PATCH 27/46] test_get_visualization_elements added --- src/sagemaker/lineage/query.py | 4 ++-- tests/integ/sagemaker/lineage/helpers.py | 4 +++- tests/unit/sagemaker/lineage/test_query.py | 28 ++++++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 5f3164bdc0..a46e88867a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -364,7 +364,7 @@ def _get_visualization_elements(self): elements = {"nodes": verts, "edges": edges} return elements - def visualize(self): + def visualize(self, path="pyvisExample.html"): """Visualize lineage query result.""" lineage_graph = { # nodes can have shape / color @@ -398,7 +398,7 @@ def visualize(self): pyvis_vis = PyvisVisualizer(lineage_graph) elements = self._get_visualization_elements() - return pyvis_vis.render(elements=elements) + return pyvis_vis.render(elements=elements, path=path) class LineageFilter(object): diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..4ed78127bb 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -16,7 +16,8 @@ import uuid from datetime import datetime import time - +import boto3 +from botocore.config import Config def name(): return "lineage-integ-{}-{}".format( @@ -78,3 +79,4 @@ def visit(arn, visited: set): ret = [] return visit(start_arn, set()) + diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index ae76fd199c..b5b809138d 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -524,3 +524,31 @@ def test_vertex_to_object_unconvertable(sagemaker_session): with pytest.raises(ValueError): vertex.to_lineage_object() + + +def test_get_visualization_elements(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + print(query_response) + + elements = query_response._get_visualization_elements() + + print(elements) + + assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) + assert elements["nodes"][1] == ("arn2", "Model", "Context", False) + assert elements["edges"][0] == ("arn1", "arn2", "Produced") + + + From fa7b9dd752b00740060a22ae95be5f87597f35df Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 29 Jul 2022 14:06:09 -0700 Subject: [PATCH 28/46] create context & action added to helper --- tests/integ/sagemaker/lineage/helpers.py | 51 ++++++++++++++++++- .../lineage/test_lineage_visualize.py | 14 ++--- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 361019d9b4..4e73537584 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains helper methods for tests of SageMaker Lineage""" from __future__ import absolute_import +from urllib import response import uuid from datetime import datetime @@ -86,9 +87,12 @@ def visit(arn, visited: set): class LineageResourceHelper: - def __init__(self): - self.client = boto3.client("sagemaker", config=Config(connect_timeout=5, read_timeout=60, retries={'max_attempts': 20})) + def __init__(self, sagemaker_session): + self.client = sagemaker_session.sagemaker_client self.artifacts = [] + self.actions = [] + self.contexts = [] + self.trialComponents = [] self.associations = [] def create_artifact(self, artifact_name, artifact_type="Dataset"): @@ -106,6 +110,42 @@ def create_artifact(self, artifact_name, artifact_type="Dataset"): return response["ArtifactArn"] + def create_action(self, action_name, action_type="ModelDeployment"): + response = self.client.create_action( + ActionName=action_name, + Source={ + "SourceUri": "Test-action-" + action_name, + "SourceTypes": [ + {"SourceIdType": "S3ETag", "Value": "Test-action-sourceId-value"}, + ], + }, + ActionType=action_type + ) + self.actions.append(response["ActionArn"]) + + return response["ActionArn"] + + def create_context(self, context_name, context_type="Endpoint"): + response = self.client.create_context( + ContextName=context_name, + Source={ + "SourceUri": "Test-context-" + context_name, + "SourceTypes": [ + {"SourceIdType": "S3ETag", "Value": "Test-context-sourceId-value"}, + ], + }, + ContextType=context_type + ) + self.contexts.append(response["ContextArn"]) + + return response["ContextArn"] + + def create_trialComponent(self, trialComponent_name, trialComponent_type="TrainingJob"): + response = self.client.create_trial_component( + TrialComponentName=trialComponent_name, + + ) + def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"): response = self.client.add_association( SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type @@ -130,3 +170,10 @@ def clean_all(self): time.sleep(0.5) except Exception as e: print("skipped " + str(e)) + + for action_arn in self.actions: + try: + self.client.delete_action(ActionArn=action_arn) + time.sleep(0.5) + except Exception as e: + print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 0fb1c04f43..5f12386392 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -23,9 +23,9 @@ from tests.integ.sagemaker.lineage.helpers import name, names, retry, LineageResourceHelper -def test_LineageResourceHelper(): +def test_LineageResourceHelper(sagemaker_session): # check if LineageResourceHelper works properly - lineage_resource_helper = LineageResourceHelper() + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) try: art1 = lineage_resource_helper.create_artifact(artifact_name=name()) art2 = lineage_resource_helper.create_artifact(artifact_name=name()) @@ -35,9 +35,9 @@ def test_LineageResourceHelper(): print(e) assert False - +@pytest.mark.skip("visualizer load test") def test_wide_graph_visualize(sagemaker_session): - lineage_resource_helper = LineageResourceHelper() + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) # create wide graph @@ -46,10 +46,9 @@ def test_wide_graph_visualize(sagemaker_session): # \ \--> Artifact # \---> ... try: - for i in range(3): + for i in range(500): artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association(source_arn=wide_graph_root_arn, dest_arn=artifact_arn) - time.sleep(0.2) except Exception as e: print(e) lineage_resource_helper.clean_all() @@ -66,8 +65,9 @@ def test_wide_graph_visualize(sagemaker_session): lineage_resource_helper.clean_all() +@pytest.mark.skip("visualizer load test") def test_long_graph_visualize(sagemaker_session): - lineage_resource_helper = LineageResourceHelper() + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) last_arn = long_graph_root_arn From 83964da9b479d9ff4f034ce179f95b3fa56554ce Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 29 Jul 2022 14:22:21 -0700 Subject: [PATCH 29/46] resolve conflict with master branch --- src/sagemaker/lineage/query.py | 42 ++++++++++++++++------------------ 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a5c8c0bfae..f5ff1edef1 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -214,17 +214,7 @@ def __init__(self, graph_styles): self.graph_styles = graph_styles - def _import_visual_modules(self): - """Import modules needed for visualization.""" - get_module("pyvis") - from pyvis.network import Network - from pyvis.options import Options - - return Network, Options - - def _get_options(self): - """Get pyvis graph options.""" - options = """ + self._options = """ var options = { "configure":{ "enabled": false @@ -253,28 +243,36 @@ def _get_options(self): } } """ - return options - def _node_color(self, n): + def _import_visual_modules(self): + """Import modules needed for visualization.""" + get_module("pyvis") + from pyvis.network import Network + from pyvis.options import Options + + return Network, Options + + def _node_color(self, entity): """Return node color by background-color specified in graph styles.""" - return self.graph_styles[n[2]]["style"]["background-color"] + return self.graph_styles[entity]["style"]["background-color"] def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result.""" net = self.Network(height="500px", width="100%", notebook=True, directed=True) - options = self._get_options() - net.set_options(options) + net.set_options(self._options) # add nodes to graph - for n in elements["nodes"]: - if n[3]: # startarn - net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n), shape="star") + for arn, source, entity, is_start_arn in elements["nodes"]: + if is_start_arn: # startarn + net.add_node( + arn, label=source, title=entity, color=self._node_color(entity), shape="star" + ) else: - net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n)) + net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) # add edges to graph - for e in elements["edges"]: - net.add_edge(e[0], e[1], title=e[2]) + for src, dest, asso_type in elements["edges"]: + net.add_edge(src, dest, title=asso_type) return net.show(path) From c66edf87243c7b2819ea0681aef16576cdded30f Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 1 Aug 2022 14:17:32 -0700 Subject: [PATCH 30/46] change: add queryLineageResult visualizer load test & integ test --- tests/integ/sagemaker/lineage/helpers.py | 34 ++- .../lineage/test_lineage_visualize.py | 194 +++++++++++++++++- 2 files changed, 197 insertions(+), 31 deletions(-) diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 4e73537584..0c40bbac91 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -12,16 +12,10 @@ # language governing permissions and limitations under the License. """This module contains helper methods for tests of SageMaker Lineage""" from __future__ import absolute_import -from urllib import response import uuid from datetime import datetime import time -import boto3 -from botocore.config import Config -from sagemaker.lineage import association -from sagemaker.lineage.artifact import Artifact -from sagemaker.lineage.association import Association def name(): @@ -92,7 +86,6 @@ def __init__(self, sagemaker_session): self.artifacts = [] self.actions = [] self.contexts = [] - self.trialComponents = [] self.associations = [] def create_artifact(self, artifact_name, artifact_type="Dataset"): @@ -115,11 +108,10 @@ def create_action(self, action_name, action_type="ModelDeployment"): ActionName=action_name, Source={ "SourceUri": "Test-action-" + action_name, - "SourceTypes": [ - {"SourceIdType": "S3ETag", "Value": "Test-action-sourceId-value"}, - ], + "SourceType": "S3ETag", + "SourceId": "Test-action-sourceId-value", }, - ActionType=action_type + ActionType=action_type, ) self.actions.append(response["ActionArn"]) @@ -130,22 +122,15 @@ def create_context(self, context_name, context_type="Endpoint"): ContextName=context_name, Source={ "SourceUri": "Test-context-" + context_name, - "SourceTypes": [ - {"SourceIdType": "S3ETag", "Value": "Test-context-sourceId-value"}, - ], + "SourceType": "S3ETag", + "SourceId": "Test-context-sourceId-value", }, - ContextType=context_type + ContextType=context_type, ) self.contexts.append(response["ContextArn"]) return response["ContextArn"] - def create_trialComponent(self, trialComponent_name, trialComponent_type="TrainingJob"): - response = self.client.create_trial_component( - TrialComponentName=trialComponent_name, - - ) - def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"): response = self.client.add_association( SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type @@ -177,3 +162,10 @@ def clean_all(self): time.sleep(0.5) except Exception as e: print("skipped " + str(e)) + + for context_arn in self.contexts: + try: + self.client.delete_context(ContextArn=context_arn) + time.sleep(0.5) + except Exception as e: + print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 5f12386392..274b18d503 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -11,16 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module contains code to test SageMaker ``LineageQueryResult.visualize()``""" - -import datetime -import logging +from __future__ import absolute_import import time +import json import pytest import sagemaker.lineage.query from sagemaker.lineage.query import LineageQueryDirectionEnum -from tests.integ.sagemaker.lineage.helpers import name, names, retry, LineageResourceHelper +from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper def test_LineageResourceHelper(sagemaker_session): @@ -35,6 +34,7 @@ def test_LineageResourceHelper(sagemaker_session): print(e) assert False + @pytest.mark.skip("visualizer load test") def test_wide_graph_visualize(sagemaker_session): lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) @@ -46,9 +46,11 @@ def test_wide_graph_visualize(sagemaker_session): # \ \--> Artifact # \---> ... try: - for i in range(500): + for i in range(10): artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) - lineage_resource_helper.create_association(source_arn=wide_graph_root_arn, dest_arn=artifact_arn) + lineage_resource_helper.create_association( + source_arn=wide_graph_root_arn, dest_arn=artifact_arn + ) except Exception as e: print(e) lineage_resource_helper.clean_all() @@ -65,6 +67,7 @@ def test_wide_graph_visualize(sagemaker_session): lineage_resource_helper.clean_all() + @pytest.mark.skip("visualizer load test") def test_long_graph_visualize(sagemaker_session): lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) @@ -74,9 +77,11 @@ def test_long_graph_visualize(sagemaker_session): # create long graph # Artifact -> Artifact -> ... -> Artifact try: - for i in range(20): + for i in range(10): new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) - lineage_resource_helper.create_association(source_arn=last_arn, dest_arn=new_artifact_arn) + lineage_resource_helper.create_association( + source_arn=last_arn, dest_arn=new_artifact_arn + ) last_arn = new_artifact_arn except Exception as e: print(e) @@ -85,7 +90,9 @@ def test_long_graph_visualize(sagemaker_session): try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) - lq_result = lq.query(start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS) + lq_result = lq.query( + start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS + ) # max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction) lq_result.visualize(path="longGraph.html") except Exception as e: @@ -93,4 +100,171 @@ def test_long_graph_visualize(sagemaker_session): lineage_resource_helper.clean_all() assert False - lineage_resource_helper.clean_all() \ No newline at end of file + lineage_resource_helper.clean_all() + + +def test_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + + # create lineage data + # image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context + # /-> + # dataset artifact -/ + try: + graph_startarn = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Model" + ) + image_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Image" + ) + lineage_resource_helper.create_association( + source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo" + ) + dataset_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="DataSet" + ) + lineage_resource_helper.create_association( + source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith" + ) + modeldeploy_action = lineage_resource_helper.create_action( + action_name=name(), action_type="ModelDeploy" + ) + lineage_resource_helper.create_association( + source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo" + ) + endpoint_context = lineage_resource_helper.create_context( + context_name=name(), context_type="Endpoint" + ) + lineage_resource_helper.create_association( + source_arn=modeldeploy_action, + dest_arn=endpoint_context, + association_type="AssociatedWith", + ) + time.sleep(1) + except Exception as e: + print(e) + lineage_resource_helper.clean_all() + assert False + + # visualize + try: + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[graph_startarn]) + lq_result.visualize(path="testGraph.html") + except Exception as e: + print(e) + lineage_resource_helper.clean_all() + assert False + + # check generated graph info + try: + fo = open("testGraph.html", "r") + lines = fo.readlines() + for line in lines: + if "nodes = " in line: + node = line + if "edges = " in line: + edge = line + + # extract node data + start = node.find("[") + end = node.find("]") + res = node[start + 1 : end].split("}, ") + res = [i + "}" for i in res] + res[-1] = res[-1][:-1] + node_dict = [json.loads(i) for i in res] + + # extract edge data + start = edge.find("[") + end = edge.find("]") + res = edge[start + 1 : end].split("}, ") + res = [i + "}" for i in res] + res[-1] = res[-1][:-1] + edge_dict = [json.loads(i) for i in res] + + # check node number + assert len(node_dict) == 5 + + # check startarn + found_value = next( + dictionary for dictionary in node_dict if dictionary["id"] == graph_startarn + ) + assert found_value["color"] == "#146eb4" + assert found_value["label"] == "Model" + assert found_value["shape"] == "star" + assert found_value["title"] == "Artifact" + + # check image artifact + found_value = next( + dictionary for dictionary in node_dict if dictionary["id"] == image_artifact + ) + assert found_value["color"] == "#146eb4" + assert found_value["label"] == "Image" + assert found_value["shape"] == "dot" + assert found_value["title"] == "Artifact" + + # check dataset artifact + found_value = next( + dictionary for dictionary in node_dict if dictionary["id"] == dataset_artifact + ) + assert found_value["color"] == "#146eb4" + assert found_value["label"] == "DataSet" + assert found_value["shape"] == "dot" + assert found_value["title"] == "Artifact" + + # check modeldeploy action + found_value = next( + dictionary for dictionary in node_dict if dictionary["id"] == modeldeploy_action + ) + assert found_value["color"] == "#88c396" + assert found_value["label"] == "ModelDeploy" + assert found_value["shape"] == "dot" + assert found_value["title"] == "Action" + + # check endpoint context + found_value = next( + dictionary for dictionary in node_dict if dictionary["id"] == endpoint_context + ) + assert found_value["color"] == "#ff9900" + assert found_value["label"] == "Endpoint" + assert found_value["shape"] == "dot" + assert found_value["title"] == "Context" + + # check edge number + assert len(edge_dict) == 4 + + # check image_artifact -> model_artifact(startarn) edge + found_value = next( + dictionary for dictionary in edge_dict if dictionary["from"] == image_artifact + ) + assert found_value["to"] == graph_startarn + assert found_value["title"] == "ContributedTo" + + # check dataset_artifact -> model_artifact(startarn) edge + found_value = next( + dictionary for dictionary in edge_dict if dictionary["from"] == dataset_artifact + ) + assert found_value["to"] == graph_startarn + assert found_value["title"] == "AssociatedWith" + + # check model_artifact(startarn) -> modeldeploy_action edge + found_value = next( + dictionary for dictionary in edge_dict if dictionary["from"] == graph_startarn + ) + assert found_value["to"] == modeldeploy_action + assert found_value["title"] == "ContributedTo" + + # check modeldeploy_action -> endpoint_context edge + found_value = next( + dictionary for dictionary in edge_dict if dictionary["from"] == modeldeploy_action + ) + assert found_value["to"] == endpoint_context + assert found_value["title"] == "AssociatedWith" + + except Exception as e: + print(e) + lineage_resource_helper.clean_all() + assert False + + # clean lineage data + lineage_resource_helper.clean_all() From 7cec38d2c88aafb402c23818ce3bb3b6e55ce505 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 1 Aug 2022 15:59:04 -0700 Subject: [PATCH 31/46] remove generated graph html file after visualize integ test --- tests/integ/sagemaker/lineage/test_lineage_visualize.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 274b18d503..1ba87d6a3c 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import time import json +import os import pytest @@ -264,7 +265,10 @@ def test_graph_visualize(sagemaker_session): except Exception as e: print(e) lineage_resource_helper.clean_all() + os.remove("testGraph.html") assert False + # delete generated test graph + os.remove("testGraph.html") # clean lineage data lineage_resource_helper.clean_all() From 628ba0ff3e4c327fb9eba2532090a2d7bec60a62 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 2 Aug 2022 09:55:08 -0700 Subject: [PATCH 32/46] __str__ function update --- src/sagemaker/lineage/query.py | 51 ++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a46e88867a..5277a2d806 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -104,6 +104,18 @@ def __str__(self): """ return str(self.__dict__) + def __repr__(self): + """Define string representation of ``Edge``. + + Format: + { + 'source_arn': 'string', 'destination_arn': 'string', + 'association_type': 'string' + } + + """ + return "\n\t" + str(self.__dict__) + class Vertex: """A vertex for a lineage graph.""" @@ -155,6 +167,19 @@ def __str__(self): """ return str(self.__dict__) + def __repr__(self): + """Define string representation of ``Vertex``. + + Format: + { + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + } + + """ + return "\n\t" + str(self.__dict__) + def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -312,29 +337,19 @@ def __str__(self): Format: { 'edges':[ - "{ - 'source_arn': 'string', 'destination_arn': 'string', - 'association_type': 'string' - }", - ... - ], + {'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string'}, + ...], + 'vertices':[ - "{ - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', - '_session': - }", - ... - ], - 'startarn':[ - 'string', - ... - ] + {'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': }, + ...], + + 'startarn':['string', ...] } """ result_dict = vars(self) - return str({k: [str(val) for val in v] for k, v in result_dict.items()}) + return '{\n' + '\n\n'.join('\'{}\': {},'.format(key, val) for key, val in self.__dict__.items()) + '\n}' def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" From f9d0ae126dcb8ef740b560e5c8af4892ac9f20a1 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 2 Aug 2022 12:37:20 -0700 Subject: [PATCH 33/46] query lineage result str function test added --- src/sagemaker/lineage/query.py | 14 +++++---- tests/integ/sagemaker/lineage/helpers.py | 3 +- tests/unit/sagemaker/lineage/test_query.py | 33 ++++++++++++++++++---- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 5277a2d806..0acdccd698 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -339,17 +339,21 @@ def __str__(self): 'edges':[ {'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string'}, ...], - + 'vertices':[ - {'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': }, + {'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', + '_session': }, ...], - + 'startarn':['string', ...] } """ - result_dict = vars(self) - return '{\n' + '\n\n'.join('\'{}\': {},'.format(key, val) for key, val in self.__dict__.items()) + '\n}' + return ( + "{\n" + + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items()) + + "\n}" + ) def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index d2ef2c067d..0c40bbac91 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -16,8 +16,7 @@ import uuid from datetime import datetime import time -import boto3 -from botocore.config import Config + def name(): return "lineage-integ-{}-{}".format( diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index b5b809138d..2573314d3b 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -18,6 +18,7 @@ from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest +import re def test_lineage_query(sagemaker_session): @@ -540,15 +541,37 @@ def test_get_visualization_elements(sagemaker_session): start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] ) - print(query_response) - elements = query_response._get_visualization_elements() - print(elements) - assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) assert elements["nodes"][1] == ("arn2", "Model", "Context", False) assert elements["edges"][0] == ("arn1", "arn2", "Produced") - +def test_query_lineage_result_str(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + response_str = query_response.__str__() + pattern = "Mock id='\d*'" + replace = "Mock id=''" + response_str = re.sub(pattern, replace, response_str) + + assert ( + response_str + == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " + + "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " + + "'Model', '_session': }],\n\n'startarn': " + + "['arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext'],\n}" + ) From c1169be2c0ea443a85bcc2325fc4b46df1393ce7 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 3 Aug 2022 12:03:42 -0700 Subject: [PATCH 34/46] validation logic clean on integ tests --- tests/integ/sagemaker/lineage/conftest.py | 15 ++ tests/integ/sagemaker/lineage/helpers.py | 4 - .../lineage/test_lineage_visualize.py | 202 +++++++----------- 3 files changed, 95 insertions(+), 126 deletions(-) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 5e201eef42..7450cc5935 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -19,6 +19,7 @@ import pytest import logging import uuid +import json from sagemaker.lineage import ( action, context, @@ -891,3 +892,17 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session): pass else: raise (e) + + +@pytest.fixture +def extract_data_from_html(): + def _method(data): + start = data.find("[") + end = data.find("]") + res = data[start + 1 : end].split("}, ") + res = [i + "}" for i in res] + res[-1] = res[-1][:-1] + data_dict = [json.loads(i) for i in res] + return data_dict + + return _method diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 0c40bbac91..609ba9836d 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -145,27 +145,23 @@ def clean_all(self): for source, dest in self.associations: try: self.client.delete_association(SourceArn=source, DestinationArn=dest) - time.sleep(0.5) except Exception as e: print("skipped " + str(e)) for artifact_arn in self.artifacts: try: self.client.delete_artifact(ArtifactArn=artifact_arn) - time.sleep(0.5) except Exception as e: print("skipped " + str(e)) for action_arn in self.actions: try: self.client.delete_action(ActionArn=action_arn) - time.sleep(0.5) except Exception as e: print("skipped " + str(e)) for context_arn in self.contexts: try: self.client.delete_context(ContextArn=context_arn) - time.sleep(0.5) except Exception as e: print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 1ba87d6a3c..55cddab97b 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -13,7 +13,6 @@ """This module contains code to test SageMaker ``LineageQueryResult.visualize()``""" from __future__ import absolute_import import time -import json import os import pytest @@ -47,26 +46,22 @@ def test_wide_graph_visualize(sagemaker_session): # \ \--> Artifact # \---> ... try: - for i in range(10): + for i in range(200): artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association( source_arn=wide_graph_root_arn, dest_arn=artifact_arn ) - except Exception as e: - print(e) - lineage_resource_helper.clean_all() - assert False - try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) lq_result = lq.query(start_arns=[wide_graph_root_arn]) lq_result.visualize(path="wideGraph.html") + except Exception as e: print(e) - lineage_resource_helper.clean_all() assert False - lineage_resource_helper.clean_all() + finally: + lineage_resource_helper.clean_all() @pytest.mark.skip("visualizer load test") @@ -84,27 +79,23 @@ def test_long_graph_visualize(sagemaker_session): source_arn=last_arn, dest_arn=new_artifact_arn ) last_arn = new_artifact_arn - except Exception as e: - print(e) - lineage_resource_helper.clean_all() - assert False - try: lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) lq_result = lq.query( start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS ) # max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction) lq_result.visualize(path="longGraph.html") + except Exception as e: print(e) - lineage_resource_helper.clean_all() assert False - lineage_resource_helper.clean_all() + finally: + lineage_resource_helper.clean_all() -def test_graph_visualize(sagemaker_session): +def test_graph_visualize(sagemaker_session, extract_data_from_html): lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) # create lineage data @@ -141,24 +132,14 @@ def test_graph_visualize(sagemaker_session): dest_arn=endpoint_context, association_type="AssociatedWith", ) - time.sleep(1) - except Exception as e: - print(e) - lineage_resource_helper.clean_all() - assert False + time.sleep(3) - # visualize - try: + # visualize lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) lq_result = lq.query(start_arns=[graph_startarn]) lq_result.visualize(path="testGraph.html") - except Exception as e: - print(e) - lineage_resource_helper.clean_all() - assert False - # check generated graph info - try: + # check generated graph info fo = open("testGraph.html", "r") lines = fo.readlines() for line in lines: @@ -167,108 +148,85 @@ def test_graph_visualize(sagemaker_session): if "edges = " in line: edge = line - # extract node data - start = node.find("[") - end = node.find("]") - res = node[start + 1 : end].split("}, ") - res = [i + "}" for i in res] - res[-1] = res[-1][:-1] - node_dict = [json.loads(i) for i in res] - - # extract edge data - start = edge.find("[") - end = edge.find("]") - res = edge[start + 1 : end].split("}, ") - res = [i + "}" for i in res] - res[-1] = res[-1][:-1] - edge_dict = [json.loads(i) for i in res] + node_dict = extract_data_from_html(node) + edge_dict = extract_data_from_html(edge) # check node number assert len(node_dict) == 5 - # check startarn - found_value = next( - dictionary for dictionary in node_dict if dictionary["id"] == graph_startarn - ) - assert found_value["color"] == "#146eb4" - assert found_value["label"] == "Model" - assert found_value["shape"] == "star" - assert found_value["title"] == "Artifact" - - # check image artifact - found_value = next( - dictionary for dictionary in node_dict if dictionary["id"] == image_artifact - ) - assert found_value["color"] == "#146eb4" - assert found_value["label"] == "Image" - assert found_value["shape"] == "dot" - assert found_value["title"] == "Artifact" - - # check dataset artifact - found_value = next( - dictionary for dictionary in node_dict if dictionary["id"] == dataset_artifact - ) - assert found_value["color"] == "#146eb4" - assert found_value["label"] == "DataSet" - assert found_value["shape"] == "dot" - assert found_value["title"] == "Artifact" - - # check modeldeploy action - found_value = next( - dictionary for dictionary in node_dict if dictionary["id"] == modeldeploy_action - ) - assert found_value["color"] == "#88c396" - assert found_value["label"] == "ModelDeploy" - assert found_value["shape"] == "dot" - assert found_value["title"] == "Action" - - # check endpoint context - found_value = next( - dictionary for dictionary in node_dict if dictionary["id"] == endpoint_context - ) - assert found_value["color"] == "#ff9900" - assert found_value["label"] == "Endpoint" - assert found_value["shape"] == "dot" - assert found_value["title"] == "Context" + expected_nodes = { + graph_startarn: { + "color": "#146eb4", + "label": "Model", + "shape": "star", + "title": "Artifact", + }, + image_artifact: { + "color": "#146eb4", + "label": "Image", + "shape": "dot", + "title": "Artifact", + }, + dataset_artifact: { + "color": "#146eb4", + "label": "DataSet", + "shape": "dot", + "title": "Artifact", + }, + modeldeploy_action: { + "color": "#88c396", + "label": "ModelDeploy", + "shape": "dot", + "title": "Action", + }, + endpoint_context: { + "color": "#ff9900", + "label": "Endpoint", + "shape": "dot", + "title": "Context", + }, + } + + # check node properties + for node in node_dict: + for label, val in expected_nodes[node["id"]].items(): + assert node[label] == val # check edge number assert len(edge_dict) == 4 - # check image_artifact -> model_artifact(startarn) edge - found_value = next( - dictionary for dictionary in edge_dict if dictionary["from"] == image_artifact - ) - assert found_value["to"] == graph_startarn - assert found_value["title"] == "ContributedTo" - - # check dataset_artifact -> model_artifact(startarn) edge - found_value = next( - dictionary for dictionary in edge_dict if dictionary["from"] == dataset_artifact - ) - assert found_value["to"] == graph_startarn - assert found_value["title"] == "AssociatedWith" - - # check model_artifact(startarn) -> modeldeploy_action edge - found_value = next( - dictionary for dictionary in edge_dict if dictionary["from"] == graph_startarn - ) - assert found_value["to"] == modeldeploy_action - assert found_value["title"] == "ContributedTo" - - # check modeldeploy_action -> endpoint_context edge - found_value = next( - dictionary for dictionary in edge_dict if dictionary["from"] == modeldeploy_action - ) - assert found_value["to"] == endpoint_context - assert found_value["title"] == "AssociatedWith" + expected_edges = { + image_artifact: { + "from": image_artifact, + "to": graph_startarn, + "title": "ContributedTo", + }, + dataset_artifact: { + "from": dataset_artifact, + "to": graph_startarn, + "title": "AssociatedWith", + }, + graph_startarn: { + "from": graph_startarn, + "to": modeldeploy_action, + "title": "ContributedTo", + }, + modeldeploy_action: { + "from": modeldeploy_action, + "to": endpoint_context, + "title": "AssociatedWith", + }, + } + + # check edge properties + for edge in edge_dict: + for label, val in expected_edges[edge["from"]].items(): + assert edge[label] == val except Exception as e: print(e) - lineage_resource_helper.clean_all() - os.remove("testGraph.html") assert False - # delete generated test graph - os.remove("testGraph.html") - # clean lineage data - lineage_resource_helper.clean_all() + finally: + lineage_resource_helper.clean_all() + os.remove("testGraph.html") From 23fe1265a8768d03b0bc5c53e879236a3a3d484d Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 3 Aug 2022 15:01:57 -0700 Subject: [PATCH 35/46] sleep time before clean_all added (avoid race condition) --- tests/integ/sagemaker/lineage/helpers.py | 4 ++++ tests/integ/sagemaker/lineage/test_lineage_visualize.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 609ba9836d..3ab22ce332 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -142,6 +142,10 @@ def create_association(self, source_arn, dest_arn, association_type="AssociatedW return False def clean_all(self): + # clean all lineage data created by LineageResourceHelper + + time.sleep(1) # avoid GSI race condition between create & delete + for source, dest in self.associations: try: self.client.delete_association(SourceArn=source, DestinationArn=dest) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 55cddab97b..d9b3f879a8 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -29,10 +29,11 @@ def test_LineageResourceHelper(sagemaker_session): art1 = lineage_resource_helper.create_artifact(artifact_name=name()) art2 = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) - lineage_resource_helper.clean_all() except Exception as e: print(e) assert False + finally: + lineage_resource_helper.clean_all() @pytest.mark.skip("visualizer load test") @@ -46,7 +47,7 @@ def test_wide_graph_visualize(sagemaker_session): # \ \--> Artifact # \---> ... try: - for i in range(200): + for i in range(150): artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) lineage_resource_helper.create_association( source_arn=wide_graph_root_arn, dest_arn=artifact_arn @@ -56,6 +57,10 @@ def test_wide_graph_visualize(sagemaker_session): lq_result = lq.query(start_arns=[wide_graph_root_arn]) lq_result.visualize(path="wideGraph.html") + print("vertex len = ") + print(len(lq_result.vertices)) + assert False + except Exception as e: print(e) assert False From 09c9fc166fb090b0526567b7e611fd357be12806 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 3 Aug 2022 15:46:18 -0700 Subject: [PATCH 36/46] change: add queryLineageResult visualizer unit test --- tests/integ/sagemaker/lineage/test_lineage_visualize.py | 4 ---- tests/unit/sagemaker/lineage/test_query.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index d9b3f879a8..555b98452e 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -57,10 +57,6 @@ def test_wide_graph_visualize(sagemaker_session): lq_result = lq.query(start_arns=[wide_graph_root_arn]) lq_result.visualize(path="wideGraph.html") - print("vertex len = ") - print(len(lq_result.vertices)) - assert False - except Exception as e: print(e) assert False diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 2573314d3b..ada905008b 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -563,8 +563,8 @@ def test_query_lineage_result_str(sagemaker_session): ) response_str = query_response.__str__() - pattern = "Mock id='\d*'" - replace = "Mock id=''" + pattern = r"Mock id='\d*'" + replace = r"Mock id=''" response_str = re.sub(pattern, replace, response_str) assert ( From 684d45d556b874eb854a982e0ed857a6f93dc6fa Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 4 Aug 2022 13:47:34 -0700 Subject: [PATCH 37/46] startarn added to lineage return value --- tests/unit/sagemaker/lineage/test_query.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index ada905008b..0a357eb1fc 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -533,6 +533,11 @@ def test_get_visualization_elements(sagemaker_session): "Vertices": [ {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + { + "Arn": "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Type": "Model", + "LineageType": "Context", + }, ], "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], } @@ -545,6 +550,12 @@ def test_get_visualization_elements(sagemaker_session): assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) assert elements["nodes"][1] == ("arn2", "Model", "Context", False) + assert elements["nodes"][2] == ( + "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Model", + "Context", + True, + ) assert elements["edges"][0] == ("arn1", "arn2", "Produced") From 2950b9f6855f7df89f2a922d18c5beec162253ef Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Mon, 8 Aug 2022 11:18:00 -0700 Subject: [PATCH 38/46] documentation: add visualize & PyvisVisualizer documentation --- src/sagemaker/lineage/query.py | 55 ++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 0acdccd698..c44e439b5a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -230,7 +230,31 @@ class PyvisVisualizer(object): """Create object used for visualizing graph using Pyvis library.""" def __init__(self, graph_styles): - """Init for PyvisVisualizer.""" + """Init for PyvisVisualizer. + + Args: + graph_styles: A dictionary that contains graph style for node and edges by their type. + Example: Display the nodes with different color by their lineage entity / different shape by start arn. + lineage_graph = { + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + + """ # import visualization packages ( self.Network, @@ -283,7 +307,22 @@ def _node_color(self, entity): return self.graph_styles[entity]["style"]["background-color"] def render(self, elements, path="pyvisExample.html"): - """Render graph for lineage query result.""" + """Render graph for lineage query result. + + Args: + elements: A dictionary that contains the node and the edges of the graph. + Example: + elements["nodes"] contains a list of tuples, each tuple represents a node in the format + (node arn, node lineage source, node lineage entity, node is start arn) + elements["edges"] contains a list of tuples, each tuple represents an edge in the format + (edge source arn, edge destination arn, edge association type) + + path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + + """ net = self.Network(height="500px", width="100%", notebook=True, directed=True) net.set_options(self._options) @@ -384,7 +423,17 @@ def _get_visualization_elements(self): return elements def visualize(self, path="pyvisExample.html"): - """Visualize lineage query result.""" + """Visualize lineage query result. + + Creates a PyvisVisualizer object to render network graph with Pyvis library. The elements(nodes & edges) are + preprocessed in this method and sent to PyvisVisualizer for rendering graph. + + Args: + path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + """ lineage_graph = { # nodes can have shape / color "TrialComponent": { From 143143d7cdee94aebcd2c371508fa5743d33811f Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 9 Aug 2022 10:44:48 -0700 Subject: [PATCH 39/46] doc: install pyvis before using visualize() --- src/sagemaker/lineage/query.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index c44e439b5a..f6966d9fa3 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -234,7 +234,8 @@ def __init__(self, graph_styles): Args: graph_styles: A dictionary that contains graph style for node and edges by their type. - Example: Display the nodes with different color by their lineage entity / different shape by start arn. + Example: Display the nodes with different color by their lineage entity / different + shape by start arn. lineage_graph = { "TrialComponent": { "name": "Trial Component", @@ -312,10 +313,11 @@ def render(self, elements, path="pyvisExample.html"): Args: elements: A dictionary that contains the node and the edges of the graph. Example: - elements["nodes"] contains a list of tuples, each tuple represents a node in the format - (node arn, node lineage source, node lineage entity, node is start arn) - elements["edges"] contains a list of tuples, each tuple represents an edge in the format - (edge source arn, edge destination arn, edge association type) + elements["nodes"] contains list of tuples, each tuple represents a node + format: (node arn, node lineage source, node lineage entity, + node is start arn) + elements["edges"] contains list of tuples, each tuple represents an edge + format: (edge source arn, edge destination arn, edge association type) path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") @@ -425,8 +427,10 @@ def _get_visualization_elements(self): def visualize(self, path="pyvisExample.html"): """Visualize lineage query result. - Creates a PyvisVisualizer object to render network graph with Pyvis library. The elements(nodes & edges) are - preprocessed in this method and sent to PyvisVisualizer for rendering graph. + Creates a PyvisVisualizer object to render network graph with Pyvis library. + Pyvis library should be installed before using this method (run "pip install pyvis") + The elements(nodes & edges) are preprocessed in this method and sent to + PyvisVisualizer for rendering graph. Args: path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") From f54e1b600d575f9b7560559a57156f2983899f29 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 9 Aug 2022 11:56:56 -0700 Subject: [PATCH 40/46] graph style fine-tune --- src/sagemaker/lineage/query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index f6966d9fa3..bcd4c26a85 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -332,14 +332,14 @@ def render(self, elements, path="pyvisExample.html"): for arn, source, entity, is_start_arn in elements["nodes"]: if is_start_arn: # startarn net.add_node( - arn, label=source, title=entity, color=self._node_color(entity), shape="star" + arn, label=source, title=entity, color=self._node_color(entity), shape="star", borderWidth=3 ) else: - net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) + net.add_node(arn, label=source, title=entity, color=self._node_color(entity), borderWidth=3) # add edges to graph for src, dest, asso_type in elements["edges"]: - net.add_edge(src, dest, title=asso_type) + net.add_edge(src, dest, title=asso_type, width=2) return net.show(path) From ff136c10471c3d7c90597abd9de42fbc07949a70 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Tue, 9 Aug 2022 16:03:02 -0700 Subject: [PATCH 41/46] query visualize more info on hover node --- src/sagemaker/lineage/query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index bcd4c26a85..9e257d7733 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -332,10 +332,10 @@ def render(self, elements, path="pyvisExample.html"): for arn, source, entity, is_start_arn in elements["nodes"]: if is_start_arn: # startarn net.add_node( - arn, label=source, title=entity, color=self._node_color(entity), shape="star", borderWidth=3 + arn, label=source, title=entity+"\n"+arn, color=self._node_color(entity), shape="star", borderWidth=3 ) else: - net.add_node(arn, label=source, title=entity, color=self._node_color(entity), borderWidth=3) + net.add_node(arn, label=source, title=entity+"\n"+arn, color=self._node_color(entity), borderWidth=3) # add edges to graph for src, dest, asso_type in elements["edges"]: From 1a3b5f3b64783fef36bc846f9af0de6c02ad91ec Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Wed, 10 Aug 2022 14:25:02 -0700 Subject: [PATCH 42/46] info on hover --- src/sagemaker/lineage/query.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 9e257d7733..348bd46697 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,7 +15,9 @@ from datetime import datetime from enum import Enum +from platform import node from typing import Optional, Union, List, Dict +import re from sagemaker.lineage._utils import get_resource_name_from_arn, get_module @@ -330,12 +332,14 @@ def render(self, elements, path="pyvisExample.html"): # add nodes to graph for arn, source, entity, is_start_arn in elements["nodes"]: + source = re.sub(r"(\w)([A-Z])", r"\1 \2", source) + node_info = "Entity: " + entity + "\n" + "Type: " + source + "\n" + "Name: " + arn if is_start_arn: # startarn net.add_node( - arn, label=source, title=entity+"\n"+arn, color=self._node_color(entity), shape="star", borderWidth=3 + arn, label=source, title=node_info, color=self._node_color(entity), shape="star", borderWidth=3 ) else: - net.add_node(arn, label=source, title=entity+"\n"+arn, color=self._node_color(entity), borderWidth=3) + net.add_node(arn, label=source, title=node_info, color=self._node_color(entity), borderWidth=3) # add edges to graph for src, dest, asso_type in elements["edges"]: @@ -391,7 +395,7 @@ def __str__(self): """ return ( - "{\n" + "{" + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items()) + "\n}" ) From a9e51141ccffb37f85b7559e84f3802dd61130c2 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Thu, 11 Aug 2022 14:46:52 -0700 Subject: [PATCH 43/46] split generate html file & display --- src/sagemaker/lineage/query.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 348bd46697..8258be0273 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -262,6 +262,8 @@ def __init__(self, graph_styles): ( self.Network, self.Options, + self.IFrame, + self.BeautifulSoup ) = self._import_visual_modules() self.graph_styles = graph_styles @@ -302,13 +304,22 @@ def _import_visual_modules(self): get_module("pyvis") from pyvis.network import Network from pyvis.options import Options + from IPython.display import IFrame - return Network, Options + get_module("bs4") + from bs4 import BeautifulSoup + + return Network, Options, IFrame, BeautifulSoup def _node_color(self, entity): """Return node color by background-color specified in graph styles.""" return self.graph_styles[entity]["style"]["background-color"] + def _add_legend(self, path): + f = open(path, "r+") + soup = self.BeautifulSoup(f, 'html.parser') + print(soup.prettify()) + def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result. @@ -345,7 +356,10 @@ def render(self, elements, path="pyvisExample.html"): for src, dest, asso_type in elements["edges"]: net.add_edge(src, dest, title=asso_type, width=2) - return net.show(path) + net.write_html(path) + self._add_legend(path) + + return self.IFrame(path, width="100%", height="500px") class LineageQueryResult(object): From c1be76563914857aebea563cd4d9665c34331e76 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 12 Aug 2022 11:35:51 -0700 Subject: [PATCH 44/46] change: lineage query visualization experience enhancement --- src/sagemaker/lineage/query.py | 86 ++++++++++++++++--- .../lineage/test_lineage_visualize.py | 40 +++++++-- tests/unit/sagemaker/lineage/test_query.py | 2 +- 3 files changed, 110 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 8258be0273..31885d3e7f 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,7 +15,6 @@ from datetime import datetime from enum import Enum -from platform import node from typing import Optional, Union, List, Dict import re @@ -263,7 +262,7 @@ def __init__(self, graph_styles): self.Network, self.Options, self.IFrame, - self.BeautifulSoup + self.BeautifulSoup, ) = self._import_visual_modules() self.graph_styles = graph_styles @@ -316,9 +315,53 @@ def _node_color(self, entity): return self.graph_styles[entity]["style"]["background-color"] def _add_legend(self, path): - f = open(path, "r+") - soup = self.BeautifulSoup(f, 'html.parser') - print(soup.prettify()) + """Embed legend to html file generated by pyvis.""" + f = open(path, "r") + content = self.BeautifulSoup(f, "html.parser") + + legend = """ +
+
+
+
+
Trial Component
+
+
+
+
+
Context
+
+
+
+
+
Action
+
+
+
+
+
Artifact
+
+
+
star
+
+
StartArn
+
+
+ """ + legend_div = self.BeautifulSoup(legend, "html.parser") + + content.div.insert_after(legend_div) + + html = content.prettify() + + with open(path, "w", encoding="utf8") as file: + file.write(html) def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result. @@ -338,19 +381,42 @@ def render(self, elements, path="pyvisExample.html"): display graph: The interactive visualization is presented as a static HTML file. """ - net = self.Network(height="500px", width="100%", notebook=True, directed=True) + net = self.Network(height="600px", width="82%", notebook=True, directed=True) net.set_options(self._options) # add nodes to graph for arn, source, entity, is_start_arn in elements["nodes"]: + entity_text = re.sub(r"(\w)([A-Z])", r"\1 \2", entity) source = re.sub(r"(\w)([A-Z])", r"\1 \2", source) - node_info = "Entity: " + entity + "\n" + "Type: " + source + "\n" + "Name: " + arn + account_id = re.search(r":\d{12}:", arn) + name = re.search(r"\/.*", arn) + node_info = ( + "Entity: " + + entity_text + + "\nType: " + + source + + "\nAccount ID: " + + str(account_id.group()[1:-1]) + + "\nName: " + + str(name.group()[1:]) + ) if is_start_arn: # startarn net.add_node( - arn, label=source, title=node_info, color=self._node_color(entity), shape="star", borderWidth=3 + arn, + label=source, + title=node_info, + color=self._node_color(entity), + shape="star", + borderWidth=3, ) else: - net.add_node(arn, label=source, title=node_info, color=self._node_color(entity), borderWidth=3) + net.add_node( + arn, + label=source, + title=node_info, + color=self._node_color(entity), + borderWidth=3, + ) # add edges to graph for src, dest, asso_type in elements["edges"]: @@ -359,7 +425,7 @@ def render(self, elements, path="pyvisExample.html"): net.write_html(path) self._add_legend(path) - return self.IFrame(path, width="100%", height="500px") + return self.IFrame(path, width="100%", height="600px") class LineageQueryResult(object): diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 555b98452e..4b9e816623 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import time import os +import re import pytest @@ -160,31 +161,56 @@ def test_graph_visualize(sagemaker_session, extract_data_from_html): "color": "#146eb4", "label": "Model", "shape": "star", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Model" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", graph_startarn).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", graph_startarn).group()[1:]), }, image_artifact: { "color": "#146eb4", "label": "Image", "shape": "dot", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Image" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", image_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", image_artifact).group()[1:]), }, dataset_artifact: { "color": "#146eb4", - "label": "DataSet", + "label": "Data Set", "shape": "dot", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Data Set" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", dataset_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", dataset_artifact).group()[1:]), }, modeldeploy_action: { "color": "#88c396", - "label": "ModelDeploy", + "label": "Model Deploy", "shape": "dot", - "title": "Action", + "title": "Entity: Action" + + "\nType: Model Deploy" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", modeldeploy_action).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", modeldeploy_action).group()[1:]), }, endpoint_context: { "color": "#ff9900", "label": "Endpoint", "shape": "dot", - "title": "Context", + "title": "Entity: Context" + + "\nType: Endpoint" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", endpoint_context).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", endpoint_context).group()[1:]), }, } diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 0a357eb1fc..bac5cb6cdb 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -580,7 +580,7 @@ def test_query_lineage_result_str(sagemaker_session): assert ( response_str - == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + == "{'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " + "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " + "'Model', '_session': }],\n\n'startarn': " From e1190092b79544ed3f978533a5939225f536ce03 Mon Sep 17 00:00:00 2001 From: Yi-Ting Lee Date: Fri, 12 Aug 2022 16:33:34 -0700 Subject: [PATCH 45/46] generate legend divs programmatically --- src/sagemaker/lineage/query.py | 56 +++++++++++++++------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 31885d3e7f..659f88a59c 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -314,6 +314,25 @@ def _node_color(self, entity): """Return node color by background-color specified in graph styles.""" return self.graph_styles[entity]["style"]["background-color"] + def _get_legend_line(self, component_name): + """Generate lengend div line for each graph component in graph_styles.""" + if self.graph_styles[component_name]["isShape"] == "False": + return '
\ +
\ +
{name}
'.format( + color=self.graph_styles[component_name]["style"]["background-color"], + name=self.graph_styles[component_name]["name"], + ) + else: + return '
{shape}
\ +
\ +
{name}
'.format( + shape=self.graph_styles[component_name]["style"]["shape"], + name=self.graph_styles[component_name]["name"], + ) + def _add_legend(self, path): """Embed legend to html file generated by pyvis.""" f = open(path, "r") @@ -322,38 +341,13 @@ def _add_legend(self, path): legend = """
-
-
-
-
Trial Component
-
-
-
-
-
Context
-
-
-
-
-
Action
-
-
-
-
-
Artifact
-
-
-
star
-
-
StartArn
-
-
""" + # iterate through graph styles to get legend + for component in self.graph_styles.keys(): + legend += self._get_legend_line(component_name=component) + + legend += "" + legend_div = self.BeautifulSoup(legend, "html.parser") content.div.insert_after(legend_div) From 2ac472ae8ce1a5fe9a9e23285609373e97725c80 Mon Sep 17 00:00:00 2001 From: jkasiraj Date: Wed, 16 Nov 2022 15:11:16 -0800 Subject: [PATCH 46/46] fix: remove no-else-return block in _get_legend_line --- src/sagemaker/lineage/query.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 659f88a59c..afc49b9ba9 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -324,14 +324,14 @@ def _get_legend_line(self, component_name): color=self.graph_styles[component_name]["style"]["background-color"], name=self.graph_styles[component_name]["name"], ) - else: - return '
{shape}
\ -
\ -
{name}
'.format( - shape=self.graph_styles[component_name]["style"]["shape"], - name=self.graph_styles[component_name]["name"], - ) + + return '
{shape}
\ +
\ +
{name}
'.format( + shape=self.graph_styles[component_name]["style"]["shape"], + name=self.graph_styles[component_name]["name"], + ) def _add_legend(self, path): """Embed legend to html file generated by pyvis."""