Skip to content

Commit 827ccf6

Browse files
committed
fix visualizer for pipeline processing job steps
1 parent 6948f17 commit 827ccf6

File tree

8 files changed

+102
-48
lines changed

8 files changed

+102
-48
lines changed

src/sagemaker/lineage/visualizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step):
105105
return None
106106

107107
metadata = pipeline_execution_step["Metadata"]
108-
jobs = ["TrainingJob", "ProccessingJob", "TransformJob"]
108+
jobs = ["TrainingJob", "ProcessingJob", "TransformJob"]
109109
for job in jobs:
110110
if job in metadata and metadata[job]:
111111
job_arn = metadata[job]["Arn"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.lineage import visualizer
17+
import unittest.mock
18+
19+
20+
@pytest.fixture
21+
def sagemaker_session():
22+
return unittest.mock.Mock()
23+
24+
25+
@pytest.fixture
26+
def viz(sagemaker_session):
27+
return visualizer.LineageTableVisualizer(sagemaker_session)

tests/unit/sagemaker/lineage/test_action.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import action, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_action.return_value = {
2923
"ActionArn": "bazz",

tests/unit/sagemaker/lineage/test_artifact.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import artifact, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_artifact.return_value = {
2923
"ArtifactArn": "bazz",

tests/unit/sagemaker/lineage/test_association.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import association, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.add_association.return_value = {
2923
"AssociationArn": "bazz",

tests/unit/sagemaker/lineage/test_endpoint_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import context, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_models(sagemaker_session):
2721
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")
2822

tests/unit/sagemaker/lineage/test_model_artifact.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import artifact, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_trained_models(sagemaker_session):
2721
model_artifact_obj = artifact.ModelArtifact(
2822
sagemaker_session, artifact_arn="model-artifact-arn"

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,22 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import visualizer
1918
import pandas as pd
2019
from collections import OrderedDict
2120

2221

23-
@pytest.fixture
24-
def sagemaker_session():
25-
return unittest.mock.Mock()
26-
27-
28-
@pytest.fixture
29-
def vizualizer(sagemaker_session):
30-
return visualizer.LineageTableVisualizer(sagemaker_session)
31-
32-
33-
def test_friendly_name_short_uri(vizualizer, sagemaker_session):
22+
def test_friendly_name_short_uri(viz, sagemaker_session):
3423
uri = "s3://f-069083975568/train.txt"
3524
arn = "test_arn"
3625
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
3726
"Source": {"SourceUri": uri, "SourceTypes": ""}
3827
}
39-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
28+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
4029
assert uri == actual_name
4130

4231

43-
def test_friendly_name_long_uri(vizualizer, sagemaker_session):
32+
def test_friendly_name_long_uri(viz, sagemaker_session):
4433
uri = (
4534
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
4635
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +38,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
4938
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
5039
"Source": {"SourceUri": uri, "SourceTypes": ""}
5140
}
52-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
41+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
5342
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
5443
assert expected_name == actual_name
5544

5645

57-
def test_trial_component_name(sagemaker_session, vizualizer):
46+
def test_trial_component_name(viz, sagemaker_session):
5847
name = "tc-name"
5948

6049
sagemaker_session.sagemaker_client.describe_trial_component.return_value = {
@@ -90,7 +79,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
9079
},
9180
]
9281

93-
df = vizualizer.show(trial_component_name=name)
82+
df = viz.show(trial_component_name=name)
9483

9584
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
9685
TrialComponentName=name,
@@ -119,3 +108,71 @@ def test_trial_component_name(sagemaker_session, vizualizer):
119108
)
120109

121110
pd.testing.assert_frame_equal(expected_dataframe, df)
111+
112+
113+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
114+
115+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
116+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
117+
}
118+
119+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
120+
{
121+
"AssociationSummaries": [
122+
{
123+
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
124+
"SourceName": "source-name-1",
125+
"SourceType": "source-type-1",
126+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
127+
"DestinationName": "dest-name-1",
128+
"DestinationType": "dest-type-1",
129+
"AssociationType": "type-1",
130+
}
131+
]
132+
},
133+
{
134+
"AssociationSummaries": [
135+
{
136+
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
137+
"SourceName": "source-name-2",
138+
"SourceType": "source-type-2",
139+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
140+
"DestinationName": "dest-name-2",
141+
"DestinationType": "dest-type-2",
142+
"AssociationType": "type-2",
143+
}
144+
]
145+
},
146+
]
147+
148+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
149+
150+
df = viz.show(pipeline_execution_step=step)
151+
152+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
153+
SourceArn="proc-job-arn",
154+
)
155+
156+
expected_calls = [
157+
unittest.mock.call(
158+
DestinationArn="tc-arn",
159+
),
160+
unittest.mock.call(
161+
SourceArn="tc-arn",
162+
),
163+
]
164+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
165+
166+
expected_dataframe = pd.DataFrame.from_dict(
167+
OrderedDict(
168+
[
169+
("Name/Source", ["source-name-1", "dest-name-2"]),
170+
("Direction", ["Input", "Output"]),
171+
("Type", ["source-type-1", "dest-type-2"]),
172+
("Association Type", ["type-1", "type-2"]),
173+
("Lineage Type", ["artifact", "artifact"]),
174+
]
175+
)
176+
)
177+
178+
pd.testing.assert_frame_equal(expected_dataframe, df)

0 commit comments

Comments
 (0)