Skip to content

Commit 91563fb

Browse files
authored
Merge branch 'master' into mask-creds-local-mode
2 parents a7df911 + 5dc9e58 commit 91563fb

File tree

9 files changed

+112
-58
lines changed

9 files changed

+112
-58
lines changed

doc/frameworks/mxnet/using_mxnet.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ It loads the model parameters from a ``model.params`` file in the SageMaker mode
377377
return net
378378
379379
MXNet on Amazon SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`__, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance.
380-
In order to load and serve your MXNet model through Amazon Elastic Inference, the MXNet context passed to your MXNet Symbol or Module object within your ``model_fn`` needs to be set to ``eia``, as shown `here <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html#ei-mxnet>`__.
380+
In order to load and serve your MXNet model through Amazon Elastic Inference, import the ``eimx`` Python package and make one change in the code to partition your model and optimize it for the ``EIA`` back end, as shown `here <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html#ei-mxnet>`__.
381381

382382
Based on the example above, the following code-snippet shows an example custom ``model_fn`` implementation, which enables loading and serving our MXNet model through Amazon Elastic Inference.
383383

@@ -392,11 +392,12 @@ Based on the example above, the following code-snippet shows an example custom `
392392
Returns:
393393
mxnet.gluon.nn.Block: a Gluon network (for this example)
394394
"""
395-
net = models.get_model('resnet34_v2', ctx=mx.eia(), pretrained=False, classes=10)
396-
net.load_params('%s/model.params' % model_dir, ctx=mx.eia())
395+
net = models.get_model('resnet34_v2', ctx=mx.cpu(), pretrained=False, classes=10)
396+
net.load_params('%s/model.params' % model_dir, ctx=mx.cpu())
397+
net.hybridize(backend='EIA', static_alloc=True, static_shape=True)
397398
return net
398399
399-
The `default_model_fn <https://github.com/aws/sagemaker-mxnet-container/pull/55/files#diff-aabf018d906ed282a3c738377d19a8deR71>`__ loads and serve your model through Elastic Inference, if applicable, within the Amazon SageMaker MXNet containers.
400+
If you are using MXNet 1.5.1 and earlier, the `default_model_fn <https://github.com/aws/sagemaker-mxnet-container/pull/55/files#diff-aabf018d906ed282a3c738377d19a8deR71>`__ loads and serve your model through Elastic Inference, if applicable, within the Amazon SageMaker MXNet containers.
400401

401402
For more information on how to enable MXNet to interact with Amazon Elastic Inference, see `Use Elastic Inference with MXNet <https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-mxnet-elastic-inference.html>`__.
402403

src/sagemaker/lineage/visualizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step):
105105
return None
106106

107107
metadata = pipeline_execution_step["Metadata"]
108-
jobs = ["TrainingJob", "ProccessingJob", "TransformJob"]
108+
jobs = ["TrainingJob", "ProcessingJob", "TransformJob"]
109109
for job in jobs:
110110
if job in metadata and metadata[job]:
111111
job_arn = metadata[job]["Arn"]
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.lineage import visualizer
17+
import unittest.mock
18+
19+
20+
@pytest.fixture
21+
def sagemaker_session():
22+
return unittest.mock.Mock()
23+
24+
25+
@pytest.fixture
26+
def viz(sagemaker_session):
27+
return visualizer.LineageTableVisualizer(sagemaker_session)

tests/unit/sagemaker/lineage/test_action.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import action, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_action.return_value = {
2923
"ActionArn": "bazz",

tests/unit/sagemaker/lineage/test_artifact.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import artifact, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_artifact.return_value = {
2923
"ArtifactArn": "bazz",

tests/unit/sagemaker/lineage/test_association.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import association, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.add_association.return_value = {
2923
"AssociationArn": "bazz",

tests/unit/sagemaker/lineage/test_endpoint_context.py

-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import context, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_models(sagemaker_session):
2721
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")
2822

tests/unit/sagemaker/lineage/test_model_artifact.py

-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import artifact, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_trained_models(sagemaker_session):
2721
model_artifact_obj = artifact.ModelArtifact(
2822
sagemaker_session, artifact_arn="model-artifact-arn"

tests/unit/sagemaker/lineage/test_visualizer.py

+79-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -14,33 +14,21 @@
1414

1515
import unittest.mock
1616

17-
import pytest
18-
from sagemaker.lineage import visualizer
1917
import pandas as pd
2018
from collections import OrderedDict
2119

2220

23-
@pytest.fixture
24-
def sagemaker_session():
25-
return unittest.mock.Mock()
26-
27-
28-
@pytest.fixture
29-
def vizualizer(sagemaker_session):
30-
return visualizer.LineageTableVisualizer(sagemaker_session)
31-
32-
33-
def test_friendly_name_short_uri(vizualizer, sagemaker_session):
21+
def test_friendly_name_short_uri(viz, sagemaker_session):
3422
uri = "s3://f-069083975568/train.txt"
3523
arn = "test_arn"
3624
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
3725
"Source": {"SourceUri": uri, "SourceTypes": ""}
3826
}
39-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
27+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
4028
assert uri == actual_name
4129

4230

43-
def test_friendly_name_long_uri(vizualizer, sagemaker_session):
31+
def test_friendly_name_long_uri(viz, sagemaker_session):
4432
uri = (
4533
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
4634
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
4937
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
5038
"Source": {"SourceUri": uri, "SourceTypes": ""}
5139
}
52-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
40+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
5341
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
5442
assert expected_name == actual_name
5543

5644

57-
def test_trial_component_name(sagemaker_session, vizualizer):
45+
def test_trial_component_name(viz, sagemaker_session):
5846
name = "tc-name"
5947

6048
sagemaker_session.sagemaker_client.describe_trial_component.return_value = {
@@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
9078
},
9179
]
9280

93-
df = vizualizer.show(trial_component_name=name)
81+
df = viz.show(trial_component_name=name)
9482

9583
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
9684
TrialComponentName=name,
@@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
121109
pd.testing.assert_frame_equal(expected_dataframe, df)
122110

123111

124-
def test_model_package_arn(sagemaker_session, vizualizer):
112+
def test_model_package_arn(viz, sagemaker_session):
125113
name = "model_package_arn"
126114

127115
sagemaker_session.sagemaker_client.list_artifacts.return_value = {
@@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
157145
},
158146
]
159147

160-
df = vizualizer.show(model_package_arn=name)
148+
df = viz.show(model_package_arn=name)
161149

162150
sagemaker_session.sagemaker_client.list_artifacts.assert_called_with(
163151
SourceUri=name,
@@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
188176
pd.testing.assert_frame_equal(expected_dataframe, df)
189177

190178

191-
def test_endpoint_arn(sagemaker_session, vizualizer):
179+
def test_endpoint_arn(viz, sagemaker_session):
192180
name = "endpoint_arn"
193181

194182
sagemaker_session.sagemaker_client.list_contexts.return_value = {
@@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
224212
},
225213
]
226214

227-
df = vizualizer.show(endpoint_arn=name)
215+
df = viz.show(endpoint_arn=name)
228216

229217
sagemaker_session.sagemaker_client.list_contexts.assert_called_with(
230218
SourceUri=name,
@@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
253241
)
254242

255243
pd.testing.assert_frame_equal(expected_dataframe, df)
244+
245+
246+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
247+
248+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250+
}
251+
252+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253+
{
254+
"AssociationSummaries": [
255+
{
256+
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
257+
"SourceName": "source-name-1",
258+
"SourceType": "source-type-1",
259+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
260+
"DestinationName": "dest-name-1",
261+
"DestinationType": "dest-type-1",
262+
"AssociationType": "type-1",
263+
}
264+
]
265+
},
266+
{
267+
"AssociationSummaries": [
268+
{
269+
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
270+
"SourceName": "source-name-2",
271+
"SourceType": "source-type-2",
272+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
273+
"DestinationName": "dest-name-2",
274+
"DestinationType": "dest-type-2",
275+
"AssociationType": "type-2",
276+
}
277+
]
278+
},
279+
]
280+
281+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282+
283+
df = viz.show(pipeline_execution_step=step)
284+
285+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286+
SourceArn="proc-job-arn",
287+
)
288+
289+
expected_calls = [
290+
unittest.mock.call(
291+
DestinationArn="tc-arn",
292+
),
293+
unittest.mock.call(
294+
SourceArn="tc-arn",
295+
),
296+
]
297+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298+
299+
expected_dataframe = pd.DataFrame.from_dict(
300+
OrderedDict(
301+
[
302+
("Name/Source", ["source-name-1", "dest-name-2"]),
303+
("Direction", ["Input", "Output"]),
304+
("Type", ["source-type-1", "dest-type-2"]),
305+
("Association Type", ["type-1", "type-2"]),
306+
("Lineage Type", ["artifact", "artifact"]),
307+
]
308+
)
309+
)
310+
311+
pd.testing.assert_frame_equal(expected_dataframe, df)

0 commit comments

Comments
 (0)