diff --git a/.gitignore b/.gitignore index cae8f890ea..0e63d2c505 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ env/ .vscode/ **/tmp .python-version +*.html **/_repack_script_launcher.sh tests/data/**/_repack_model.py tests/data/experiment/sagemaker-dev-1.0.tar.gz diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 695c5b2d47..bf1ec6e2d9 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -19,6 +19,7 @@ fabric==2.6.0 requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 +pyvis==0.2.1 pandas>=1.3.5,<1.5 scikit-learn==1.0.2 cloudpickle==2.2.1 diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a54331c39a..182f117913 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -15,8 +15,11 @@ from datetime import datetime from enum import Enum -from typing import Optional, Union, List, Dict +from typing import Any, Optional, Union, List, Dict +from json import dumps +from re import sub, search +from sagemaker.utils import get_module from sagemaker.lineage._utils import get_resource_name_from_arn @@ -92,6 +95,32 @@ def __eq__(self, other): and self.destination_arn == other.destination_arn ) + def __str__(self): + """Define string representation of ``Edge``. + + Format: + { + 'source_arn': 'string', + 'destination_arn': 'string', + 'association_type': 'string' + } + + """ + return str(self.__dict__) + + def __repr__(self): + """Define string representation of ``Edge``. + + Format: + { + 'source_arn': 'string', + 'destination_arn': 'string', + 'association_type': 'string' + } + + """ + return "\n\t" + str(self.__dict__) + class Vertex: """A vertex for a lineage graph.""" @@ -130,6 +159,34 @@ def __eq__(self, other): and self.lineage_source == other.lineage_source ) + def __str__(self): + """Define string representation of ``Vertex``. + + Format: + { + 'arn': 'string', + 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + } + + """ + return str(self.__dict__) + + def __repr__(self): + """Define string representation of ``Vertex``. + + Format: + { + 'arn': 'string', + 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + } + + """ + return "\n\t" + str(self.__dict__) + def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -176,6 +233,196 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) +class PyvisVisualizer(object): + """Create object used for visualizing graph using Pyvis library.""" + + def __init__(self, graph_styles, pyvis_options: Optional[Dict[str, Any]] = None): + """Init for PyvisVisualizer. + + Args: + graph_styles: A dictionary that contains graph style for node and edges by their type. + Example: Display the nodes with different color by their lineage entity / different + shape by start arn. + lineage_graph_styles = { + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + pyvis_options(optional): A dict containing PyVis options to customize visualization. + (see https://visjs.github.io/vis-network/docs/network/#options for supported fields) + """ + # import visualization packages + ( + self.Network, + self.Options, + self.IFrame, + self.BeautifulSoup, + ) = self._import_visual_modules() + + self.graph_styles = graph_styles + + if pyvis_options is None: + # default pyvis graph options + pyvis_options = { + "configure": {"enabled": False}, + "layout": { + "hierarchical": { + "enabled": True, + "blockShifting": True, + "direction": "LR", + "sortMethod": "directed", + "shakeTowards": "leaves", + } + }, + "interaction": {"multiselect": True, "navigationButtons": True}, + "physics": { + "enabled": False, + "hierarchicalRepulsion": {"centralGravity": 0, "avoidOverlap": None}, + "minVelocity": 0.75, + "solver": "hierarchicalRepulsion", + }, + } + # A string representation of a Javascript-like object used to override pyvis options + self._pyvis_options = f"var options = {dumps(pyvis_options)}" + + def _import_visual_modules(self): + """Import modules needed for visualization.""" + get_module("pyvis") + from pyvis.network import Network + from pyvis.options import Options + from IPython.display import IFrame + + get_module("bs4") + from bs4 import BeautifulSoup + + return Network, Options, IFrame, BeautifulSoup + + def _node_color(self, entity): + """Return node color by background-color specified in graph styles.""" + return self.graph_styles[entity]["style"]["background-color"] + + def _get_legend_line(self, component_name): + """Generate lengend div line for each graph component in graph_styles.""" + if self.graph_styles[component_name]["isShape"] == "False": + return '
\ +
\ +
{name}
'.format( + color=self.graph_styles[component_name]["style"]["background-color"], + name=self.graph_styles[component_name]["name"], + ) + + return '
{shape}
\ +
\ +
{name}
'.format( + shape=self.graph_styles[component_name]["style"]["shape"], + name=self.graph_styles[component_name]["name"], + ) + + def _add_legend(self, path): + """Embed legend to html file generated by pyvis.""" + f = open(path, "r") + content = self.BeautifulSoup(f, "html.parser") + + legend = """ +
+ """ + # iterate through graph styles to get legend + for component in self.graph_styles.keys(): + legend += self._get_legend_line(component_name=component) + + legend += "
" + + legend_div = self.BeautifulSoup(legend, "html.parser") + + content.div.insert_after(legend_div) + + html = content.prettify() + + with open(path, "w", encoding="utf8") as file: + file.write(html) + + def render(self, elements, path="lineage_graph_pyvis.html"): + """Render graph for lineage query result. + + Args: + elements: A dictionary that contains the node and the edges of the graph. + Example: + elements["nodes"] contains list of tuples, each tuple represents a node + format: (node arn, node lineage source, node lineage entity, + node is start arn) + elements["edges"] contains list of tuples, each tuple represents an edge + format: (edge source arn, edge destination arn, edge association type) + + path(optional): The path/filename of the rendered graph html file. + (default path: "lineage_graph_pyvis.html") + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + + """ + net = self.Network(height="600px", width="82%", notebook=True, directed=True) + net.set_options(self._pyvis_options) + + # add nodes to graph + for arn, source, entity, is_start_arn in elements["nodes"]: + entity_text = sub(r"(\w)([A-Z])", r"\1 \2", entity) + source = sub(r"(\w)([A-Z])", r"\1 \2", source) + account_id = search(r":\d{12}:", arn) + name = search(r"\/.*", arn) + node_info = ( + "Entity: " + + entity_text + + "\nType: " + + source + + "\nAccount ID: " + + str(account_id.group()[1:-1]) + + "\nName: " + + str(name.group()[1:]) + ) + if is_start_arn: # startarn + net.add_node( + arn, + label=source, + title=node_info, + color=self._node_color(entity), + shape="star", + borderWidth=3, + ) + else: + net.add_node( + arn, + label=source, + title=node_info, + color=self._node_color(entity), + borderWidth=3, + ) + + # add edges to graph + for src, dest, asso_type in elements["edges"]: + net.add_edge(src, dest, title=asso_type, width=2) + + net.write_html(path) + self._add_legend(path) + + return self.IFrame(path, width="100%", height="600px") + + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -183,6 +430,7 @@ def __init__( self, edges: List[Edge] = None, vertices: List[Vertex] = None, + startarn: List[str] = None, ): """Init for LineageQueryResult. @@ -192,6 +440,7 @@ def __init__( """ self.edges = [] self.vertices = [] + self.startarn = [] if edges is not None: self.edges = edges @@ -199,6 +448,124 @@ def __init__( if vertices is not None: self.vertices = vertices + if startarn is not None: + self.startarn = startarn + + def __str__(self): + """Define string representation of ``LineageQueryResult``. + + Format: + { + 'edges':[ + { + 'source_arn': 'string', + 'destination_arn': 'string', + 'association_type': 'string' + }, + ], + + 'vertices':[ + { + 'arn': 'string', + 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + }, + ], + + 'startarn':['string', ...] + } + + """ + return ( + "{" + + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items()) + + "\n}" + ) + + def _covert_edges_to_tuples(self): + """Convert edges to tuple format for visualizer.""" + edges = [] + # get edge info in the form of (source, target, label) + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + + def _covert_vertices_to_tuples(self): + """Convert vertices to tuple format for visualizer.""" + verts = [] + # get vertex info in the form of (id, label, class) + for vert in self.vertices: + if vert.arn in self.startarn: + # add "startarn" class to node if arn is a startarn + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True)) + else: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False)) + return verts + + def _get_visualization_elements(self): + """Get elements(nodes+edges) for visualization.""" + verts = self._covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() + + elements = {"nodes": verts, "edges": edges} + return elements + + def visualize( + self, + path: Optional[str] = "lineage_graph_pyvis.html", + pyvis_options: Optional[Dict[str, Any]] = None, + ): + """Visualize lineage query result. + + Creates a PyvisVisualizer object to render network graph with Pyvis library. + Pyvis library should be installed before using this method (run "pip install pyvis") + The elements(nodes & edges) are preprocessed in this method and sent to + PyvisVisualizer for rendering graph. + + Args: + path(optional): The path/filename of the rendered graph html file. + (default path: "lineage_graph_pyvis.html") + pyvis_options(optional): A dict containing PyVis options to customize visualization. + (see https://visjs.github.io/vis-network/docs/network/#options for supported fields) + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + """ + lineage_graph_styles = { + # nodes can have shape / color + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "Action": { + "name": "Action", + "style": {"background-color": "#88c396"}, + "isShape": "False", + }, + "Artifact": { + "name": "Artifact", + "style": {"background-color": "#146eb4"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + + pyvis_vis = PyvisVisualizer(lineage_graph_styles, pyvis_options) + elements = self._get_visualization_elements() + return pyvis_vis.render(elements=elements, path=path) + class LineageFilter(object): """A filter used in a lineage query.""" @@ -273,9 +640,8 @@ def _get_vertex(self, vertex): sagemaker_session=self._session, ) - def _convert_api_response(self, response) -> LineageQueryResult: + def _convert_api_response(self, response, converted) -> LineageQueryResult: """Convert the lineage query API response to its Python representation.""" - converted = LineageQueryResult() converted.edges = [self._get_edge(edge) for edge in response["Edges"]] converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]] @@ -358,7 +724,9 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) - query_response = self._convert_api_response(query_response) + # create query result for startarn info + query_result = LineageQueryResult(startarn=start_arns) + query_response = self._convert_api_response(query_response, query_result) query_response = self._collapse_cross_account_artifacts(query_response) return query_response diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index abfe6f6d0d..7a27bfb3cd 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -19,6 +19,7 @@ import pytest import logging import uuid +import json from sagemaker.lineage import ( action, context, @@ -892,3 +893,17 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session): pass else: raise (e) + + +@pytest.fixture +def extract_data_from_html(): + def _method(data): + start = data.find("[") + end = data.find("]") + res = data[start + 1 : end].split("}, ") + res = [i + "}" for i in res] + res[-1] = res[-1][:-1] + data_dict = [json.loads(i) for i in res] + return data_dict + + return _method diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index 5548c63cff..faf3081fec 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import uuid +import time from datetime import datetime @@ -64,3 +65,94 @@ def visit(arn, visited: set): ret = [] return visit(start_arn, set()) + + +class LineageResourceHelper: + def __init__(self, sagemaker_session): + self.client = sagemaker_session.sagemaker_client + self.artifacts = [] + self.actions = [] + self.contexts = [] + self.associations = [] + + def create_artifact(self, artifact_name, artifact_type="Dataset"): + response = self.client.create_artifact( + ArtifactName=artifact_name, + Source={ + "SourceUri": "Test-artifact-" + artifact_name, + "SourceTypes": [ + {"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"}, + ], + }, + ArtifactType=artifact_type, + ) + self.artifacts.append(response["ArtifactArn"]) + + return response["ArtifactArn"] + + def create_action(self, action_name, action_type="ModelDeployment"): + response = self.client.create_action( + ActionName=action_name, + Source={ + "SourceUri": "Test-action-" + action_name, + "SourceType": "S3ETag", + "SourceId": "Test-action-sourceId-value", + }, + ActionType=action_type, + ) + self.actions.append(response["ActionArn"]) + + return response["ActionArn"] + + def create_context(self, context_name, context_type="Endpoint"): + response = self.client.create_context( + ContextName=context_name, + Source={ + "SourceUri": "Test-context-" + context_name, + "SourceType": "S3ETag", + "SourceId": "Test-context-sourceId-value", + }, + ContextType=context_type, + ) + self.contexts.append(response["ContextArn"]) + + return response["ContextArn"] + + def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"): + response = self.client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type + ) + if "SourceArn" in response.keys(): + self.associations.append((source_arn, dest_arn)) + return True + else: + return False + + def clean_all(self): + # clean all lineage data created by LineageResourceHelper + + time.sleep(1) # avoid GSI race condition between create & delete + + for source, dest in self.associations: + try: + self.client.delete_association(SourceArn=source, DestinationArn=dest) + except Exception as e: + print("skipped " + str(e)) + + for artifact_arn in self.artifacts: + try: + self.client.delete_artifact(ArtifactArn=artifact_arn) + except Exception as e: + print("skipped " + str(e)) + + for action_arn in self.actions: + try: + self.client.delete_action(ActionArn=action_arn) + except Exception as e: + print("skipped " + str(e)) + + for context_arn in self.contexts: + try: + self.client.delete_context(ContextArn=context_arn) + except Exception as e: + print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py new file mode 100644 index 0000000000..4b9e816623 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -0,0 +1,259 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``""" +from __future__ import absolute_import +import time +import os +import re + +import pytest + +import sagemaker.lineage.query +from sagemaker.lineage.query import LineageQueryDirectionEnum +from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper + + +def test_LineageResourceHelper(sagemaker_session): + # check if LineageResourceHelper works properly + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + try: + art1 = lineage_resource_helper.create_artifact(artifact_name=name()) + art2 = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) + except Exception as e: + print(e) + assert False + finally: + lineage_resource_helper.clean_all() + + +@pytest.mark.skip("visualizer load test") +def test_wide_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + + # create wide graph + # Artifact ----> Artifact + # \ \ \-> Artifact + # \ \--> Artifact + # \---> ... + try: + for i in range(150): + artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association( + source_arn=wide_graph_root_arn, dest_arn=artifact_arn + ) + + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[wide_graph_root_arn]) + lq_result.visualize(path="wideGraph.html") + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + + +@pytest.mark.skip("visualizer load test") +def test_long_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + last_arn = long_graph_root_arn + + # create long graph + # Artifact -> Artifact -> ... -> Artifact + try: + for i in range(10): + new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association( + source_arn=last_arn, dest_arn=new_artifact_arn + ) + last_arn = new_artifact_arn + + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query( + start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS + ) + # max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction) + lq_result.visualize(path="longGraph.html") + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + + +def test_graph_visualize(sagemaker_session, extract_data_from_html): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + + # create lineage data + # image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context + # /-> + # dataset artifact -/ + try: + graph_startarn = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Model" + ) + image_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Image" + ) + lineage_resource_helper.create_association( + source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo" + ) + dataset_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="DataSet" + ) + lineage_resource_helper.create_association( + source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith" + ) + modeldeploy_action = lineage_resource_helper.create_action( + action_name=name(), action_type="ModelDeploy" + ) + lineage_resource_helper.create_association( + source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo" + ) + endpoint_context = lineage_resource_helper.create_context( + context_name=name(), context_type="Endpoint" + ) + lineage_resource_helper.create_association( + source_arn=modeldeploy_action, + dest_arn=endpoint_context, + association_type="AssociatedWith", + ) + time.sleep(3) + + # visualize + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[graph_startarn]) + lq_result.visualize(path="testGraph.html") + + # check generated graph info + fo = open("testGraph.html", "r") + lines = fo.readlines() + for line in lines: + if "nodes = " in line: + node = line + if "edges = " in line: + edge = line + + node_dict = extract_data_from_html(node) + edge_dict = extract_data_from_html(edge) + + # check node number + assert len(node_dict) == 5 + + expected_nodes = { + graph_startarn: { + "color": "#146eb4", + "label": "Model", + "shape": "star", + "title": "Entity: Artifact" + + "\nType: Model" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", graph_startarn).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", graph_startarn).group()[1:]), + }, + image_artifact: { + "color": "#146eb4", + "label": "Image", + "shape": "dot", + "title": "Entity: Artifact" + + "\nType: Image" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", image_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", image_artifact).group()[1:]), + }, + dataset_artifact: { + "color": "#146eb4", + "label": "Data Set", + "shape": "dot", + "title": "Entity: Artifact" + + "\nType: Data Set" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", dataset_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", dataset_artifact).group()[1:]), + }, + modeldeploy_action: { + "color": "#88c396", + "label": "Model Deploy", + "shape": "dot", + "title": "Entity: Action" + + "\nType: Model Deploy" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", modeldeploy_action).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", modeldeploy_action).group()[1:]), + }, + endpoint_context: { + "color": "#ff9900", + "label": "Endpoint", + "shape": "dot", + "title": "Entity: Context" + + "\nType: Endpoint" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", endpoint_context).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", endpoint_context).group()[1:]), + }, + } + + # check node properties + for node in node_dict: + for label, val in expected_nodes[node["id"]].items(): + assert node[label] == val + + # check edge number + assert len(edge_dict) == 4 + + expected_edges = { + image_artifact: { + "from": image_artifact, + "to": graph_startarn, + "title": "ContributedTo", + }, + dataset_artifact: { + "from": dataset_artifact, + "to": graph_startarn, + "title": "AssociatedWith", + }, + graph_startarn: { + "from": graph_startarn, + "to": modeldeploy_action, + "title": "ContributedTo", + }, + modeldeploy_action: { + "from": modeldeploy_action, + "to": endpoint_context, + "title": "AssociatedWith", + }, + } + + # check edge properties + for edge in edge_dict: + for label, val in expected_edges[edge["from"]].items(): + assert edge[label] == val + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + os.remove("testGraph.html") diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index ae76fd199c..bac5cb6cdb 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -18,6 +18,7 @@ from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest +import re def test_lineage_query(sagemaker_session): @@ -524,3 +525,64 @@ def test_vertex_to_object_unconvertable(sagemaker_session): with pytest.raises(ValueError): vertex.to_lineage_object() + + +def test_get_visualization_elements(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + { + "Arn": "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Type": "Model", + "LineageType": "Context", + }, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + elements = query_response._get_visualization_elements() + + assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) + assert elements["nodes"][1] == ("arn2", "Model", "Context", False) + assert elements["nodes"][2] == ( + "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Model", + "Context", + True, + ) + assert elements["edges"][0] == ("arn1", "arn2", "Produced") + + +def test_query_lineage_result_str(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + response_str = query_response.__str__() + pattern = r"Mock id='\d*'" + replace = r"Mock id=''" + response_str = re.sub(pattern, replace, response_str) + + assert ( + response_str + == "{'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " + + "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " + + "'Model', '_session': }],\n\n'startarn': " + + "['arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext'],\n}" + )