diff --git a/tests/unit/sagemaker/lineage/test_visualizer.py b/tests/unit/sagemaker/lineage/test_visualizer.py index d2e1073ca1..402c1c82a2 100644 --- a/tests/unit/sagemaker/lineage/test_visualizer.py +++ b/tests/unit/sagemaker/lineage/test_visualizer.py @@ -49,34 +49,7 @@ def test_trial_component_name(viz, sagemaker_session): "TrialComponentArn": "tc-arn", } - sagemaker_session.sagemaker_client.list_associations.side_effect = [ - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:artifact/src-arn-1", - "SourceName": "source-name-1", - "SourceType": "source-type-1", - "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1", - "DestinationName": "dest-name-1", - "DestinationType": "dest-type-1", - "AssociationType": "type-1", - } - ] - }, - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:artifact/src-arn-2", - "SourceName": "source-name-2", - "SourceType": "source-type-2", - "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2", - "DestinationName": "dest-name-2", - "DestinationType": "dest-type-2", - "AssociationType": "type-2", - } - ] - }, - ] + get_list_associations_side_effect(sagemaker_session) df = viz.show(trial_component_name=name) @@ -84,29 +57,9 @@ def test_trial_component_name(viz, sagemaker_session): TrialComponentName=name, ) - expected_calls = [ - unittest.mock.call( - DestinationArn="tc-arn", - ), - unittest.mock.call( - SourceArn="tc-arn", - ), - ] - assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + assert_list_associations_mock_calls(sagemaker_session) - expected_dataframe = pd.DataFrame.from_dict( - OrderedDict( - [ - ("Name/Source", ["source-name-1", "dest-name-2"]), - ("Direction", ["Input", "Output"]), - ("Type", ["source-type-1", "dest-type-2"]), - ("Association Type", ["type-1", "type-2"]), - ("Lineage Type", ["artifact", "artifact"]), - ] - ) - ) - - pd.testing.assert_frame_equal(expected_dataframe, df) + pd.testing.assert_frame_equal(get_expected_dataframe(), df) def test_model_package_arn(viz, sagemaker_session): @@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session): "ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}] } - sagemaker_session.sagemaker_client.list_associations.side_effect = [ - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:artifact/src-arn-1", - "SourceName": "source-name-1", - "SourceType": "source-type-1", - "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1", - "DestinationName": "dest-name-1", - "DestinationType": "dest-type-1", - "AssociationType": "type-1", - } - ] - }, - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:artifact/src-arn-2", - "SourceName": "source-name-2", - "SourceType": "source-type-2", - "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2", - "DestinationName": "dest-name-2", - "DestinationType": "dest-type-2", - "AssociationType": "type-2", - } - ] - }, - ] + get_list_associations_side_effect(sagemaker_session) df = viz.show(model_package_arn=name) @@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session): ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls - expected_dataframe = pd.DataFrame.from_dict( - OrderedDict( - [ - ("Name/Source", ["source-name-1", "dest-name-2"]), - ("Direction", ["Input", "Output"]), - ("Type", ["source-type-1", "dest-type-2"]), - ("Association Type", ["type-1", "type-2"]), - ("Lineage Type", ["artifact", "artifact"]), - ] - ) - ) - - pd.testing.assert_frame_equal(expected_dataframe, df) + pd.testing.assert_frame_equal(get_expected_dataframe(), df) def test_endpoint_arn(viz, sagemaker_session): @@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session): "ContextSummaries": [{"ContextArn": "context-arn"}] } - sagemaker_session.sagemaker_client.list_associations.side_effect = [ - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:context/src-arn-1", - "SourceName": "source-name-1", - "SourceType": "source-type-1", - "DestinationArn": "a:b:c:d:e:context/dest-arn-1", - "DestinationName": "dest-name-1", - "DestinationType": "dest-type-1", - "AssociationType": "type-1", - } - ] - }, - { - "AssociationSummaries": [ - { - "SourceArn": "a:b:c:d:e:context/src-arn-2", - "SourceName": "source-name-2", - "SourceType": "source-type-2", - "DestinationArn": "a:b:c:d:e:context/dest-arn-2", - "DestinationName": "dest-name-2", - "DestinationType": "dest-type-2", - "AssociationType": "type-2", - } - ] - }, - ] + get_list_associations_side_effect(sagemaker_session) df = viz.show(endpoint_arn=name) @@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session): ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls - expected_dataframe = pd.DataFrame.from_dict( - OrderedDict( - [ - ("Name/Source", ["source-name-1", "dest-name-2"]), - ("Direction", ["Input", "Output"]), - ("Type", ["source-type-1", "dest-type-2"]), - ("Association Type", ["type-1", "type-2"]), - ("Lineage Type", ["context", "context"]), - ] - ) + pd.testing.assert_frame_equal(get_expected_dataframe(), df) + + +def test_processing_job_pipeline_execution_step(viz, sagemaker_session): + + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}] + } + + get_list_associations_side_effect(sagemaker_session) + + step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}} + + df = viz.show(pipeline_execution_step=step) + + sagemaker_session.sagemaker_client.list_trial_components.assert_called_with( + SourceArn="proc-job-arn", ) - pd.testing.assert_frame_equal(expected_dataframe, df) + assert_list_associations_mock_calls(sagemaker_session) + pd.testing.assert_frame_equal(get_expected_dataframe(), df) -def test_processing_job_pipeline_execution_step(viz, sagemaker_session): + +def test_training_job_pipeline_execution_step(viz, sagemaker_session): sagemaker_session.sagemaker_client.list_trial_components.return_value = { "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}] } + get_list_associations_side_effect(sagemaker_session) + + step = {"Metadata": {"TrainingJob": {"Arn": "training-job-arn"}}} + + df = viz.show(pipeline_execution_step=step) + + sagemaker_session.sagemaker_client.list_trial_components.assert_called_with( + SourceArn="training-job-arn", + ) + + assert_list_associations_mock_calls(sagemaker_session) + + pd.testing.assert_frame_equal(get_expected_dataframe(), df) + + +def test_transform_job_pipeline_execution_step(viz, sagemaker_session): + + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}] + } + + get_list_associations_side_effect(sagemaker_session) + + step = {"Metadata": {"TransformJob": {"Arn": "transform-job-arn"}}} + + df = viz.show(pipeline_execution_step=step) + + sagemaker_session.sagemaker_client.list_trial_components.assert_called_with( + SourceArn="transform-job-arn", + ) + + assert_list_associations_mock_calls(sagemaker_session) + + pd.testing.assert_frame_equal(get_expected_dataframe(), df) + + +def get_list_associations_side_effect(sagemaker_session): + sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [ @@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session): }, ] - step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}} - - df = viz.show(pipeline_execution_step=step) - sagemaker_session.sagemaker_client.list_trial_components.assert_called_with( - SourceArn="proc-job-arn", - ) +def assert_list_associations_mock_calls(sagemaker_session): expected_calls = [ unittest.mock.call( @@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session): ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + +def get_expected_dataframe(): + expected_dataframe = pd.DataFrame.from_dict( OrderedDict( [ @@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session): ) ) - pd.testing.assert_frame_equal(expected_dataframe, df) + return expected_dataframe