Skip to content

Commit 795b030

Browse files
yangawslukmis
authored andcommitted
Add ntm algorithm with doc, unit tests, integ tests (#73)
1 parent b400fa4 commit 795b030

15 files changed

+603
-30
lines changed

CHANGELOG.rst

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
CHANGELOG
33
=========
44

5+
1.0.4
6+
=====
7+
8+
* feature: Estimators: add support for Amazon Neural Topic Model(NTM) algorithm
9+
* feature: Documentation: Fix description of an argument of sagemaker.session.train
10+
* feature: Documentation: Add FM and LDA to the documentation
11+
* feature: Estimators: add support for async fit
12+
* bug-fix: Estimators: fix estimator role expansion
13+
514
1.0.3
615
=====
716

README.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ You can install from source by cloning this repository and issuing a pip install
3939

4040
git clone https://github.com/aws/sagemaker-python-sdk.git
4141
python setup.py sdist
42-
pip install dist/sagemaker-1.0.3.tar.gz
42+
pip install dist/sagemaker-1.0.4.tar.gz
4343

4444
Supported Python versions
4545
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1447,7 +1447,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
14471447
14481448
The full list of algorithms is available on the AWS website: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
14491449
1450-
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis, Linear Learner, Factorization Machines and LDA algorithms.
1450+
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA) and Neural Topic Model(NTM) algorithms.
14511451
14521452
Definition and usage
14531453
~~~~~~~~~~~~~~~~~~~~

doc/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __getattr__(cls, name):
1818
'tensorflow.python.framework', 'tensorflow_serving', 'tensorflow_serving.apis']
1919
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
2020

21-
version = '1.0.3'
21+
version = '1.0.4'
2222
project = u'sagemaker'
2323

2424
# Add any Sphinx extension module names here, as strings. They can be extensions

doc/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ Amazon provides implementations of some common machine learning algortithms opti
4949
sagemaker.amazon.amazon_estimator
5050
factorization_machines
5151
lda
52+
ntm

doc/ntm.rst

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
NTM
2+
--------------------
3+
4+
The Amazon SageMaker NTM algorithm.
5+
6+
.. autoclass:: sagemaker.NTM
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
:inherited-members:
11+
:exclude-members: image, num_topics, encoder_layers, epochs, encoder_layers_activation, optimizer, tolerance,
12+
num_patience_epochs, batch_norm, rescale_gradient, clip_gradient, weight_decay, learning_rate
13+
14+
15+
.. autoclass:: sagemaker.NTMModel
16+
:members:
17+
:undoc-members:
18+
:show-inheritance:
19+
20+
.. autoclass:: sagemaker.NTMPredictor
21+
:members:
22+
:undoc-members:
23+
:show-inheritance:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def read(fname):
1111

1212

1313
setup(name="sagemaker",
14-
version="1.0.3",
14+
version="1.0.4",
1515
description="Open source library for training and deploying models on Amazon SageMaker.",
1616
packages=find_packages('src'),
1717
package_dir={'': 'src'},

src/sagemaker/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor
2020
from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel
2121
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor
22+
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor
2223

2324
from sagemaker.model import Model
2425
from sagemaker.predictor import RealTimePredictor
@@ -33,5 +34,5 @@
3334
LinearLearnerModel, LinearLearnerPredictor,
3435
LDA, LDAModel, LDAPredictor,
3536
FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor,
36-
Model, RealTimePredictor, Session,
37+
Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session,
3738
container_def, s3_input, production_variant, get_execution_role]

src/sagemaker/amazon/amazon_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
228228

229229
def registry(region_name, algorithm=None):
230230
"""Return docker registry for the given AWS region"""
231-
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines"]:
231+
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm"]:
232232
account_id = {
233233
"us-east-1": "382416733822",
234234
"us-east-2": "404615174143",

src/sagemaker/amazon/ntm.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
14+
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
15+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16+
from sagemaker.amazon.validation import ge, le, isin
17+
from sagemaker.predictor import RealTimePredictor
18+
from sagemaker.model import Model
19+
from sagemaker.session import Session
20+
21+
22+
class NTM(AmazonAlgorithmEstimatorBase):
23+
24+
repo_name = 'ntm'
25+
repo_version = 1
26+
27+
num_topics = hp('num_topics', (ge(2), le(1000)), 'An integer in [2, 1000]', int)
28+
encoder_layers = hp(name='encoder_layers', validation_message='A comma separated list of '
29+
'positive integers', data_type=list)
30+
epochs = hp('epochs', (ge(1), le(100)), 'An integer in [1, 100]', int)
31+
encoder_layers_activation = hp('encoder_layers_activation', isin('sigmoid', 'tanh', 'relu'),
32+
'One of "sigmoid", "tanh" or "relu"', str)
33+
optimizer = hp('optimizer', isin('adagrad', 'adam', 'rmsprop', 'sgd', 'adadelta'),
34+
'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str)
35+
tolerance = hp('tolerance', (ge(1e-6), le(0.1)), 'A float in [1e-6, 0.1]', float)
36+
num_patience_epochs = hp('num_patience_epochs', (ge(1), le(10)), 'An integer in [1, 10]', int)
37+
batch_norm = hp(name='batch_norm', validation_message='Value must be a boolean', data_type=bool)
38+
rescale_gradient = hp('rescale_gradient', (ge(1e-3), le(1.0)), 'A float in [1e-3, 1.0]', float)
39+
clip_gradient = hp('clip_gradient', ge(1e-3), 'A float greater equal to 1e-3', float)
40+
weight_decay = hp('weight_decay', (ge(0.0), le(1.0)), 'A float in [0.0, 1.0]', float)
41+
learning_rate = hp('learning_rate', (ge(1e-6), le(1.0)), 'A float in [1e-6, 1.0]', float)
42+
43+
def __init__(self, role, train_instance_count, train_instance_type, num_topics,
44+
encoder_layers=None, epochs=None, encoder_layers_activation=None, optimizer=None, tolerance=None,
45+
num_patience_epochs=None, batch_norm=None, rescale_gradient=None, clip_gradient=None,
46+
weight_decay=None, learning_rate=None, **kwargs):
47+
"""Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning.
48+
49+
This Estimator may be fit via calls to
50+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon
51+
:class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3.
52+
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
53+
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
54+
to the `fit` call.
55+
56+
To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please
57+
consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
58+
59+
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
60+
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint,
61+
deploy returns a :class:`~sagemaker.amazon.ntm.NTMPredictor` object that can be used
62+
for inference calls using the trained model hosted in the SageMaker Endpoint.
63+
64+
NTM Estimators can be configured by setting hyperparameters. The available hyperparameters for
65+
NTM are documented below.
66+
67+
For further information on the AWS NTM algorithm,
68+
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/ntm.html
69+
70+
Args:
71+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
72+
APIs that create Amazon SageMaker endpoints use this role to access
73+
training data and model artifacts. After the endpoint is created,
74+
the inference code might use the IAM role, if accessing AWS resource.
75+
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
76+
num_topics (int): Required. The number of topics for NTM to find within the data.
77+
encoder_layers (list): Optional. Represents number of layers in the encoder and the output size of
78+
each layer.
79+
epochs (int): Optional. Maximum number of passes over the training data.
80+
encoder_layers_activation (str): Optional. Activation function to use in the encoder layers.
81+
optimizer (str): Optional. Optimizer to use for training.
82+
tolerance (float): Optional. Maximum relative change in the loss function within the last
83+
num_patience_epochs number of epochs below which early stopping is triggered.
84+
num_patience_epochs (int): Optional. Number of successive epochs over which early stopping criterion
85+
is evaluated.
86+
batch_norm (bool): Optional. Whether to use batch normalization during training.
87+
rescale_gradient (float): Optional. Rescale factor for gradient.
88+
clip_gradient (float): Optional. Maximum magnitude for each gradient component.
89+
weight_decay (float): Optional. Weight decay coefficient. Adds L2 regularization.
90+
learning_rate (float): Optional. Learning rate for the optimizer.
91+
**kwargs: base class keyword argument values.
92+
"""
93+
94+
super(NTM, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
95+
self.num_topics = num_topics
96+
self.encoder_layers = encoder_layers
97+
self.epochs = epochs
98+
self.encoder_layers_activation = encoder_layers_activation
99+
self.optimizer = optimizer
100+
self.tolerance = tolerance
101+
self.num_patience_epochs = num_patience_epochs
102+
self.batch_norm = batch_norm
103+
self.rescale_gradient = rescale_gradient
104+
self.clip_gradient = clip_gradient
105+
self.weight_decay = weight_decay
106+
self.learning_rate = learning_rate
107+
108+
def create_model(self):
109+
"""Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
110+
s3 model data produced by this Estimator."""
111+
112+
return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
113+
114+
def fit(self, records, mini_batch_size=None, **kwargs):
115+
if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 10000):
116+
raise ValueError("mini_batch_size must be in [1, 10000]")
117+
super(NTM, self).fit(records, mini_batch_size, **kwargs)
118+
119+
120+
class NTMPredictor(RealTimePredictor):
121+
"""Transforms input vectors to lower-dimesional representations.
122+
123+
The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this
124+
`RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the
125+
same number of columns as the feature-dimension of the data used to fit the model this
126+
Predictor performs inference on.
127+
128+
:meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one
129+
for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection``
130+
key of the ``Record.label`` field."""
131+
132+
def __init__(self, endpoint, sagemaker_session=None):
133+
super(NTMPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(),
134+
deserializer=record_deserializer())
135+
136+
137+
class NTMModel(Model):
138+
"""Reference NTM s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return
139+
a Predictor that transforms vectors to a lower-dimensional representation."""
140+
141+
def __init__(self, model_data, role, sagemaker_session=None):
142+
sagemaker_session = sagemaker_session or Session()
143+
repo = '{}:{}'.format(NTM.repo_name, NTM.repo_version)
144+
image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo)
145+
super(NTMModel, self).__init__(model_data, image, role, predictor_cls=NTMPredictor,
146+
sagemaker_session=sagemaker_session)

src/sagemaker/amazon/validation.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def validate(value):
3030
return validate
3131

3232

33+
def le(maximum):
34+
def validate(value):
35+
return value <= maximum
36+
return validate
37+
38+
3339
def isin(*expected):
3440
def validate(value):
3541
return value in expected

tests/data/ntm/nips-train_1.pbr

1.01 MB
Binary file not shown.

tests/integ/record_set.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from six.moves.urllib.parse import urlparse
2+
3+
from sagemaker.amazon.amazon_estimator import RecordSet
4+
from sagemaker.utils import sagemaker_timestamp
5+
6+
7+
def prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session):
8+
"""Build a :class:`~RecordSet` by pointing to local files.
9+
10+
Args:
11+
dir_path (string): Path to local directory from where the files shall be uploaded.
12+
destination (string): S3 path to upload the file to.
13+
num_records (int): Number of records in all the files
14+
feature_dim (int): Number of features in the data set
15+
sagemaker_session (sagemaker.session.Session): Session object to manage interactions with Amazon SageMaker APIs.
16+
Returns:
17+
RecordSet: A RecordSet specified by S3Prefix to to be used in training.
18+
"""
19+
key_prefix = urlparse(destination).path
20+
key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp())
21+
key_prefix = key_prefix.lstrip('/')
22+
uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix)
23+
return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix')

tests/integ/test_lda.py

+4-24
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
import boto3
1414
import numpy as np
1515
import os
16-
from six.moves.urllib.parse import urlparse
1716

1817
import sagemaker
1918
from sagemaker import LDA, LDAModel
20-
from sagemaker.amazon.amazon_estimator import RecordSet
2119
from sagemaker.amazon.common import read_records
22-
from sagemaker.utils import name_from_base, sagemaker_timestamp
20+
from sagemaker.utils import name_from_base
2321

2422
from tests.integ import DATA_DIR, REGION
2523
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
24+
from tests.integ.record_set import prepare_record_set_from_local_files
2625

2726

2827
def test_lda():
@@ -41,8 +40,8 @@ def test_lda():
4140
lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10,
4241
sagemaker_session=sagemaker_session, base_job_name='test-lda')
4342

44-
record_set = _prepare_record_set_from_local_files(data_path, lda.data_location,
45-
len(all_records), feature_num, sagemaker_session)
43+
record_set = prepare_record_set_from_local_files(data_path, lda.data_location,
44+
len(all_records), feature_num, sagemaker_session)
4645
lda.fit(record_set, 100)
4746

4847
endpoint_name = name_from_base('lda')
@@ -56,22 +55,3 @@ def test_lda():
5655
assert len(result) == 1
5756
for record in result:
5857
assert record.label["topic_mixture"] is not None
59-
60-
61-
def _prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session):
62-
"""Build a :class:`~RecordSet` by pointing to local files.
63-
64-
Args:
65-
dir_path (string): Path to local directory from where the files shall be uploaded.
66-
destination (string): S3 path to upload the file to.
67-
num_records (int): Number of records in all the files
68-
feature_dim (int): Number of features in the data set
69-
sagemaker_session (sagemaker.session.Session): Session object to manage interactions with Amazon SageMaker APIs.
70-
Returns:
71-
RecordSet: A RecordSet specified by S3Prefix to to be used in training.
72-
"""
73-
key_prefix = urlparse(destination).path
74-
key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp())
75-
key_prefix = key_prefix.lstrip('/')
76-
uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix)
77-
return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix')

tests/integ/test_ntm.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import boto3
14+
import numpy as np
15+
import os
16+
17+
import sagemaker
18+
from sagemaker import NTM, NTMModel
19+
from sagemaker.amazon.common import read_records
20+
from sagemaker.utils import name_from_base
21+
22+
from tests.integ import DATA_DIR, REGION
23+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
24+
from tests.integ.record_set import prepare_record_set_from_local_files
25+
26+
27+
def test_ntm():
28+
29+
with timeout(minutes=15):
30+
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
31+
data_path = os.path.join(DATA_DIR, 'ntm')
32+
data_filename = 'nips-train_1.pbr'
33+
34+
with open(os.path.join(data_path, data_filename), 'rb') as f:
35+
all_records = read_records(f)
36+
37+
# all records must be same
38+
feature_num = int(all_records[0].features['values'].float32_tensor.shape[0])
39+
40+
ntm = NTM(role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', num_topics=10,
41+
sagemaker_session=sagemaker_session, base_job_name='test-ntm')
42+
43+
record_set = prepare_record_set_from_local_files(data_path, ntm.data_location,
44+
len(all_records), feature_num, sagemaker_session)
45+
ntm.fit(record_set, None)
46+
47+
endpoint_name = name_from_base('ntm')
48+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
49+
model = NTMModel(ntm.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session)
50+
predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
51+
52+
predict_input = np.random.rand(1, feature_num)
53+
result = predictor.predict(predict_input)
54+
55+
assert len(result) == 1
56+
for record in result:
57+
assert record.label["topic_weights"] is not None

0 commit comments

Comments
 (0)