Skip to content

Commit e2b8f0d

Browse files
committed
Add wrapper for LDA.
Update CHANGELOG and bump the version number.
1 parent e82fb4f commit e2b8f0d

File tree

8 files changed

+505
-9
lines changed

8 files changed

+505
-9
lines changed

CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
CHANGELOG
33
=========
44

5+
1.0.3
6+
=====
7+
8+
* feature: Estimators: add support for Amazon LDA algorithm
9+
* feature: Documentation: Update TensorFlow examples following API change
10+
* feature: Session: Support multi-part uploads
11+
12+
513
1.0.2
614
=====
715

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def read(fname):
1111

1212

1313
setup(name="sagemaker",
14-
version="1.0.2",
14+
version="1.0.3",
1515
description="Open source library for training and deploying models on Amazon SageMaker.",
1616
packages=find_packages('src'),
1717
package_dir={'': 'src'},

src/sagemaker/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sagemaker import estimator
1616
from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor
1717
from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor
18+
from sagemaker.amazon.lda import LDA, LDAModel, LDAPredictor
1819
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor
1920
from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel
2021
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor
@@ -30,6 +31,7 @@
3031

3132
__all__ = [estimator, KMeans, KMeansModel, KMeansPredictor, PCA, PCAModel, PCAPredictor, LinearLearner,
3233
LinearLearnerModel, LinearLearnerPredictor,
34+
LDA, LDAModel, LDAPredictor,
3335
FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor,
3436
Model, RealTimePredictor, Session,
3537
container_def, s3_input, production_variant, get_execution_role]

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import boto3
1314
import json
1415
import logging
1516
import tempfile
@@ -18,6 +19,7 @@
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.common import write_numpy_to_dense_tensor
2021
from sagemaker.estimator import EstimatorBase
22+
from sagemaker.fw_utils import parse_s3_url
2123
from sagemaker.session import s3_input
2224
from sagemaker.utils import sagemaker_timestamp
2325

@@ -47,7 +49,7 @@ def __init__(self, role, train_instance_count, train_instance_type, data_locatio
4749
self.data_location = data_location
4850

4951
def train_image(self):
50-
return registry(self.sagemaker_session.boto_region_name) + "/" + type(self).repo
52+
return registry(self.sagemaker_session.boto_region_name, type(self).__name__) + "/" + type(self).repo
5153

5254
def hyperparameters(self):
5355
return hp.serialize_all(self)
@@ -152,6 +154,61 @@ def __repr__(self):
152154
"""Return an unambiguous representation of this RecordSet"""
153155
return str((RecordSet, self.__dict__))
154156

157+
@staticmethod
158+
def from_s3(data_path, num_records, feature_dim, channel='train'):
159+
"""
160+
Create instance of the class given S3 path. It prepares the manifest file with all files found at the location.
161+
162+
Args:
163+
data_path: S3 path to files
164+
num_records: Number of records at S3 location
165+
feature_dim: Number of features in each of the files
166+
channel: Name of the data channel
167+
168+
Returns:
169+
Instance of RecordSet that can be used when calling
170+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`
171+
"""
172+
s3 = boto3.client('s3')
173+
174+
if not data_path.endswith('/'):
175+
data_path = data_path + '/'
176+
177+
bucket, prefix = parse_s3_url(data_path)
178+
179+
all_files = []
180+
next_token = None
181+
more = True
182+
while more:
183+
list_req = {
184+
'Bucket': bucket,
185+
'Prefix': prefix
186+
}
187+
if next_token is not None:
188+
list_req.update({'ContinuationToken': next_token})
189+
objects = s3.list_objects_v2(**list_req)
190+
more = objects['IsTruncated']
191+
if more:
192+
next_token = objects['NextContinuationToken']
193+
files_list = objects.get('Contents', None)
194+
if files_list is None:
195+
continue
196+
long_names = [content['Key'] for content in files_list]
197+
files = [file.split(prefix)[1] for file in long_names]
198+
[all_files.append(f) for f in files]
199+
200+
if len(all_files) == 0:
201+
raise ValueError("S3 location:{} doesn't have any files".format(data_path))
202+
manifest_key = prefix + ".amazon.manifest"
203+
manifest_str = json.dumps([{'prefix': data_path}] + all_files)
204+
205+
s3.put_object(Bucket=bucket, Body=manifest_str.encode('utf-8'), Key=manifest_key)
206+
207+
return RecordSet("s3://{}/{}".format(bucket, manifest_key),
208+
num_records=num_records,
209+
feature_dim=feature_dim,
210+
channel=channel)
211+
155212

156213
def _build_shards(num_shards, array):
157214
if num_shards < 1:
@@ -200,12 +257,22 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
200257
raise ex
201258

202259

203-
def registry(region_name):
260+
def registry(region_name, algorithm=None):
204261
"""Return docker registry for the given AWS region"""
205-
account_id = {
206-
"us-east-1": "382416733822",
207-
"us-east-2": "404615174143",
208-
"us-west-2": "174872318107",
209-
"eu-west-1": "438346466558"
210-
}[region_name]
262+
if algorithm in [None, "PCA", "KMeans", "LinearLearner", "FactorizationMachines"]:
263+
account_id = {
264+
"us-east-1": "382416733822",
265+
"us-east-2": "404615174143",
266+
"us-west-2": "174872318107",
267+
"eu-west-1": "438346466558"
268+
}[region_name]
269+
elif algorithm in ["LDA"]:
270+
account_id = {
271+
"us-east-1": "766337827248",
272+
"us-east-2": "999911452149",
273+
"us-west-2": "266724342769",
274+
"eu-west-1": "999678624901"
275+
}[region_name]
276+
else:
277+
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm))
211278
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)

src/sagemaker/amazon/lda.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
14+
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
15+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16+
from sagemaker.amazon.validation import gt, isint, isnumber
17+
from sagemaker.predictor import RealTimePredictor
18+
from sagemaker.model import Model
19+
from sagemaker.session import Session
20+
21+
22+
class LDA(AmazonAlgorithmEstimatorBase):
23+
24+
repo = 'lda:1'
25+
26+
num_topics = hp('num_topics', (gt(0), isint), 'An integer greater than zero')
27+
alpha0 = hp('alpha0', isnumber, "A float value")
28+
max_restarts = hp('max_restarts', (gt(0), isint), 'An integer greater than zero')
29+
max_iterations = hp('max_iterations', (gt(0), isint), 'An integer greater than zero')
30+
tol = hp('tol', (gt(0), isnumber), "A positive float")
31+
32+
def __init__(self, role, train_instance_type, num_topics,
33+
alpha0=None, max_restarts=None, max_iterations=None, tol=None, **kwargs):
34+
"""Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning.
35+
36+
Amazon SageMaker Latent Dirichlet Allocation is an unsupervised learning algorithm that attempts to describe
37+
a set of observations as a mixture of distinct categories. LDA is most commonly used to discover
38+
a user-specified number of topics shared by documents within a text corpus.
39+
Here each observation is a document, the features are the presence (or occurrence count) of each word, and
40+
the categories are the topics.
41+
42+
This Estimator may be fit via calls to
43+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon
44+
:class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3.
45+
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
46+
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
47+
to the `fit` call.
48+
49+
To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please
50+
consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
51+
52+
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
53+
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint,
54+
deploy returns a :class:`~sagemaker.amazon.lda.LDAPredictor` object that can be used
55+
for inference calls using the trained model hosted in the SageMaker Endpoint.
56+
57+
LDA Estimators can be configured by setting hyperparameters. The available hyperparameters for
58+
LDA are documented below.
59+
60+
For further information on the AWS LDA algorithm,
61+
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/lda.html
62+
63+
Args:
64+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
65+
APIs that create Amazon SageMaker endpoints use this role to access
66+
training data and model artifacts. After the endpoint is created,
67+
the inference code might use the IAM role, if accessing AWS resource.
68+
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
69+
num_topics (int): The number of topics for LDA to find within the data.
70+
alpha0 (float): Initial guess for the concentration parameter
71+
max_restarts (int): The number of restarts to perform during the Alternating Least Squares (ALS)
72+
spectral decomposition phase of the algorithm.
73+
max_iterations (int): The maximum number of iterations to perform during the ALS phase of the algorithm.
74+
tol (float): Target error tolerance for the ALS phase of the algorithm.
75+
**kwargs: base class keyword argument values.
76+
"""
77+
78+
# this algorithm only supports single instance training
79+
super(LDA, self).__init__(role, 1, train_instance_type, **kwargs)
80+
self.num_topics = num_topics
81+
self.alpha0 = alpha0
82+
self.max_restarts = max_restarts
83+
self.max_iterations = max_iterations
84+
self.tol = tol
85+
86+
def create_model(self):
87+
"""Return a :class:`~sagemaker.amazon.FactorizationMachinesModel` referencing the latest
88+
s3 model data produced by this Estimator."""
89+
90+
return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
91+
92+
def fit(self, records, mini_batch_size, **kwargs):
93+
# mini_batch_size is required
94+
if mini_batch_size is None:
95+
raise ValueError("mini_batch_size must be set")
96+
if not isinstance(mini_batch_size, int) or mini_batch_size < 1:
97+
raise ValueError("mini_batch_size must be positive integer")
98+
99+
super(LDA, self).fit(records, mini_batch_size, **kwargs)
100+
101+
102+
class LDAPredictor(RealTimePredictor):
103+
"""Transforms input vectors to lower-dimesional representations.
104+
105+
The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this
106+
`RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the
107+
same number of columns as the feature-dimension of the data used to fit the model this
108+
Predictor performs inference on.
109+
110+
:meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one
111+
for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection``
112+
key of the ``Record.label`` field."""
113+
114+
def __init__(self, endpoint, sagemaker_session=None):
115+
super(LDAPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(),
116+
deserializer=record_deserializer())
117+
118+
119+
class LDAModel(Model):
120+
"""Reference LDA s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return
121+
a Predictor that transforms vectors to a lower-dimensional representation."""
122+
123+
def __init__(self, model_data, role, sagemaker_session=None):
124+
sagemaker_session = sagemaker_session or Session()
125+
image = registry(sagemaker_session.boto_session.region_name, LDA.__name__) + "/" + LDA.repo
126+
super(LDAModel, self).__init__(model_data, image, role, predictor_cls=LDAPredictor,
127+
sagemaker_session=sagemaker_session)

tests/data/lda/nips-train_1.pbr

1.01 MB
Binary file not shown.

tests/integ/test_lda.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import boto3
14+
import numpy as np
15+
import os
16+
17+
import sagemaker
18+
from sagemaker import LDA, LDAModel
19+
from sagemaker.amazon.amazon_estimator import RecordSet
20+
from sagemaker.amazon.common import read_records
21+
from sagemaker.utils import name_from_base, sagemaker_timestamp
22+
from tests.integ import DATA_DIR, REGION
23+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
24+
25+
26+
def test_lda():
27+
28+
with timeout(minutes=15):
29+
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
30+
data_filename = 'nips-train_1.pbr'
31+
data_path = os.path.join(DATA_DIR, 'lda', data_filename)
32+
33+
with open(data_path, 'r') as f:
34+
all_records = read_records(f)
35+
36+
# all records must be same
37+
feature_num = int(all_records[0].features['values'].float32_tensor.shape[0])
38+
39+
lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10,
40+
sagemaker_session=sagemaker_session, base_job_name='test-lda')
41+
42+
# upload data and prepare the set
43+
data_location_key = "integ-test-data/lda-" + sagemaker_timestamp()
44+
sagemaker_session.upload_data(path=data_path, key_prefix=data_location_key)
45+
record_set = RecordSet.from_s3("s3://{}/{}".format(sagemaker_session.default_bucket(), data_location_key),
46+
num_records=len(all_records),
47+
feature_dim=feature_num,
48+
channel='train')
49+
lda.fit(record_set, 100)
50+
51+
endpoint_name = name_from_base('lda')
52+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
53+
model = LDAModel(lda.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session)
54+
predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
55+
56+
predict_input = np.random.rand(1, feature_num)
57+
result = predictor.predict(predict_input)
58+
59+
assert len(result) == 1
60+
for record in result:
61+
assert record.label["topic_mixture"] is not None

0 commit comments

Comments
 (0)