Skip to content

Commit 885e4c9

Browse files
authored
Merge branch 'master' into fix-pydoc
2 parents fd62c54 + adcbbaf commit 885e4c9

File tree

18 files changed

+1498
-8
lines changed

18 files changed

+1498
-8
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## v2.49.2 (2021-07-21)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* order of populating container list
8+
* upgrade Adobe Analytics cookie to 3.0
9+
310
## v2.49.1 (2021-07-19)
411

512
### Bug Fixes and Other Changes

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.49.2.dev0
1+
2.49.3.dev0

src/sagemaker/model.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import abc
1617
import json
1718
import logging
1819
import os
@@ -29,6 +30,7 @@
2930
git_utils,
3031
)
3132
from sagemaker.deprecations import removed_kwargs
33+
from sagemaker.predictor import PredictorBase
3234
from sagemaker.transformer import Transformer
3335

3436
LOGGER = logging.getLogger("sagemaker")
@@ -38,7 +40,23 @@
3840
)
3941

4042

41-
class Model(object):
43+
class ModelBase(abc.ABC):
44+
"""An object that encapsulates a trained model.
45+
46+
Models can be deployed to compute services like a SageMaker ``Endpoint``
47+
or Lambda. Deployed models can be used to perform real-time inference.
48+
"""
49+
50+
@abc.abstractmethod
51+
def deploy(self, *args, **kwargs) -> PredictorBase:
52+
"""Deploy this model to a compute service."""
53+
54+
@abc.abstractmethod
55+
def delete_model(self, *args, **kwargs) -> None:
56+
"""Destroy resources associated with this model."""
57+
58+
59+
class Model(ModelBase):
4260
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
4361

4462
def __init__(

src/sagemaker/predictor.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16+
import abc
17+
from typing import Any, Tuple
18+
1619
from sagemaker.deprecations import (
1720
deprecated_class,
1821
deprecated_deserialize,
@@ -51,7 +54,29 @@
5154
from sagemaker.lineage.context import EndpointContext
5255

5356

54-
class Predictor(object):
57+
class PredictorBase(abc.ABC):
58+
"""An object that encapsulates a deployed model."""
59+
60+
@abc.abstractmethod
61+
def predict(self, *args, **kwargs) -> Any:
62+
"""Perform inference on the provided data and return a prediction."""
63+
64+
@abc.abstractmethod
65+
def delete_endpoint(self, *args, **kwargs) -> None:
66+
"""Destroy resources associated with this predictor."""
67+
68+
@property
69+
@abc.abstractmethod
70+
def content_type(self) -> str:
71+
"""The MIME type of the data sent to the inference server."""
72+
73+
@property
74+
@abc.abstractmethod
75+
def accept(self) -> Tuple[str]:
76+
"""The content type(s) that are expected from the inference server."""
77+
78+
79+
class Predictor(PredictorBase):
5580
"""Make prediction requests to an Amazon SageMaker endpoint."""
5681

5782
def __init__(

src/sagemaker/serverless/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for performing machine learning on serverless compute."""
14+
from sagemaker.serverless.model import LambdaModel # noqa: F401
15+
from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401

src/sagemaker/serverless/model.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Models that can be deployed to serverless compute."""
14+
from __future__ import absolute_import
15+
16+
import time
17+
from typing import Optional
18+
19+
import boto3
20+
import botocore
21+
22+
from sagemaker.model import ModelBase
23+
24+
from .predictor import LambdaPredictor
25+
26+
27+
class LambdaModel(ModelBase):
28+
"""A model that can be deployed to Lambda."""
29+
30+
def __init__(
31+
self, image_uri: str, role: str, client: Optional[botocore.client.BaseClient] = None
32+
) -> None:
33+
"""Initialize instance attributes.
34+
35+
Arguments:
36+
image_uri: URI of a container image in the Amazon ECR registry. The image
37+
should contain a handler that performs inference.
38+
role: The Amazon Resource Name (ARN) of the IAM role that Lambda will assume
39+
when it performs inference
40+
client: The Lambda client used to interact with Lambda.
41+
"""
42+
self._client = client or boto3.client("lambda")
43+
self._image_uri = image_uri
44+
self._role = role
45+
46+
def deploy(
47+
self, function_name: str, timeout: int, memory_size: int, wait: bool = True
48+
) -> LambdaPredictor:
49+
"""Create a Lambda function using the image specified in the constructor.
50+
51+
Arguments:
52+
function_name: The name of the function.
53+
timeout: The number of seconds that the function can run for before being terminated.
54+
memory_size: The amount of memory in MB that the function has access to.
55+
wait: If true, wait until the deployment completes (default: True).
56+
57+
Returns:
58+
A LambdaPredictor instance that performs inference using the specified image.
59+
"""
60+
response = self._client.create_function(
61+
FunctionName=function_name,
62+
PackageType="Image",
63+
Role=self._role,
64+
Code={
65+
"ImageUri": self._image_uri,
66+
},
67+
Timeout=timeout,
68+
MemorySize=memory_size,
69+
)
70+
71+
if not wait:
72+
return LambdaPredictor(function_name, client=self._client)
73+
74+
# Poll function state.
75+
polling_interval = 5
76+
while response["State"] == "Pending":
77+
time.sleep(polling_interval)
78+
response = self._client.get_function_configuration(FunctionName=function_name)
79+
80+
if response["State"] != "Active":
81+
raise RuntimeError("Failed to deploy model to Lambda: %s" % response["StateReason"])
82+
83+
return LambdaPredictor(function_name, client=self._client)
84+
85+
def delete_model(self) -> None:
86+
"""Destroy resources associated with this model.
87+
88+
This method does not delete the image specified in the constructor. As
89+
a result, this method is a no-op.
90+
"""

src/sagemaker/serverless/predictor.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Predictors that are hosted on serverless compute."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Tuple
17+
18+
import boto3
19+
import botocore
20+
21+
from sagemaker import deserializers, serializers
22+
from sagemaker.predictor import PredictorBase
23+
24+
25+
class LambdaPredictor(PredictorBase):
26+
"""A deployed model hosted on Lambda."""
27+
28+
def __init__(
29+
self, function_name: str, client: Optional[botocore.client.BaseClient] = None
30+
) -> None:
31+
"""Initialize instance attributes.
32+
33+
Arguments:
34+
function_name: The name of the function.
35+
client: The Lambda client used to interact with Lambda.
36+
"""
37+
self._client = client or boto3.client("lambda")
38+
self._function_name = function_name
39+
self._serializer = serializers.JSONSerializer()
40+
self._deserializer = deserializers.JSONDeserializer()
41+
42+
def predict(self, data: dict) -> dict:
43+
"""Invoke the Lambda function specified in the constructor.
44+
45+
This function is synchronous. It will only return after the function
46+
has produced a prediction.
47+
48+
Arguments:
49+
data: The data sent to the Lambda function as input.
50+
51+
Returns:
52+
The data returned by the Lambda function.
53+
"""
54+
response = self._client.invoke(
55+
FunctionName=self._function_name,
56+
InvocationType="RequestResponse",
57+
Payload=self._serializer.serialize(data),
58+
)
59+
return self._deserializer.deserialize(
60+
response["Payload"],
61+
response["ResponseMetadata"]["HTTPHeaders"]["content-type"],
62+
)
63+
64+
def delete_endpoint(self) -> None:
65+
"""Destroy the Lambda function specified in the constructor."""
66+
self._client.delete_function(FunctionName=self._function_name)
67+
68+
@property
69+
def content_type(self) -> str:
70+
"""The MIME type of the data sent to the Lambda function."""
71+
return self._serializer.CONTENT_TYPE
72+
73+
@property
74+
def accept(self) -> Tuple[str]:
75+
"""The content type(s) that are expected from the Lambda function."""
76+
return self._deserializer.ACCEPT
77+
78+
@property
79+
def function_name(self) -> str:
80+
"""The name of the Lambda function this predictor invokes."""
81+
return self._function_name

src/sagemaker/workflow/step_collections.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,8 @@ def __init__(
136136
elif model is not None:
137137
if isinstance(model, PipelineModel):
138138
self.model_list = model.models
139-
self.container_def_list = model.pipeline_container_def(inference_instances[0])
140139
elif isinstance(model, Model):
141140
self.model_list = [model]
142-
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
143141

144142
for model_entity in self.model_list:
145143
if estimator is not None:
@@ -154,10 +152,10 @@ def __init__(
154152
source_dir = model_entity.source_dir
155153
dependencies = model_entity.dependencies
156154
kwargs = dict(**kwargs, output_kms_key=model_entity.model_kms_key)
157-
name = model_entity.name or model_entity._framework_name
155+
model_name = model_entity.name or model_entity._framework_name
158156

159157
repack_model_step = _RepackModelStep(
160-
name=f"{name}RepackModel",
158+
name=f"{model_name}RepackModel",
161159
depends_on=depends_on,
162160
sagemaker_session=sagemaker_session,
163161
role=role,
@@ -171,10 +169,14 @@ def __init__(
171169
model_entity.model_data = (
172170
repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
173171
)
174-
175172
# remove kwargs consumed by model repacking step
176173
kwargs.pop("output_kms_key", None)
177174

175+
if isinstance(model, PipelineModel):
176+
self.container_def_list = model.pipeline_container_def(inference_instances[0])
177+
elif isinstance(model, Model):
178+
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
179+
178180
register_model_step = _RegisterModelStep(
179181
name=name,
180182
estimator=estimator,

tests/data/serverless/Dockerfile

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
FROM public.ecr.aws/lambda/python:3.8
2+
3+
COPY requirements.txt /tmp/requirements.txt
4+
RUN pip3 install -r /tmp/requirements.txt
5+
6+
COPY model.py .
7+
RUN python model.py
8+
9+
COPY app.py .
10+
COPY classes.txt .
11+
12+
CMD ["app.handler"]

tests/data/serverless/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
This folder contains the source code for the image used in the
2+
`sagemaker.serverless` tests.
3+
4+
The image contains the code for a Lambda handler and all of its dependencies.
5+
The Lambda handler predicts the class of an image using a pre-trained ResNet-34
6+
model.

tests/data/serverless/app.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import urllib
2+
import json
3+
import os
4+
5+
import torch
6+
from PIL import Image
7+
from torchvision import models
8+
from torchvision import transforms
9+
10+
transform = transforms.Compose(
11+
[
12+
transforms.Resize(256),
13+
transforms.CenterCrop(224),
14+
transforms.ToTensor(),
15+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
16+
]
17+
)
18+
19+
with open("classes.txt") as file:
20+
classes = [s.strip() for s in file.readlines()]
21+
22+
23+
def handler(event, context):
24+
data = urllib.request.urlopen(event["url"])
25+
26+
image = Image.open(data)
27+
image = transform(image)
28+
image = image.unsqueeze(0)
29+
30+
model = torch.jit.load("./model.pt")
31+
outputs = model(image)
32+
target = outputs.argmax().item()
33+
34+
return {"class": classes[target]}

0 commit comments

Comments
 (0)