-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add wrapper for LDA. #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
e2b8f0d
33fc211
30c65e0
e32aac7
b8658b3
4e1c134
324655d
5094e97
238db37
7cc945a
1dfc4c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import boto3 | ||
import json | ||
import logging | ||
import tempfile | ||
|
@@ -18,6 +19,7 @@ | |
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa | ||
from sagemaker.amazon.common import write_numpy_to_dense_tensor | ||
from sagemaker.estimator import EstimatorBase | ||
from sagemaker.fw_utils import parse_s3_url | ||
from sagemaker.session import s3_input | ||
from sagemaker.utils import sagemaker_timestamp | ||
|
||
|
@@ -47,7 +49,7 @@ def __init__(self, role, train_instance_count, train_instance_type, data_locatio | |
self.data_location = data_location | ||
|
||
def train_image(self): | ||
return registry(self.sagemaker_session.boto_region_name) + "/" + type(self).repo | ||
return registry(self.sagemaker_session.boto_region_name, type(self).__name__) + "/" + type(self).repo | ||
|
||
def hyperparameters(self): | ||
return hp.serialize_all(self) | ||
|
@@ -152,6 +154,61 @@ def __repr__(self): | |
"""Return an unambiguous representation of this RecordSet""" | ||
return str((RecordSet, self.__dict__)) | ||
|
||
@staticmethod | ||
def from_s3(data_path, num_records, feature_dim, channel='train'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed. EASE already has the logic to list files in an S3 prefix and create its internal manifest file - that's what happens when you use its S3Prefix mode and just pass it a prefix. We shouldn't duplicate that logic here. So, you can already create a RecordSet object and set the s3_data_type constructor argument to 'S3Prefix' to get this functionality. (Or if you just pass an S3 URI string as input to .fit(), you'll get it as well.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed no need to duplicate. I'll be addressing your comment below about the test and this will likely go away. |
||
""" | ||
Create instance of the class given S3 path. It prepares the manifest file with all files found at the location. | ||
|
||
Args: | ||
data_path: S3 path to files | ||
num_records: Number of records at S3 location | ||
feature_dim: Number of features in each of the files | ||
channel: Name of the data channel | ||
|
||
Returns: | ||
Instance of RecordSet that can be used when calling | ||
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit` | ||
""" | ||
s3 = boto3.client('s3') | ||
|
||
if not data_path.endswith('/'): | ||
data_path = data_path + '/' | ||
|
||
bucket, prefix = parse_s3_url(data_path) | ||
|
||
all_files = [] | ||
next_token = None | ||
more = True | ||
while more: | ||
list_req = { | ||
'Bucket': bucket, | ||
'Prefix': prefix | ||
} | ||
if next_token is not None: | ||
list_req.update({'ContinuationToken': next_token}) | ||
objects = s3.list_objects_v2(**list_req) | ||
more = objects['IsTruncated'] | ||
if more: | ||
next_token = objects['NextContinuationToken'] | ||
files_list = objects.get('Contents', None) | ||
if files_list is None: | ||
continue | ||
long_names = [content['Key'] for content in files_list] | ||
files = [file.split(prefix)[1] for file in long_names] | ||
[all_files.append(f) for f in files] | ||
|
||
if len(all_files) == 0: | ||
raise ValueError("S3 location:{} doesn't have any files".format(data_path)) | ||
manifest_key = prefix + ".amazon.manifest" | ||
manifest_str = json.dumps([{'prefix': data_path}] + all_files) | ||
|
||
s3.put_object(Bucket=bucket, Body=manifest_str.encode('utf-8'), Key=manifest_key) | ||
|
||
return RecordSet("s3://{}/{}".format(bucket, manifest_key), | ||
num_records=num_records, | ||
feature_dim=feature_dim, | ||
channel=channel) | ||
|
||
|
||
def _build_shards(num_shards, array): | ||
if num_shards < 1: | ||
|
@@ -200,12 +257,22 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= | |
raise ex | ||
|
||
|
||
def registry(region_name): | ||
def registry(region_name, algorithm=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe change "algorithm" to "algorithm_repo_name" or similar? |
||
"""Return docker registry for the given AWS region""" | ||
account_id = { | ||
"us-east-1": "382416733822", | ||
"us-east-2": "404615174143", | ||
"us-west-2": "174872318107", | ||
"eu-west-1": "438346466558" | ||
}[region_name] | ||
if algorithm in [None, "PCA", "KMeans", "LinearLearner", "FactorizationMachines"]: | ||
account_id = { | ||
"us-east-1": "382416733822", | ||
"us-east-2": "404615174143", | ||
"us-west-2": "174872318107", | ||
"eu-west-1": "438346466558" | ||
}[region_name] | ||
elif algorithm in ["LDA"]: | ||
account_id = { | ||
"us-east-1": "766337827248", | ||
"us-east-2": "999911452149", | ||
"us-west-2": "266724342769", | ||
"eu-west-1": "999678624901" | ||
}[region_name] | ||
else: | ||
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm)) | ||
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry | ||
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer | ||
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa | ||
from sagemaker.amazon.validation import gt | ||
from sagemaker.predictor import RealTimePredictor | ||
from sagemaker.model import Model | ||
from sagemaker.session import Session | ||
|
||
|
||
class LDA(AmazonAlgorithmEstimatorBase): | ||
|
||
repo = 'lda:1' | ||
|
||
num_topics = hp('num_topics', gt(0), 'An integer greater than zero', int) | ||
alpha0 = hp('alpha0', (), "A float value", float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also be gt(0)? (Basing this off: https://docs.aws.amazon.com/sagemaker/latest/dg/lda_hyperparameters.html ) Also, use single quotes for consistency (here and also line 30). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
max_restarts = hp('max_restarts', gt(0), 'An integer greater than zero', int) | ||
max_iterations = hp('max_iterations', gt(0), 'An integer greater than zero', int) | ||
tol = hp('tol', (gt(0)), "A positive float", float) | ||
|
||
def __init__(self, role, train_instance_type, num_topics, | ||
alpha0=None, max_restarts=None, max_iterations=None, tol=None, **kwargs): | ||
"""Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. | ||
|
||
Amazon SageMaker Latent Dirichlet Allocation is an unsupervised learning algorithm that attempts to describe | ||
a set of observations as a mixture of distinct categories. LDA is most commonly used to discover | ||
a user-specified number of topics shared by documents within a text corpus. | ||
Here each observation is a document, the features are the presence (or occurrence count) of each word, and | ||
the categories are the topics. | ||
|
||
This Estimator may be fit via calls to | ||
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon | ||
:class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3. | ||
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that | ||
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed | ||
to the `fit` call. | ||
|
||
To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please | ||
consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html | ||
|
||
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker | ||
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint, | ||
deploy returns a :class:`~sagemaker.amazon.lda.LDAPredictor` object that can be used | ||
for inference calls using the trained model hosted in the SageMaker Endpoint. | ||
|
||
LDA Estimators can be configured by setting hyperparameters. The available hyperparameters for | ||
LDA are documented below. | ||
|
||
For further information on the AWS LDA algorithm, | ||
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/lda.html | ||
|
||
Args: | ||
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and | ||
APIs that create Amazon SageMaker endpoints use this role to access | ||
training data and model artifacts. After the endpoint is created, | ||
the inference code might use the IAM role, if accessing AWS resource. | ||
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. | ||
num_topics (int): The number of topics for LDA to find within the data. | ||
alpha0 (float): Initial guess for the concentration parameter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indicate that these hps are optional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
max_restarts (int): The number of restarts to perform during the Alternating Least Squares (ALS) | ||
spectral decomposition phase of the algorithm. | ||
max_iterations (int): The maximum number of iterations to perform during the ALS phase of the algorithm. | ||
tol (float): Target error tolerance for the ALS phase of the algorithm. | ||
**kwargs: base class keyword argument values. | ||
""" | ||
|
||
# this algorithm only supports single instance training | ||
super(LDA, self).__init__(role, 1, train_instance_type, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a link to the docs for this. It also indicates that it only supports CPU instances for training. That seems like it would be good to validate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The training job will not fail if not started on CPU instance so I think we shouldn't fail too fast here. Will add comment/link. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, that makes sense. I misunderstood. |
||
self.num_topics = num_topics | ||
self.alpha0 = alpha0 | ||
self.max_restarts = max_restarts | ||
self.max_iterations = max_iterations | ||
self.tol = tol | ||
|
||
def create_model(self): | ||
"""Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest | ||
s3 model data produced by this Estimator.""" | ||
|
||
return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session) | ||
|
||
def fit(self, records, mini_batch_size, **kwargs): | ||
# mini_batch_size is required | ||
if mini_batch_size is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check isn't needed since mini_batch_size isn't an optional parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it is required parameter we must fail if you call: fit(records, None) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally think this kind of validation is generally not worth the code clutter, since Python's support for optional parameters as a first-class feature means that people don't randomly pass None to things and expect it to work. However this is pretty minor and it can add value in some cases so I'm okay with leaving this if you feel strongly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplified the checks |
||
raise ValueError("mini_batch_size must be set") | ||
if not isinstance(mini_batch_size, int) or mini_batch_size < 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this check should go here - mini_batch_size is a shared concept across all the 1P algorithms, so it'd be better to put it in the base class if we want it at all. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not entirely true. Some algorithms have it as optional (e.g. FM) and some do not even have it at all (e.g. XGBoost). Since it is algorithm dependent we need to validate it there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mini_batch_size is a parameter in the fit() method of the base class of all 1P algorithms: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/amazon/amazon_estimator.py#L67 It takes either an int or None; None represents the cases where it's not required. Validating that it's an int if it's not None will apply for all cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right, eventually code will check the type, just leaving the value check only |
||
raise ValueError("mini_batch_size must be positive integer") | ||
|
||
super(LDA, self).fit(records, mini_batch_size, **kwargs) | ||
|
||
|
||
class LDAPredictor(RealTimePredictor): | ||
"""Transforms input vectors to lower-dimesional representations. | ||
|
||
The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this | ||
`RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the | ||
same number of columns as the feature-dimension of the data used to fit the model this | ||
Predictor performs inference on. | ||
|
||
:meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one | ||
for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` | ||
key of the ``Record.label`` field.""" | ||
|
||
def __init__(self, endpoint, sagemaker_session=None): | ||
super(LDAPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), | ||
deserializer=record_deserializer()) | ||
|
||
|
||
class LDAModel(Model): | ||
"""Reference LDA s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return | ||
a Predictor that transforms vectors to a lower-dimensional representation.""" | ||
|
||
def __init__(self, model_data, role, sagemaker_session=None): | ||
sagemaker_session = sagemaker_session or Session() | ||
image = registry(sagemaker_session.boto_session.region_name, LDA.__name__) + "/" + LDA.repo | ||
super(LDAModel, self).__init__(model_data, image, role, predictor_cls=LDAPredictor, | ||
sagemaker_session=sagemaker_session) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import boto3 | ||
import numpy as np | ||
import os | ||
|
||
import sagemaker | ||
from sagemaker import LDA, LDAModel | ||
from sagemaker.amazon.amazon_estimator import RecordSet | ||
from sagemaker.amazon.common import read_records | ||
from sagemaker.utils import name_from_base, sagemaker_timestamp | ||
from tests.integ import DATA_DIR, REGION | ||
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name | ||
|
||
|
||
def test_lda(): | ||
|
||
with timeout(minutes=15): | ||
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) | ||
data_filename = 'nips-train_1.pbr' | ||
data_path = os.path.join(DATA_DIR, 'lda', data_filename) | ||
|
||
with open(data_path, 'rb') as f: | ||
all_records = read_records(f) | ||
|
||
# all records must be same | ||
feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) | ||
|
||
lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10, | ||
sagemaker_session=sagemaker_session, base_job_name='test-lda') | ||
|
||
# upload data and prepare the set | ||
data_location_key = "integ-test-data/lda-" + sagemaker_timestamp() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to do something like "lda.record_set(...)" instead of uploading to S3 separately, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "lda.record_set(...)" works with numpy arrays but not with binary files. With respect you your other comment about "from_s3". I'll see if it can be nicely refactored into one call. |
||
sagemaker_session.upload_data(path=data_path, key_prefix=data_location_key) | ||
record_set = RecordSet.from_s3("s3://{}/{}".format(sagemaker_session.default_bucket(), data_location_key), | ||
num_records=len(all_records), | ||
feature_dim=feature_num, | ||
channel='train') | ||
lda.fit(record_set, 100) | ||
|
||
endpoint_name = name_from_base('lda') | ||
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): | ||
model = LDAModel(lda.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) | ||
predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) | ||
|
||
predict_input = np.random.rand(1, feature_num) | ||
result = predictor.predict(predict_input) | ||
|
||
assert len(result) == 1 | ||
for record in result: | ||
assert record.label["topic_mixture"] is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type(self).name here isn't ideal since this will break if the LDA class is subclassed. (I think the type(self).repo stuff we're already doing isn't the best either, but at least that still works in that case).
Changing the new algorithm parameter to take in the class instead of the name may be an improvement. Open to other suggestions as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good observation. Not sure about the scenario for subclassing these wrapper classes but in any case there are two options: either you add your new class to registry() mapping (just like you add any new class) or provide custom 'train_image' in a subclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm talking about end users subclassing them in their own code. There's no super obvious case for doing so, but it's also not wrong to do so, and all else being equal there shouldn't be unexpected things that break when it happens.
The cost to make this work in the case of subclassing isn't big IMO - you can modify the registry() method to accept a Class instead of a string, then instead of just using direct equality checks, you use is https://docs.python.org/3/library/functions.html#issubclass instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation with class has an issue with circular dependency, let's use same approach as with repo name then