Skip to content

Commit 31051e3

Browse files
committed
add ntm algorithm and unit tests
1 parent b45d79c commit 31051e3

File tree

4 files changed

+442
-1
lines changed

4 files changed

+442
-1
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
228228

229229
def registry(region_name, algorithm=None):
230230
"""Return docker registry for the given AWS region"""
231-
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines"]:
231+
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm"]:
232232
account_id = {
233233
"us-east-1": "382416733822",
234234
"us-east-2": "404615174143",

src/sagemaker/amazon/ntm.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
14+
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
15+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16+
from sagemaker.amazon.validation import ge, le, isin
17+
from sagemaker.predictor import RealTimePredictor
18+
from sagemaker.model import Model
19+
from sagemaker.session import Session
20+
21+
22+
class NTM(AmazonAlgorithmEstimatorBase):
23+
24+
repo_name = 'ntm'
25+
repo_version = 1
26+
27+
num_topics = hp('num_topics', (ge(2), le(1000)), 'An integer in [2, 1000]', int)
28+
encoder_layers = hp(name='encoder_layers', validation_message='A comma separated list of '
29+
'positive integers or "auto"', data_type=list)
30+
epochs = hp('epochs', (ge(1), le(100)), 'An integer in [1, 100]', int)
31+
encoder_layers_activation = hp('encoder_layers_activation', isin('sigmoid', 'tanh', 'relu'),
32+
'One of "sigmoid", "tanh" or "relu"', str)
33+
optimizer = hp('optimizer', isin('adagrad', 'adam', 'rmsprop', 'sgd', 'adadelta'),
34+
'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str)
35+
tolerance = hp('tolerance', (ge(1e-6), le(0.1)), 'A float in [1e-6, 0.1]', float)
36+
num_patience_epochs = hp('num_patience_epochs', (ge(1), le(10)), 'An integer in [1, 10]', int)
37+
batch_norm = hp(name='batch_norm', validation_message='Value must be a boolean', data_type=bool)
38+
rescale_gradient = hp('rescale_gradient', (ge(1e-3), le(1.0)), 'A float in [1e-3, 1.0]', float)
39+
clip_gradient = hp('clip_gradient', ge(1e-3), 'A float greater equal to 1e-3', float)
40+
weight_decay = hp('weight_decay', (ge(0.0), le(1.0)), 'A float in [0.0, 1.0]', float)
41+
learning_rate = hp('learning_rate', (ge(1e-6), le(1.0)), 'A float in [1e-6, 1.0]', float)
42+
43+
def __init__(self, role, train_instance_count, train_instance_type, num_topics,
44+
encoder_layers=None, epochs=None, encoder_layers_activation=None, optimizer=None, tolerance=None,
45+
num_patience_epochs=None, batch_norm=None, rescale_gradient=None, clip_gradient=None,
46+
weight_decay=None, learning_rate=None, **kwargs):
47+
"""Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning.
48+
49+
This Estimator may be fit via calls to
50+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon
51+
:class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3.
52+
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
53+
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
54+
to the `fit` call.
55+
56+
To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please
57+
consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
58+
59+
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
60+
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint,
61+
deploy returns a :class:`~sagemaker.amazon.ntm.NTMPredictor` object that can be used
62+
for inference calls using the trained model hosted in the SageMaker Endpoint.
63+
64+
NTM Estimators can be configured by setting hyperparameters. The available hyperparameters for
65+
NTM are documented below.
66+
67+
For further information on the AWS NTM algorithm,
68+
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/ntm.html
69+
70+
Args:
71+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
72+
APIs that create Amazon SageMaker endpoints use this role to access
73+
training data and model artifacts. After the endpoint is created,
74+
the inference code might use the IAM role, if accessing AWS resource.
75+
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
76+
num_topics (int): Required. The number of topics for NTM to find within the data.
77+
encoder_layers (list): Optional. Represents number of layers in the encoder and the output size of
78+
each layer.
79+
epochs (int): Optional. Maximum number of passes over the training data.
80+
encoder_layers_activation (str): Optional. Activation function to use in the encoder layers.
81+
optimizer (str): Optional. Optimizer to use for training.
82+
tolerance (float): Optional. Maximum relative change in the loss function within the last
83+
num_patience_epochs number of epochs below which early stopping is triggered.
84+
num_patience_epochs (int): Optional. Number of successive epochs over which early stopping criterion
85+
is evaluated.
86+
batch_norm (bool): Optional. Whether to use batch normalization during training.
87+
rescale_gradient (float): Optional. Rescale factor for gradient.
88+
clip_gradient (float): Optional. Maximum magnitude for each gradient component.
89+
weight_decay (float): Optional. Weight decay coefficient. Adds L2 regularization.
90+
learning_rate (float): Optional. Learning rate for the optimizer.
91+
**kwargs: base class keyword argument values.
92+
"""
93+
94+
super(NTM, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
95+
self.num_topics = num_topics
96+
self.encoder_layers = encoder_layers
97+
self.epochs = epochs
98+
self.encoder_layers_activation = encoder_layers_activation
99+
self.optimizer = optimizer
100+
self.tolerance = tolerance
101+
self.num_patience_epochs = num_patience_epochs
102+
self.batch_norm = batch_norm
103+
self.rescale_gradient = rescale_gradient
104+
self.clip_gradient = clip_gradient
105+
self.weight_decay = weight_decay
106+
self.learning_rate = learning_rate
107+
108+
def create_model(self):
109+
"""Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
110+
s3 model data produced by this Estimator."""
111+
112+
return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
113+
114+
def fit(self, records, mini_batch_size, **kwargs):
115+
# mini_batch_size is required, prevent explicit calls with None
116+
if mini_batch_size is None:
117+
raise ValueError("mini_batch_size must be set")
118+
super(NTM, self).fit(records, mini_batch_size, **kwargs)
119+
120+
121+
class NTMPredictor(RealTimePredictor):
122+
"""Transforms input vectors to lower-dimesional representations.
123+
124+
The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this
125+
`RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the
126+
same number of columns as the feature-dimension of the data used to fit the model this
127+
Predictor performs inference on.
128+
129+
:meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one
130+
for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection``
131+
key of the ``Record.label`` field."""
132+
133+
def __init__(self, endpoint, sagemaker_session=None):
134+
super(NTMPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(),
135+
deserializer=record_deserializer())
136+
137+
138+
class NTMModel(Model):
139+
"""Reference NTM s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return
140+
a Predictor that transforms vectors to a lower-dimensional representation."""
141+
142+
def __init__(self, model_data, role, sagemaker_session=None):
143+
sagemaker_session = sagemaker_session or Session()
144+
repo = '{}:{}'.format(NTM.repo_name, NTM.repo_version)
145+
image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo)
146+
super(NTMModel, self).__init__(model_data, image, role, predictor_cls=NTMPredictor,
147+
sagemaker_session=sagemaker_session)

src/sagemaker/amazon/validation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def validate(value):
3030
return validate
3131

3232

33+
def le(maximum):
34+
def validate(value):
35+
return value <= maximum
36+
return validate
37+
38+
3339
def isin(*expected):
3440
def validate(value):
3541
return value in expected

0 commit comments

Comments
 (0)