Skip to content

Commit fe85356

Browse files
mnuyensChoiByungWook
authored andcommitted
feature: add edge packaging job support (#507)
1 parent 37af818 commit fe85356

File tree

5 files changed

+340
-0
lines changed

5 files changed

+340
-0
lines changed

src/sagemaker/model.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def __init__(
101101
self.sagemaker_session = sagemaker_session
102102
self.endpoint_name = None
103103
self._is_compiled_model = False
104+
self._compilation_job_name = None
105+
self._is_edge_packaged_model = False
104106
self._enable_network_isolation = enable_network_isolation
105107
self.model_kms_key = model_kms_key
106108

@@ -336,6 +338,50 @@ def _get_framework_version(self):
336338
"""Placeholder docstring"""
337339
return getattr(self, "framework_version", None)
338340

341+
def _edge_packaging_job_config(
342+
self,
343+
output_path,
344+
role,
345+
model_name,
346+
model_version,
347+
packaging_job_name,
348+
compilation_job_name,
349+
resource_key,
350+
s3_kms_key,
351+
tags,
352+
):
353+
"""Creates a request object for a packaging job.
354+
355+
Args:
356+
output_path (str): where in S3 to store the output of the job
357+
role (str): what role to use when executing the job
358+
packaging_job_name (str): what to name the packaging job
359+
compilation_job_name (str): what compilation job to source the model from
360+
resource_key (str): the kms key to encrypt the disk with
361+
s3_kms_key (str): the kms key to encrypt the output with
362+
tags (list[dict]): List of tags for labeling an edge packaging job. For
363+
more, see
364+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
365+
Returns:
366+
dict: the request object to use when creating a packaging job
367+
"""
368+
output_model_config = {
369+
"S3OutputLocation": output_path,
370+
}
371+
if s3_kms_key is not None:
372+
output_model_config["KmsKeyId"] = s3_kms_key
373+
374+
return {
375+
"output_model_config": output_model_config,
376+
"role": role,
377+
"tags": tags,
378+
"model_name": model_name,
379+
"model_version": model_version,
380+
"job_name": packaging_job_name,
381+
"compilation_job_name": compilation_job_name,
382+
"resource_key": resource_key,
383+
}
384+
339385
def _compilation_job_config(
340386
self,
341387
target_instance_type,
@@ -438,6 +484,64 @@ def _compilation_image_uri(self, region, target_instance_type, framework, framew
438484
version=framework_version,
439485
)
440486

487+
def package_for_edge(
488+
self,
489+
output_path,
490+
model_name,
491+
model_version,
492+
role=None,
493+
job_name=None,
494+
resource_key=None,
495+
s3_kms_key=None,
496+
tags=None,
497+
):
498+
"""Package this ``Model`` with SageMaker Edge.
499+
500+
Creates a new EdgePackagingJob and wait for it to finish.
501+
model_data will now point to the packaged artifacts.
502+
503+
Args:
504+
output_path (str): Specifies where to store the packaged model
505+
role (str): Execution role
506+
model_name (str): the name to attach to the model metadata
507+
model_version (str): the version to attach to the model metadata
508+
job_name (str): The name of the edge packaging job
509+
resource_key (str): the kms key to encrypt the disk with
510+
s3_kms_key (str): the kms key to encrypt the output with
511+
tags (list[dict]): List of tags for labeling an edge packaging job. For
512+
more, see
513+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
514+
515+
Returns:
516+
sagemaker.model.Model: A SageMaker ``Model`` object. See
517+
:func:`~sagemaker.model.Model` for full details.
518+
"""
519+
if self._compilation_job_name is None:
520+
raise ValueError("You must first compile this model")
521+
if job_name is None:
522+
job_name = f"packaging{self._compilation_job_name[11:]}"
523+
if role is None:
524+
role = self.sagemaker_session.expand_role(role)
525+
526+
self._init_sagemaker_session_if_does_not_exist(None)
527+
config = self._edge_packaging_job_config(
528+
output_path,
529+
role,
530+
model_name,
531+
model_version,
532+
job_name,
533+
self._compilation_job_name,
534+
resource_key,
535+
s3_kms_key,
536+
tags,
537+
)
538+
self.sagemaker_session.package_model_for_edge(**config)
539+
job_status = self.sagemaker_session.wait_for_edge_packaging_job(job_name)
540+
self.model_data = job_status["ModelArtifact"]
541+
self._is_edge_packaged_model = True
542+
543+
return self
544+
441545
def compile(
442546
self,
443547
target_instance_family,
@@ -557,6 +661,8 @@ def compile(
557661
"supported for deployment via SageMaker. Please deploy the model manually."
558662
)
559663

664+
self._compilation_job_name = job_name
665+
560666
return self
561667

562668
def deploy(

src/sagemaker/session.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,48 @@ def compile_model(
18021802
LOGGER.info("Creating compilation-job with name: %s", job_name)
18031803
self.sagemaker_client.create_compilation_job(**compilation_job_request)
18041804

1805+
def package_model_for_edge(
1806+
self,
1807+
output_model_config,
1808+
role,
1809+
job_name,
1810+
compilation_job_name,
1811+
model_name,
1812+
model_version,
1813+
resource_key,
1814+
tags,
1815+
):
1816+
"""Create an Amazon SageMaker Edge packaging job.
1817+
1818+
Args:
1819+
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
1820+
SageMaker Edge to save the results of edge packaging job
1821+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Edge
1822+
edge packaging jobs use this role to access model artifacts. You must grant
1823+
sufficient permissions to this role.
1824+
job_name (str): Name of the edge packaging job being created.
1825+
compilation_job_name (str): Name of the compilation job being created.
1826+
resource_key (str): KMS key to encrypt the disk used to package the job
1827+
tags (list[dict]): List of tags for labeling a compile model job. For more, see
1828+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1829+
"""
1830+
edge_packaging_job_request = {
1831+
"OutputConfig": output_model_config,
1832+
"RoleArn": role,
1833+
"ModelName": model_name,
1834+
"ModelVersion": model_version,
1835+
"EdgePackagingJobName": job_name,
1836+
"CompilationJobName": compilation_job_name,
1837+
}
1838+
1839+
if tags is not None:
1840+
edge_packaging_job_request["Tags"] = tags
1841+
if resource_key is not None:
1842+
edge_packaging_job_request["ResourceKey"] = (resource_key,)
1843+
1844+
LOGGER.info("Creating edge-packaging-job with name: %s", job_name)
1845+
self.sagemaker_client.create_edge_packaging_job(**edge_packaging_job_request)
1846+
18051847
def tune( # noqa: C901
18061848
self,
18071849
job_name,
@@ -3108,6 +3150,23 @@ def wait_for_compilation_job(self, job, poll=5):
31083150
self._check_job_status(job, desc, "CompilationJobStatus")
31093151
return desc
31103152

3153+
def wait_for_edge_packaging_job(self, job, poll=5):
3154+
"""Wait for an Amazon SageMaker Edge packaging job to complete.
3155+
3156+
Args:
3157+
job (str): Name of the edge packaging job to wait for.
3158+
poll (int): Polling interval in seconds (default: 5).
3159+
3160+
Returns:
3161+
(dict): Return value from the ``DescribeEdgePackagingJob`` API.
3162+
3163+
Raises:
3164+
exceptions.UnexpectedStatusException: If the compilation job fails.
3165+
"""
3166+
desc = _wait_until(lambda: _edge_packaging_job_status(self.sagemaker_client, job), poll)
3167+
self._check_job_status(job, desc, "EdgePackagingJobStatus")
3168+
return desc
3169+
31113170
def wait_for_tuning_job(self, job, poll=5):
31123171
"""Wait for an Amazon SageMaker hyperparameter tuning job to complete.
31133172
@@ -4186,6 +4245,38 @@ def _processing_job_status(sagemaker_client, job_name):
41864245
return desc
41874246

41884247

4248+
def _edge_packaging_job_status(sagemaker_client, job_name):
4249+
"""Process the current status of a packaging job
4250+
4251+
Args:
4252+
sagemaker_client (boto3.client.sagemaker): a sagemaker client
4253+
job_name (str): the name of the job to inspec
4254+
4255+
Returns:
4256+
Dict: the status of the edge packaging job
4257+
"""
4258+
package_status_codes = {
4259+
"Completed": "!",
4260+
"InProgress": ".",
4261+
"Failed": "*",
4262+
"Stopped": "s",
4263+
"Stopping": "_",
4264+
}
4265+
in_progress_statuses = ["InProgress", "Stopping", "Starting"]
4266+
4267+
desc = sagemaker_client.describe_edge_packaging_job(EdgePackagingJobName=job_name)
4268+
status = desc["EdgePackagingJobStatus"]
4269+
4270+
status = _STATUS_CODE_TABLE.get(status, status)
4271+
print(package_status_codes.get(status, "?"), end="")
4272+
sys.stdout.flush()
4273+
4274+
if status in in_progress_statuses:
4275+
return None
4276+
4277+
return desc
4278+
4279+
41894280
def _compilation_job_status(sagemaker_client, job_name):
41904281
"""Placeholder docstring"""
41914282
compile_status_codes = {

tests/integ/test_edge.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2020-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import pytest
18+
19+
from sagemaker.mxnet.estimator import MXNet
20+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
21+
from tests.integ.timeout import timeout
22+
23+
24+
@pytest.fixture(scope="module")
25+
def mxnet_training_job(
26+
sagemaker_session,
27+
cpu_instance_type,
28+
mxnet_training_latest_version,
29+
mxnet_training_latest_py_version,
30+
):
31+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
32+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_neo.py")
33+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
34+
35+
mx = MXNet(
36+
entry_point=script_path,
37+
role="SageMakerRole",
38+
framework_version=mxnet_training_latest_version,
39+
py_version=mxnet_training_latest_py_version,
40+
instance_count=1,
41+
instance_type=cpu_instance_type,
42+
sagemaker_session=sagemaker_session,
43+
)
44+
45+
train_input = mx.sagemaker_session.upload_data(
46+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
47+
)
48+
test_input = mx.sagemaker_session.upload_data(
49+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
50+
)
51+
52+
mx.fit({"train": train_input, "test": test_input})
53+
return mx.latest_training_job.name
54+
55+
56+
def test_edge_packaging_job(mxnet_training_job, sagemaker_session):
57+
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
58+
model = estimator.compile_model(
59+
target_instance_family="rasp3b",
60+
input_shape={"data": [1, 1, 28, 28], "softmax_label": [1]},
61+
output_path=estimator.output_path,
62+
)
63+
64+
model.package_for_edge(
65+
output_path=estimator.output_path,
66+
role=estimator.role,
67+
model_name="sdk-test-model",
68+
model_version="1.0",
69+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2020-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock
17+
18+
from sagemaker.model import Model
19+
20+
MODEL_DATA = "s3://bucket/model.tar.gz"
21+
MODEL_IMAGE = "mi"
22+
23+
REGION = "us-west-2"
24+
25+
DESCRIBE_EDGE_PACKAGING_JOB_RESPONSE = {
26+
"EdgePackagingJobStatus": "Completed",
27+
"ModelArtifact": "s3://output-path/package-model.tar.gz",
28+
}
29+
30+
31+
@pytest.fixture
32+
def sagemaker_session():
33+
return Mock(boto_region_name=REGION)
34+
35+
36+
def _create_model(sagemaker_session=None):
37+
model = Model(MODEL_IMAGE, MODEL_DATA, role="role", sagemaker_session=sagemaker_session)
38+
model._compilation_job_name = "compilation-test-name"
39+
model._is_compiled_model = True
40+
return model
41+
42+
43+
def test_package_model(sagemaker_session):
44+
sagemaker_session.wait_for_edge_packaging_job = Mock(
45+
return_value=DESCRIBE_EDGE_PACKAGING_JOB_RESPONSE
46+
)
47+
model = _create_model(sagemaker_session)
48+
model.package_for_edge(
49+
output_path="s3://output",
50+
role="role",
51+
model_name="model_name",
52+
model_version="1.0",
53+
)
54+
assert model._is_edge_packaged_model is True
55+
56+
57+
def test_package_validates_compiled():
58+
sagemaker_session.wait_for_edge_packaging_job = Mock(
59+
return_value=DESCRIBE_EDGE_PACKAGING_JOB_RESPONSE
60+
)
61+
sagemaker_session.package_model_for_edge = Mock()
62+
model = _create_model()
63+
model._compilation_job_name = None
64+
65+
with pytest.raises(ValueError) as e:
66+
model.package_for_edge(
67+
output_path="s3://output",
68+
role="role",
69+
model_name="model_name",
70+
model_version="1.0",
71+
)
72+
73+
assert "You must first compile this model" in str(e)

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ACCELERATOR_TYPE = "ml.eia.medium"
4343
IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3"
4444
COMPILATION_JOB_NAME = "{}-{}".format("compilation-sagemaker-mxnet", TIMESTAMP)
45+
EDGE_PACKAGING_JOB_NAME = "{}-{}".format("compilation-sagemaker-mxnet", TIMESTAMP)
4546
FRAMEWORK = "mxnet"
4647
ROLE = "Dummy"
4748
REGION = "us-west-2"

0 commit comments

Comments
 (0)