Skip to content

fix visualizer for pipeline processing job steps #2160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/lineage/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step):
return None

metadata = pipeline_execution_step["Metadata"]
jobs = ["TrainingJob", "ProccessingJob", "TransformJob"]
jobs = ["TrainingJob", "ProcessingJob", "TransformJob"]
for job in jobs:
if job in metadata and metadata[job]:
job_arn = metadata[job]["Arn"]
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/sagemaker/lineage/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker.lineage import visualizer
import unittest.mock


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


@pytest.fixture
def viz(sagemaker_session):
return visualizer.LineageTableVisualizer(sagemaker_session)
6 changes: 0 additions & 6 deletions tests/unit/sagemaker/lineage/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
import datetime
import unittest.mock

import pytest
from sagemaker.lineage import action, _api_types


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


def test_create(sagemaker_session):
sagemaker_session.sagemaker_client.create_action.return_value = {
"ActionArn": "bazz",
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/sagemaker/lineage/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
import datetime
import unittest.mock

import pytest
from sagemaker.lineage import artifact, _api_types


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


def test_create(sagemaker_session):
sagemaker_session.sagemaker_client.create_artifact.return_value = {
"ArtifactArn": "bazz",
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/sagemaker/lineage/test_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
import datetime
import unittest.mock

import pytest
from sagemaker.lineage import association, _api_types


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


def test_create(sagemaker_session):
sagemaker_session.sagemaker_client.add_association.return_value = {
"AssociationArn": "bazz",
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/sagemaker/lineage/test_endpoint_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

import unittest.mock

import pytest
from sagemaker.lineage import context, _api_types


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


def test_models(sagemaker_session):
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")

Expand Down
6 changes: 0 additions & 6 deletions tests/unit/sagemaker/lineage/test_model_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

import unittest.mock

import pytest
from sagemaker.lineage import artifact, _api_types


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


def test_trained_models(sagemaker_session):
model_artifact_obj = artifact.ModelArtifact(
sagemaker_session, artifact_arn="model-artifact-arn"
Expand Down
102 changes: 79 additions & 23 deletions tests/unit/sagemaker/lineage/test_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand All @@ -14,33 +14,21 @@

import unittest.mock

import pytest
from sagemaker.lineage import visualizer
import pandas as pd
from collections import OrderedDict


@pytest.fixture
def sagemaker_session():
return unittest.mock.Mock()


@pytest.fixture
def vizualizer(sagemaker_session):
return visualizer.LineageTableVisualizer(sagemaker_session)


def test_friendly_name_short_uri(vizualizer, sagemaker_session):
def test_friendly_name_short_uri(viz, sagemaker_session):
uri = "s3://f-069083975568/train.txt"
arn = "test_arn"
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
"Source": {"SourceUri": uri, "SourceTypes": ""}
}
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
assert uri == actual_name


def test_friendly_name_long_uri(vizualizer, sagemaker_session):
def test_friendly_name_long_uri(viz, sagemaker_session):
uri = (
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
"preprocessed-data/tuning_data/train.txt"
Expand All @@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
"Source": {"SourceUri": uri, "SourceTypes": ""}
}
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
assert expected_name == actual_name


def test_trial_component_name(sagemaker_session, vizualizer):
def test_trial_component_name(viz, sagemaker_session):
name = "tc-name"

sagemaker_session.sagemaker_client.describe_trial_component.return_value = {
Expand Down Expand Up @@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
},
]

df = vizualizer.show(trial_component_name=name)
df = viz.show(trial_component_name=name)

sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
TrialComponentName=name,
Expand Down Expand Up @@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
pd.testing.assert_frame_equal(expected_dataframe, df)


def test_model_package_arn(sagemaker_session, vizualizer):
def test_model_package_arn(viz, sagemaker_session):
name = "model_package_arn"

sagemaker_session.sagemaker_client.list_artifacts.return_value = {
Expand Down Expand Up @@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
},
]

df = vizualizer.show(model_package_arn=name)
df = viz.show(model_package_arn=name)

sagemaker_session.sagemaker_client.list_artifacts.assert_called_with(
SourceUri=name,
Expand Down Expand Up @@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
pd.testing.assert_frame_equal(expected_dataframe, df)


def test_endpoint_arn(sagemaker_session, vizualizer):
def test_endpoint_arn(viz, sagemaker_session):
name = "endpoint_arn"

sagemaker_session.sagemaker_client.list_contexts.return_value = {
Expand Down Expand Up @@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
},
]

df = vizualizer.show(endpoint_arn=name)
df = viz.show(endpoint_arn=name)

sagemaker_session.sagemaker_client.list_contexts.assert_called_with(
SourceUri=name,
Expand Down Expand Up @@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
)

pd.testing.assert_frame_equal(expected_dataframe, df)


def test_processing_job_pipeline_execution_step(viz, sagemaker_session):

sagemaker_session.sagemaker_client.list_trial_components.return_value = {
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
}

sagemaker_session.sagemaker_client.list_associations.side_effect = [
{
"AssociationSummaries": [
{
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
"SourceName": "source-name-1",
"SourceType": "source-type-1",
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
"DestinationName": "dest-name-1",
"DestinationType": "dest-type-1",
"AssociationType": "type-1",
}
]
},
{
"AssociationSummaries": [
{
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
"SourceName": "source-name-2",
"SourceType": "source-type-2",
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
"DestinationName": "dest-name-2",
"DestinationType": "dest-type-2",
"AssociationType": "type-2",
}
]
},
]

step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}

df = viz.show(pipeline_execution_step=step)

sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
SourceArn="proc-job-arn",
)

expected_calls = [
unittest.mock.call(
DestinationArn="tc-arn",
),
unittest.mock.call(
SourceArn="tc-arn",
),
]
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls

expected_dataframe = pd.DataFrame.from_dict(
OrderedDict(
[
("Name/Source", ["source-name-1", "dest-name-2"]),
("Direction", ["Input", "Output"]),
("Type", ["source-type-1", "dest-type-2"]),
("Association Type", ["type-1", "type-2"]),
("Lineage Type", ["artifact", "artifact"]),
]
)
)

pd.testing.assert_frame_equal(expected_dataframe, df)