diff --git a/src/sagemaker/lineage/visualizer.py b/src/sagemaker/lineage/visualizer.py index cab9cf2891..e3faeaa491 100644 --- a/src/sagemaker/lineage/visualizer.py +++ b/src/sagemaker/lineage/visualizer.py @@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step): return None metadata = pipeline_execution_step["Metadata"] - jobs = ["TrainingJob", "ProccessingJob", "TransformJob"] + jobs = ["TrainingJob", "ProcessingJob", "TransformJob"] for job in jobs: if job in metadata and metadata[job]: job_arn = metadata[job]["Arn"] diff --git a/tests/unit/sagemaker/lineage/conftest.py b/tests/unit/sagemaker/lineage/conftest.py new file mode 100644 index 0000000000..10e52f051a --- /dev/null +++ b/tests/unit/sagemaker/lineage/conftest.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker.lineage import visualizer +import unittest.mock + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +@pytest.fixture +def viz(sagemaker_session): + return visualizer.LineageTableVisualizer(sagemaker_session) diff --git a/tests/unit/sagemaker/lineage/test_action.py b/tests/unit/sagemaker/lineage/test_action.py index 6d483bcc0b..79e59b679b 100644 --- a/tests/unit/sagemaker/lineage/test_action.py +++ b/tests/unit/sagemaker/lineage/test_action.py @@ -15,15 +15,9 @@ import datetime import unittest.mock -import pytest from sagemaker.lineage import action, _api_types -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - def test_create(sagemaker_session): sagemaker_session.sagemaker_client.create_action.return_value = { "ActionArn": "bazz", diff --git a/tests/unit/sagemaker/lineage/test_artifact.py b/tests/unit/sagemaker/lineage/test_artifact.py index 8a3ac5d5cc..72228ec964 100644 --- a/tests/unit/sagemaker/lineage/test_artifact.py +++ b/tests/unit/sagemaker/lineage/test_artifact.py @@ -15,15 +15,9 @@ import datetime import unittest.mock -import pytest from sagemaker.lineage import artifact, _api_types -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - def test_create(sagemaker_session): sagemaker_session.sagemaker_client.create_artifact.return_value = { "ArtifactArn": "bazz", diff --git a/tests/unit/sagemaker/lineage/test_association.py b/tests/unit/sagemaker/lineage/test_association.py index 9e02be2a69..268bdac0e0 100644 --- a/tests/unit/sagemaker/lineage/test_association.py +++ b/tests/unit/sagemaker/lineage/test_association.py @@ -15,15 +15,9 @@ import datetime import unittest.mock -import pytest from sagemaker.lineage import association, _api_types -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - def test_create(sagemaker_session): sagemaker_session.sagemaker_client.add_association.return_value = { "AssociationArn": "bazz", diff --git a/tests/unit/sagemaker/lineage/test_endpoint_context.py b/tests/unit/sagemaker/lineage/test_endpoint_context.py index 80f6ee8dac..01369e5378 100644 --- a/tests/unit/sagemaker/lineage/test_endpoint_context.py +++ b/tests/unit/sagemaker/lineage/test_endpoint_context.py @@ -14,15 +14,9 @@ import unittest.mock -import pytest from sagemaker.lineage import context, _api_types -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - def test_models(sagemaker_session): obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz") diff --git a/tests/unit/sagemaker/lineage/test_model_artifact.py b/tests/unit/sagemaker/lineage/test_model_artifact.py index 9e9fe84920..def352efbc 100644 --- a/tests/unit/sagemaker/lineage/test_model_artifact.py +++ b/tests/unit/sagemaker/lineage/test_model_artifact.py @@ -14,15 +14,9 @@ import unittest.mock -import pytest from sagemaker.lineage import artifact, _api_types -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - def test_trained_models(sagemaker_session): model_artifact_obj = artifact.ModelArtifact( sagemaker_session, artifact_arn="model-artifact-arn" diff --git a/tests/unit/sagemaker/lineage/test_visualizer.py b/tests/unit/sagemaker/lineage/test_visualizer.py index d5a373e08a..d2e1073ca1 100644 --- a/tests/unit/sagemaker/lineage/test_visualizer.py +++ b/tests/unit/sagemaker/lineage/test_visualizer.py @@ -1,4 +1,4 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -14,33 +14,21 @@ import unittest.mock -import pytest -from sagemaker.lineage import visualizer import pandas as pd from collections import OrderedDict -@pytest.fixture -def sagemaker_session(): - return unittest.mock.Mock() - - -@pytest.fixture -def vizualizer(sagemaker_session): - return visualizer.LineageTableVisualizer(sagemaker_session) - - -def test_friendly_name_short_uri(vizualizer, sagemaker_session): +def test_friendly_name_short_uri(viz, sagemaker_session): uri = "s3://f-069083975568/train.txt" arn = "test_arn" sagemaker_session.sagemaker_client.describe_artifact.return_value = { "Source": {"SourceUri": uri, "SourceTypes": ""} } - actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact") + actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact") assert uri == actual_name -def test_friendly_name_long_uri(vizualizer, sagemaker_session): +def test_friendly_name_long_uri(viz, sagemaker_session): uri = ( "s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/" "preprocessed-data/tuning_data/train.txt" @@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session): sagemaker_session.sagemaker_client.describe_artifact.return_value = { "Source": {"SourceUri": uri, "SourceTypes": ""} } - actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact") + actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact") expected_name = "s3://.../preprocessed-data/tuning_data/train.txt" assert expected_name == actual_name -def test_trial_component_name(sagemaker_session, vizualizer): +def test_trial_component_name(viz, sagemaker_session): name = "tc-name" sagemaker_session.sagemaker_client.describe_trial_component.return_value = { @@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer): }, ] - df = vizualizer.show(trial_component_name=name) + df = viz.show(trial_component_name=name) sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with( TrialComponentName=name, @@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer): pd.testing.assert_frame_equal(expected_dataframe, df) -def test_model_package_arn(sagemaker_session, vizualizer): +def test_model_package_arn(viz, sagemaker_session): name = "model_package_arn" sagemaker_session.sagemaker_client.list_artifacts.return_value = { @@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer): }, ] - df = vizualizer.show(model_package_arn=name) + df = viz.show(model_package_arn=name) sagemaker_session.sagemaker_client.list_artifacts.assert_called_with( SourceUri=name, @@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer): pd.testing.assert_frame_equal(expected_dataframe, df) -def test_endpoint_arn(sagemaker_session, vizualizer): +def test_endpoint_arn(viz, sagemaker_session): name = "endpoint_arn" sagemaker_session.sagemaker_client.list_contexts.return_value = { @@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer): }, ] - df = vizualizer.show(endpoint_arn=name) + df = viz.show(endpoint_arn=name) sagemaker_session.sagemaker_client.list_contexts.assert_called_with( SourceUri=name, @@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer): ) pd.testing.assert_frame_equal(expected_dataframe, df) + + +def test_processing_job_pipeline_execution_step(viz, sagemaker_session): + + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}] + } + + sagemaker_session.sagemaker_client.list_associations.side_effect = [ + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:artifact/src-arn-1", + "SourceName": "source-name-1", + "SourceType": "source-type-1", + "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1", + "DestinationName": "dest-name-1", + "DestinationType": "dest-type-1", + "AssociationType": "type-1", + } + ] + }, + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:artifact/src-arn-2", + "SourceName": "source-name-2", + "SourceType": "source-type-2", + "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2", + "DestinationName": "dest-name-2", + "DestinationType": "dest-type-2", + "AssociationType": "type-2", + } + ] + }, + ] + + step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}} + + df = viz.show(pipeline_execution_step=step) + + sagemaker_session.sagemaker_client.list_trial_components.assert_called_with( + SourceArn="proc-job-arn", + ) + + expected_calls = [ + unittest.mock.call( + DestinationArn="tc-arn", + ), + unittest.mock.call( + SourceArn="tc-arn", + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + + expected_dataframe = pd.DataFrame.from_dict( + OrderedDict( + [ + ("Name/Source", ["source-name-1", "dest-name-2"]), + ("Direction", ["Input", "Output"]), + ("Type", ["source-type-1", "dest-type-2"]), + ("Association Type", ["type-1", "type-2"]), + ("Lineage Type", ["artifact", "artifact"]), + ] + ) + ) + + pd.testing.assert_frame_equal(expected_dataframe, df)