Skip to content

Commit 65979b7

Browse files
fix: visualizer for pipeline processing job steps (#2160)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent 9b86920 commit 65979b7

File tree

8 files changed

+107
-54
lines changed

8 files changed

+107
-54
lines changed

src/sagemaker/lineage/visualizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step):
105105
return None
106106

107107
metadata = pipeline_execution_step["Metadata"]
108-
jobs = ["TrainingJob", "ProccessingJob", "TransformJob"]
108+
jobs = ["TrainingJob", "ProcessingJob", "TransformJob"]
109109
for job in jobs:
110110
if job in metadata and metadata[job]:
111111
job_arn = metadata[job]["Arn"]
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.lineage import visualizer
17+
import unittest.mock
18+
19+
20+
@pytest.fixture
21+
def sagemaker_session():
22+
return unittest.mock.Mock()
23+
24+
25+
@pytest.fixture
26+
def viz(sagemaker_session):
27+
return visualizer.LineageTableVisualizer(sagemaker_session)

tests/unit/sagemaker/lineage/test_action.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import action, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_action.return_value = {
2923
"ActionArn": "bazz",

tests/unit/sagemaker/lineage/test_artifact.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import artifact, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_artifact.return_value = {
2923
"ArtifactArn": "bazz",

tests/unit/sagemaker/lineage/test_association.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import association, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.add_association.return_value = {
2923
"AssociationArn": "bazz",

tests/unit/sagemaker/lineage/test_endpoint_context.py

-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import context, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_models(sagemaker_session):
2721
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")
2822

tests/unit/sagemaker/lineage/test_model_artifact.py

-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import artifact, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_trained_models(sagemaker_session):
2721
model_artifact_obj = artifact.ModelArtifact(
2822
sagemaker_session, artifact_arn="model-artifact-arn"

tests/unit/sagemaker/lineage/test_visualizer.py

+79-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -14,33 +14,21 @@
1414

1515
import unittest.mock
1616

17-
import pytest
18-
from sagemaker.lineage import visualizer
1917
import pandas as pd
2018
from collections import OrderedDict
2119

2220

23-
@pytest.fixture
24-
def sagemaker_session():
25-
return unittest.mock.Mock()
26-
27-
28-
@pytest.fixture
29-
def vizualizer(sagemaker_session):
30-
return visualizer.LineageTableVisualizer(sagemaker_session)
31-
32-
33-
def test_friendly_name_short_uri(vizualizer, sagemaker_session):
21+
def test_friendly_name_short_uri(viz, sagemaker_session):
3422
uri = "s3://f-069083975568/train.txt"
3523
arn = "test_arn"
3624
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
3725
"Source": {"SourceUri": uri, "SourceTypes": ""}
3826
}
39-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
27+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
4028
assert uri == actual_name
4129

4230

43-
def test_friendly_name_long_uri(vizualizer, sagemaker_session):
31+
def test_friendly_name_long_uri(viz, sagemaker_session):
4432
uri = (
4533
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
4634
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
4937
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
5038
"Source": {"SourceUri": uri, "SourceTypes": ""}
5139
}
52-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
40+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
5341
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
5442
assert expected_name == actual_name
5543

5644

57-
def test_trial_component_name(sagemaker_session, vizualizer):
45+
def test_trial_component_name(viz, sagemaker_session):
5846
name = "tc-name"
5947

6048
sagemaker_session.sagemaker_client.describe_trial_component.return_value = {
@@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
9078
},
9179
]
9280

93-
df = vizualizer.show(trial_component_name=name)
81+
df = viz.show(trial_component_name=name)
9482

9583
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
9684
TrialComponentName=name,
@@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
121109
pd.testing.assert_frame_equal(expected_dataframe, df)
122110

123111

124-
def test_model_package_arn(sagemaker_session, vizualizer):
112+
def test_model_package_arn(viz, sagemaker_session):
125113
name = "model_package_arn"
126114

127115
sagemaker_session.sagemaker_client.list_artifacts.return_value = {
@@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
157145
},
158146
]
159147

160-
df = vizualizer.show(model_package_arn=name)
148+
df = viz.show(model_package_arn=name)
161149

162150
sagemaker_session.sagemaker_client.list_artifacts.assert_called_with(
163151
SourceUri=name,
@@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
188176
pd.testing.assert_frame_equal(expected_dataframe, df)
189177

190178

191-
def test_endpoint_arn(sagemaker_session, vizualizer):
179+
def test_endpoint_arn(viz, sagemaker_session):
192180
name = "endpoint_arn"
193181

194182
sagemaker_session.sagemaker_client.list_contexts.return_value = {
@@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
224212
},
225213
]
226214

227-
df = vizualizer.show(endpoint_arn=name)
215+
df = viz.show(endpoint_arn=name)
228216

229217
sagemaker_session.sagemaker_client.list_contexts.assert_called_with(
230218
SourceUri=name,
@@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
253241
)
254242

255243
pd.testing.assert_frame_equal(expected_dataframe, df)
244+
245+
246+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
247+
248+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250+
}
251+
252+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253+
{
254+
"AssociationSummaries": [
255+
{
256+
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
257+
"SourceName": "source-name-1",
258+
"SourceType": "source-type-1",
259+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
260+
"DestinationName": "dest-name-1",
261+
"DestinationType": "dest-type-1",
262+
"AssociationType": "type-1",
263+
}
264+
]
265+
},
266+
{
267+
"AssociationSummaries": [
268+
{
269+
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
270+
"SourceName": "source-name-2",
271+
"SourceType": "source-type-2",
272+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
273+
"DestinationName": "dest-name-2",
274+
"DestinationType": "dest-type-2",
275+
"AssociationType": "type-2",
276+
}
277+
]
278+
},
279+
]
280+
281+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282+
283+
df = viz.show(pipeline_execution_step=step)
284+
285+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286+
SourceArn="proc-job-arn",
287+
)
288+
289+
expected_calls = [
290+
unittest.mock.call(
291+
DestinationArn="tc-arn",
292+
),
293+
unittest.mock.call(
294+
SourceArn="tc-arn",
295+
),
296+
]
297+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298+
299+
expected_dataframe = pd.DataFrame.from_dict(
300+
OrderedDict(
301+
[
302+
("Name/Source", ["source-name-1", "dest-name-2"]),
303+
("Direction", ["Input", "Output"]),
304+
("Type", ["source-type-1", "dest-type-2"]),
305+
("Association Type", ["type-1", "type-2"]),
306+
("Lineage Type", ["artifact", "artifact"]),
307+
]
308+
)
309+
)
310+
311+
pd.testing.assert_frame_equal(expected_dataframe, df)

0 commit comments

Comments
 (0)