From 2c035d8662200d44c53f4bc4068a63fb8cfd71cc Mon Sep 17 00:00:00 2001 From: Sachin Mysore Satish Date: Thu, 18 Feb 2021 10:20:45 -0800 Subject: [PATCH] Add tests for visualizer to improve test coverage --- .../unit/sagemaker/lineage/test_visualizer.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/tests/unit/sagemaker/lineage/test_visualizer.py b/tests/unit/sagemaker/lineage/test_visualizer.py index 386f4e16be..d5a373e08a 100644 --- a/tests/unit/sagemaker/lineage/test_visualizer.py +++ b/tests/unit/sagemaker/lineage/test_visualizer.py @@ -119,3 +119,137 @@ def test_trial_component_name(sagemaker_session, vizualizer): ) pd.testing.assert_frame_equal(expected_dataframe, df) + + +def test_model_package_arn(sagemaker_session, vizualizer): + name = "model_package_arn" + + sagemaker_session.sagemaker_client.list_artifacts.return_value = { + "ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}] + } + + sagemaker_session.sagemaker_client.list_associations.side_effect = [ + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:artifact/src-arn-1", + "SourceName": "source-name-1", + "SourceType": "source-type-1", + "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1", + "DestinationName": "dest-name-1", + "DestinationType": "dest-type-1", + "AssociationType": "type-1", + } + ] + }, + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:artifact/src-arn-2", + "SourceName": "source-name-2", + "SourceType": "source-type-2", + "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2", + "DestinationName": "dest-name-2", + "DestinationType": "dest-type-2", + "AssociationType": "type-2", + } + ] + }, + ] + + df = vizualizer.show(model_package_arn=name) + + sagemaker_session.sagemaker_client.list_artifacts.assert_called_with( + SourceUri=name, + ) + + expected_calls = [ + unittest.mock.call( + DestinationArn="artifact-arn", + ), + unittest.mock.call( + SourceArn="artifact-arn", + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + + expected_dataframe = pd.DataFrame.from_dict( + OrderedDict( + [ + ("Name/Source", ["source-name-1", "dest-name-2"]), + ("Direction", ["Input", "Output"]), + ("Type", ["source-type-1", "dest-type-2"]), + ("Association Type", ["type-1", "type-2"]), + ("Lineage Type", ["artifact", "artifact"]), + ] + ) + ) + + pd.testing.assert_frame_equal(expected_dataframe, df) + + +def test_endpoint_arn(sagemaker_session, vizualizer): + name = "endpoint_arn" + + sagemaker_session.sagemaker_client.list_contexts.return_value = { + "ContextSummaries": [{"ContextArn": "context-arn"}] + } + + sagemaker_session.sagemaker_client.list_associations.side_effect = [ + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:context/src-arn-1", + "SourceName": "source-name-1", + "SourceType": "source-type-1", + "DestinationArn": "a:b:c:d:e:context/dest-arn-1", + "DestinationName": "dest-name-1", + "DestinationType": "dest-type-1", + "AssociationType": "type-1", + } + ] + }, + { + "AssociationSummaries": [ + { + "SourceArn": "a:b:c:d:e:context/src-arn-2", + "SourceName": "source-name-2", + "SourceType": "source-type-2", + "DestinationArn": "a:b:c:d:e:context/dest-arn-2", + "DestinationName": "dest-name-2", + "DestinationType": "dest-type-2", + "AssociationType": "type-2", + } + ] + }, + ] + + df = vizualizer.show(endpoint_arn=name) + + sagemaker_session.sagemaker_client.list_contexts.assert_called_with( + SourceUri=name, + ) + + expected_calls = [ + unittest.mock.call( + DestinationArn="context-arn", + ), + unittest.mock.call( + SourceArn="context-arn", + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + + expected_dataframe = pd.DataFrame.from_dict( + OrderedDict( + [ + ("Name/Source", ["source-name-1", "dest-name-2"]), + ("Direction", ["Input", "Output"]), + ("Type", ["source-type-1", "dest-type-2"]), + ("Association Type", ["type-1", "type-2"]), + ("Lineage Type", ["context", "context"]), + ] + ) + ) + + pd.testing.assert_frame_equal(expected_dataframe, df)