Skip to content

Commit 537ba79

Browse files
jaredleekatzmanjesterhazy
authored andcommitted
Add custom estimator for IP Insights algorithm (aws#493)
* Estimators: add support for Amazon IP Insights algorithm
1 parent 6412991 commit 537ba79

File tree

9 files changed

+453
-3
lines changed

9 files changed

+453
-3
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ CHANGELOG
1414
* feature: HyperparameterTuner: Make input channels optional
1515
* feature: Add support for Chainer 5.0
1616
* feature: Estimator: add support for MetricDefinitions
17+
* feature: Estimators: add support for Amazon IP Insights algorithm
1718

1819
1.14.2
1920
======

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
414414
The full list of algorithms is available at: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
415415
416416
The SageMaker Python SDK includes estimator wrappers for the AWS K-means, Principal Components Analysis (PCA), Linear Learner, Factorization Machines,
417-
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM), Random Cut Forest, k-nearest neighbors (k-NN), and Object2Vec algorithms.
417+
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM), Random Cut Forest, k-nearest neighbors (k-NN), Object2Vec, and IP Insights algorithms.
418418
419419
For more information, see `AWS SageMaker Estimators and Models`_.
420420

src/sagemaker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
RandomCutForestPredictor)
2525
from sagemaker.amazon.knn import KNN, KNNModel, KNNPredictor # noqa: F401
2626
from sagemaker.amazon.object2vec import Object2Vec, Object2VecModel # noqa: F401
27+
from sagemaker.amazon.ipinsights import IPInsights, IPInsightsModel, IPInsightsPredictor # noqa: F401
2728

2829
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401
2930
from sagemaker.local.local_session import LocalSession # noqa: F401

src/sagemaker/amazon/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
77

88
The full list of algorithms is available on the AWS website: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
99

10-
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms, k-nearest neighbors (k-NN) and Object2Vec.
10+
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms, k-nearest neighbors (k-NN), Object2Vec, and IP Insights.
1111

1212
Definition and usage
1313
~~~~~~~~~~~~~~~~~~~~

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def registry(region_name, algorithm=None):
284284
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
285285
"""
286286
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm",
287-
"randomcutforest", "knn", "object2vec"]:
287+
"randomcutforest", "knn", "object2vec", "ipinsights"]:
288288
account_id = {
289289
"us-east-1": "382416733822",
290290
"us-east-2": "404615174143",

src/sagemaker/amazon/ipinsights.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
17+
from sagemaker.amazon.validation import ge, le
18+
from sagemaker.predictor import RealTimePredictor, csv_serializer, json_deserializer
19+
from sagemaker.model import Model
20+
from sagemaker.session import Session
21+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
22+
23+
24+
class IPInsights(AmazonAlgorithmEstimatorBase):
25+
repo_name = 'ipinsights'
26+
repo_version = 1
27+
MINI_BATCH_SIZE = 10000
28+
29+
num_entity_vectors = hp('num_entity_vectors', (ge(1), le(250000000)), 'An integer in [1, 250000000]', int)
30+
vector_dim = hp('vector_dim', (ge(4), le(4096)), 'An integer in [4, 4096]', int)
31+
32+
batch_metrics_publish_interval = hp('batch_metrics_publish_interval', (ge(1)), 'An integer greater than 0', int)
33+
epochs = hp('epochs', (ge(1)), 'An integer greater than 0', int)
34+
learning_rate = hp('learning_rate', (ge(1e-6), le(10.0)), 'A float in [1e-6, 10.0]', float)
35+
num_ip_encoder_layers = hp('num_ip_encoder_layers', (ge(0), le(100)), 'An integer in [0, 100]', int)
36+
random_negative_sampling_rate = hp('random_negative_sampling_rate', (ge(0), le(500)), 'An integer in [0, 500]', int)
37+
shuffled_negative_sampling_rate = hp('shuffled_negative_sampling_rate', (ge(0), le(500)), 'An integer in [0, 500]',
38+
int)
39+
weight_decay = hp('weight_decay', (ge(0.0), le(10.0)), 'A float in [0.0, 10.0]', float)
40+
41+
def __init__(self, role, train_instance_count, train_instance_type, num_entity_vectors, vector_dim,
42+
batch_metrics_publish_interval=None, epochs=None, learning_rate=None,
43+
num_ip_encoder_layers=None, random_negative_sampling_rate=None,
44+
shuffled_negative_sampling_rate=None, weight_decay=None, **kwargs):
45+
"""This estimator is for IP Insights, an unsupervised algorithm that learns usage patterns of IP addresses.
46+
47+
This Estimator may be fit via calls to
48+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires
49+
CSV data to be stored in S3.
50+
51+
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
52+
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint,
53+
deploy returns a :class:`~sagemaker.amazon.IPInsightPredictor` object that can be used
54+
for inference calls using the trained model hosted in the SageMaker Endpoint.
55+
56+
IPInsights Estimators can be configured by setting hyperparamters.
57+
The available hyperparamters are documented below.
58+
59+
For further information on the AWS IPInsights algorithm, please consult AWS technical documentation:
60+
https://docs.aws.amazon.com/sagemaker/latest/dg/ip-insights-hyperparameters.html
61+
62+
Args:
63+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
64+
APIs that create Amazon SageMaker endpoints use this role to access
65+
training data and model artifacts. After the endpoint is created,
66+
the inference code might use the IAM role, if accessing AWS resource.
67+
train_instance_count (int): Number of Amazon EC2 instances to use for training.
68+
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.m5.xlarge'.
69+
num_entity_vectors (int): Required. The number of embeddings to train for entities accessing online
70+
resources. We recommend 2x the total number of unique entity IDs.
71+
vector_dim (int): Required. The size of the embedding vectors for both entity and IP addresses.
72+
batch_metrics_publish_interval (int): Optional. The period at which to publish metrics (batches).
73+
epochs (int): Optional. Maximum number of passes over the training data.
74+
learning_rate (float): Optional. Learning rate for the optimizer.
75+
num_ip_encoder_layers (int): Optional. The number of fully-connected layers to encode IP address embedding.
76+
random_negative_sampling_rate (int): Optional. The ratio of random negative samples to draw during training.
77+
Random negative samples are randomly drawn IPv4 addresses.
78+
shuffled_negative_sampling_rate (int): Optional. The ratio of shuffled negative samples to draw during
79+
training. Shuffled negative samples are IP addresses picked from within a batch.
80+
weight_decay (float): Optional. Weight decay coefficient. Adds L2 regularization.
81+
**kwargs: base class keyword argument values.
82+
"""
83+
super(IPInsights, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
84+
self.num_entity_vectors = num_entity_vectors
85+
self.vector_dim = vector_dim
86+
self.batch_metrics_publish_interval = batch_metrics_publish_interval
87+
self.epochs = epochs
88+
self.learning_rate = learning_rate
89+
self.num_ip_encoder_layers = num_ip_encoder_layers
90+
self.random_negative_sampling_rate = random_negative_sampling_rate
91+
self.shuffled_negative_sampling_rate = shuffled_negative_sampling_rate
92+
self.weight_decay = weight_decay
93+
94+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
95+
"""Create a model for the latest s3 model produced by this estimator.
96+
97+
Args:
98+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model.
99+
Default: use subnets and security groups from this Estimator.
100+
* 'Subnets' (list[str]): List of subnet ids.
101+
* 'SecurityGroupIds' (list[str]): List of security group ids.
102+
Returns:
103+
:class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model data produced by this estimator.
104+
"""
105+
return IPInsightsModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session,
106+
vpc_config=self.get_vpc_config(vpc_config_override))
107+
108+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
109+
if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 500000):
110+
raise ValueError("mini_batch_size must be in [1, 500000]")
111+
super(IPInsights, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
112+
113+
114+
class IPInsightsPredictor(RealTimePredictor):
115+
"""Returns dot product of entity and IP address embeddings as a score for compatibility.
116+
117+
The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this
118+
`RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain
119+
two columns. The first column should contain the entity ID. The second column should
120+
contain the IPv4 address in dot notation.
121+
"""
122+
123+
def __init__(self, endpoint, sagemaker_session=None):
124+
super(IPInsightsPredictor, self).__init__(endpoint, sagemaker_session,
125+
serializer=csv_serializer,
126+
deserializer=json_deserializer)
127+
128+
129+
class IPInsightsModel(Model):
130+
"""Reference IPInsights s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an
131+
Endpoint and returns a Predictor that calculates anomaly scores for data points."""
132+
133+
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
134+
sagemaker_session = sagemaker_session or Session()
135+
repo = '{}:{}'.format(IPInsights.repo_name, IPInsights.repo_version)
136+
image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name,
137+
IPInsights.repo_name), repo)
138+
139+
super(IPInsightsModel, self).__init__(
140+
model_data, image, role,
141+
predictor_cls=IPInsightsPredictor,
142+
sagemaker_session=sagemaker_session,
143+
**kwargs)

tests/data/ipinsights/train.csv

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
user_1,1.1.1.1
2+
user_1,1.1.1.1
3+
user_1,1.1.1.1
4+
user_1,1.1.1.1
5+
user_1,1.1.1.1
6+
user_1,1.1.1.1
7+
user_1,1.1.1.1
8+
user_1,1.1.1.1
9+
user_1,1.1.1.1
10+
user_1,1.1.1.1

tests/integ/test_ipinsights.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import pytest
17+
18+
from sagemaker import IPInsights, IPInsightsModel
19+
from sagemaker.predictor import RealTimePredictor
20+
from sagemaker.utils import name_from_base
21+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
22+
from tests.integ.record_set import prepare_record_set_from_local_files
23+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
24+
25+
FEATURE_DIM = None
26+
27+
28+
@pytest.mark.continuous_testing
29+
def test_ipinsights(sagemaker_session):
30+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
31+
data_path = os.path.join(DATA_DIR, 'ipinsights')
32+
data_filename = 'train.csv'
33+
34+
with open(os.path.join(data_path, data_filename), 'rb') as f:
35+
num_records = len(f.readlines())
36+
37+
ipinsights = IPInsights(
38+
role='SageMakerRole',
39+
train_instance_count=1,
40+
train_instance_type='ml.c4.xlarge',
41+
num_entity_vectors=10,
42+
vector_dim=100,
43+
sagemaker_session=sagemaker_session,
44+
base_job_name='test-ipinsights')
45+
46+
record_set = prepare_record_set_from_local_files(data_path, ipinsights.data_location,
47+
num_records, FEATURE_DIM, sagemaker_session)
48+
ipinsights.fit(record_set, None)
49+
50+
endpoint_name = name_from_base('ipinsights')
51+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
52+
model = IPInsightsModel(ipinsights.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session)
53+
predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
54+
assert isinstance(predictor, RealTimePredictor)
55+
56+
predict_input = [['user_1', '1.1.1.1']]
57+
result = predictor.predict(predict_input)
58+
59+
assert len(result) == 1
60+
for record in result:
61+
assert record.label["dot_product"] is not None

0 commit comments

Comments
 (0)