-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add wrapper for LDA. #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e2b8f0d
33fc211
30c65e0
e32aac7
b8658b3
4e1c134
324655d
5094e97
238db37
7cc945a
1dfc4c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,6 @@ | |
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import boto3 | ||
import json | ||
import logging | ||
import tempfile | ||
|
@@ -19,7 +18,6 @@ | |
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa | ||
from sagemaker.amazon.common import write_numpy_to_dense_tensor | ||
from sagemaker.estimator import EstimatorBase | ||
from sagemaker.fw_utils import parse_s3_url | ||
from sagemaker.session import s3_input | ||
from sagemaker.utils import sagemaker_timestamp | ||
|
||
|
@@ -49,7 +47,8 @@ def __init__(self, role, train_instance_count, train_instance_type, data_locatio | |
self.data_location = data_location | ||
|
||
def train_image(self): | ||
return registry(self.sagemaker_session.boto_region_name, type(self).__name__) + "/" + type(self).repo | ||
repo = '{}:{}'.format(type(self).alg_name, type(self).alg_version) | ||
return registry(self.sagemaker_session.boto_region_name, type(self).alg_name) + "/" + repo | ||
|
||
def hyperparameters(self): | ||
return hp.serialize_all(self) | ||
|
@@ -127,6 +126,26 @@ def record_set(self, train, labels=None, channel="train"): | |
logger.debug("Created manifest file {}".format(manifest_s3_file)) | ||
return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) | ||
|
||
def record_set_from_local_files(self, data_path, num_records, feature_dim, channel="train"): | ||
"""Build a :class:`~RecordSet` by pointing to local files. | ||
|
||
Args: | ||
data_path (string): Path to local file to be uploaded for training. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need more specificity either here or in the overall comment. The user needs to know whether it works on single files, directories, or both, etc. |
||
num_records (int): Number of records in all the files | ||
feature_dim (int): Number of features in the data set | ||
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to. | ||
Returns: | ||
RecordSet: A RecordSet specified by S3Prefix to to be used in training. | ||
""" | ||
|
||
parsed_s3_url = urlparse(self.data_location) | ||
bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path | ||
key_prefix = key_prefix + '{}-{}'.format(type(self).__name__, sagemaker_timestamp()) | ||
key_prefix = key_prefix.lstrip('/') | ||
logger.debug('Uploading to bucket {} and key_prefix {}'.format(bucket, key_prefix)) | ||
uploaded_location = self.sagemaker_session.upload_data(path=data_path, key_prefix=key_prefix) | ||
return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix', channel=channel) | ||
|
||
|
||
class RecordSet(object): | ||
|
||
|
@@ -154,61 +173,6 @@ def __repr__(self): | |
"""Return an unambiguous representation of this RecordSet""" | ||
return str((RecordSet, self.__dict__)) | ||
|
||
@staticmethod | ||
def from_s3(data_path, num_records, feature_dim, channel='train'): | ||
""" | ||
Create instance of the class given S3 path. It prepares the manifest file with all files found at the location. | ||
|
||
Args: | ||
data_path: S3 path to files | ||
num_records: Number of records at S3 location | ||
feature_dim: Number of features in each of the files | ||
channel: Name of the data channel | ||
|
||
Returns: | ||
Instance of RecordSet that can be used when calling | ||
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit` | ||
""" | ||
s3 = boto3.client('s3') | ||
|
||
if not data_path.endswith('/'): | ||
data_path = data_path + '/' | ||
|
||
bucket, prefix = parse_s3_url(data_path) | ||
|
||
all_files = [] | ||
next_token = None | ||
more = True | ||
while more: | ||
list_req = { | ||
'Bucket': bucket, | ||
'Prefix': prefix | ||
} | ||
if next_token is not None: | ||
list_req.update({'ContinuationToken': next_token}) | ||
objects = s3.list_objects_v2(**list_req) | ||
more = objects['IsTruncated'] | ||
if more: | ||
next_token = objects['NextContinuationToken'] | ||
files_list = objects.get('Contents', None) | ||
if files_list is None: | ||
continue | ||
long_names = [content['Key'] for content in files_list] | ||
files = [file.split(prefix)[1] for file in long_names] | ||
[all_files.append(f) for f in files] | ||
|
||
if len(all_files) == 0: | ||
raise ValueError("S3 location:{} doesn't have any files".format(data_path)) | ||
manifest_key = prefix + ".amazon.manifest" | ||
manifest_str = json.dumps([{'prefix': data_path}] + all_files) | ||
|
||
s3.put_object(Bucket=bucket, Body=manifest_str.encode('utf-8'), Key=manifest_key) | ||
|
||
return RecordSet("s3://{}/{}".format(bucket, manifest_key), | ||
num_records=num_records, | ||
feature_dim=feature_dim, | ||
channel=channel) | ||
|
||
|
||
def _build_shards(num_shards, array): | ||
if num_shards < 1: | ||
|
@@ -259,14 +223,14 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= | |
|
||
def registry(region_name, algorithm=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe change "algorithm" to "algorithm_repo_name" or similar? |
||
"""Return docker registry for the given AWS region""" | ||
if algorithm in [None, "PCA", "KMeans", "LinearLearner", "FactorizationMachines"]: | ||
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines"]: | ||
account_id = { | ||
"us-east-1": "382416733822", | ||
"us-east-2": "404615174143", | ||
"us-west-2": "174872318107", | ||
"eu-west-1": "438346466558" | ||
}[region_name] | ||
elif algorithm in ["LDA"]: | ||
elif algorithm in ["lda"]: | ||
account_id = { | ||
"us-east-1": "766337827248", | ||
"us-east-2": "999911452149", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,8 @@ | |
|
||
class FactorizationMachines(AmazonAlgorithmEstimatorBase): | ||
|
||
repo = 'factorization-machines:1' | ||
alg_name = 'factorization-machines' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would still name these "repo_name" and "repo_tag" or similar, since that's what they are first and foremost. |
||
alg_version = 1 | ||
|
||
num_factors = hp('num_factors', gt(0), 'An integer greater than zero', int) | ||
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'), | ||
|
@@ -194,7 +195,8 @@ class FactorizationMachinesModel(Model): | |
|
||
def __init__(self, model_data, role, sagemaker_session=None): | ||
sagemaker_session = sagemaker_session or Session() | ||
image = registry(sagemaker_session.boto_session.region_name) + "/" + FactorizationMachines.repo | ||
repo = '{}:{}'.format(FactorizationMachines.alg_name, FactorizationMachines.alg_version) | ||
image = registry(sagemaker_session.boto_session.region_name) + "/" + repo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format here as well for consistency? |
||
super(FactorizationMachinesModel, self).__init__(model_data, | ||
image, | ||
role, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,13 +21,14 @@ | |
|
||
class LDA(AmazonAlgorithmEstimatorBase): | ||
|
||
repo = 'lda:1' | ||
alg_name = 'lda' | ||
alg_version = 1 | ||
|
||
num_topics = hp('num_topics', gt(0), 'An integer greater than zero', int) | ||
alpha0 = hp('alpha0', (), "A float value", float) | ||
alpha0 = hp('alpha0', gt(0), 'A positive float', float) | ||
max_restarts = hp('max_restarts', gt(0), 'An integer greater than zero', int) | ||
max_iterations = hp('max_iterations', gt(0), 'An integer greater than zero', int) | ||
tol = hp('tol', (gt(0)), "A positive float", float) | ||
tol = hp('tol', gt(0), 'A positive float', float) | ||
|
||
def __init__(self, role, train_instance_type, num_topics, | ||
alpha0=None, max_restarts=None, max_iterations=None, tol=None, **kwargs): | ||
|
@@ -67,11 +68,12 @@ def __init__(self, role, train_instance_type, num_topics, | |
the inference code might use the IAM role, if accessing AWS resource. | ||
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. | ||
num_topics (int): The number of topics for LDA to find within the data. | ||
alpha0 (float): Initial guess for the concentration parameter | ||
max_restarts (int): The number of restarts to perform during the Alternating Least Squares (ALS) | ||
alpha0 (float): Optional. Initial guess for the concentration parameter | ||
max_restarts (int): Optional. The number of restarts to perform during the Alternating Least Squares (ALS) | ||
spectral decomposition phase of the algorithm. | ||
max_iterations (int): The maximum number of iterations to perform during the ALS phase of the algorithm. | ||
tol (float): Target error tolerance for the ALS phase of the algorithm. | ||
max_iterations (int): Optional. The maximum number of iterations to perform during the | ||
ALS phase of the algorithm. | ||
tol (float): Optional. Target error tolerance for the ALS phase of the algorithm. | ||
**kwargs: base class keyword argument values. | ||
""" | ||
|
||
|
@@ -90,12 +92,9 @@ def create_model(self): | |
return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session) | ||
|
||
def fit(self, records, mini_batch_size, **kwargs): | ||
# mini_batch_size is required | ||
# mini_batch_size is required, prevent explicit calls with None | ||
if mini_batch_size is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check isn't needed since mini_batch_size isn't an optional parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it is required parameter we must fail if you call: fit(records, None) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally think this kind of validation is generally not worth the code clutter, since Python's support for optional parameters as a first-class feature means that people don't randomly pass None to things and expect it to work. However this is pretty minor and it can add value in some cases so I'm okay with leaving this if you feel strongly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplified the checks |
||
raise ValueError("mini_batch_size must be set") | ||
if not isinstance(mini_batch_size, int) or mini_batch_size < 1: | ||
raise ValueError("mini_batch_size must be positive integer") | ||
|
||
super(LDA, self).fit(records, mini_batch_size, **kwargs) | ||
|
||
|
||
|
@@ -122,6 +121,7 @@ class LDAModel(Model): | |
|
||
def __init__(self, model_data, role, sagemaker_session=None): | ||
sagemaker_session = sagemaker_session or Session() | ||
image = registry(sagemaker_session.boto_session.region_name, LDA.__name__) + "/" + LDA.repo | ||
repo = '{}:{}'.format(LDA.alg_name, LDA.alg_version) | ||
image = registry(sagemaker_session.boto_session.region_name, LDA.alg_name) + "/" + repo | ||
super(LDAModel, self).__init__(model_data, image, role, predictor_cls=LDAPredictor, | ||
sagemaker_session=sagemaker_session) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests please.