Skip to content

Commit 5b0a720

Browse files
saswatacyangaws
authored andcommitted
Add multiclass support for linear learner (#287)
1 parent 16f5d25 commit 5b0a720

File tree

4 files changed

+95
-19
lines changed

4 files changed

+95
-19
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.6.1
6+
=====
7+
8+
* feature: Added multiclass classification support for linear learner algorithm.
9+
510
1.6.0
611
=====
712

src/sagemaker/amazon/linear_learner.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1616
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1717
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
18-
from sagemaker.amazon.validation import isin, gt, lt, ge
18+
from sagemaker.amazon.validation import isin, gt, lt, ge, le
1919
from sagemaker.predictor import RealTimePredictor
2020
from sagemaker.model import Model
2121
from sagemaker.session import Session
@@ -28,28 +28,28 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
2828
DEFAULT_MINI_BATCH_SIZE = 1000
2929

3030
binary_classifier_model_selection_criteria = hp('binary_classifier_model_selection_criteria',
31-
isin('accuracy', 'f1', 'precision_at_target_recall',
32-
'recall_at_target_precision', 'cross_entropy_loss'),
33-
data_type=str)
31+
isin('accuracy', 'f1', 'f_beta', 'precision_at_target_recall',
32+
'recall_at_target_precision', 'cross_entropy_loss',
33+
'loss_function'), data_type=str)
3434
target_recall = hp('target_recall', (gt(0), lt(1)), "A float in (0,1)", float)
3535
target_precision = hp('target_precision', (gt(0), lt(1)), "A float in (0,1)", float)
3636
positive_example_weight_mult = hp('positive_example_weight_mult', (),
3737
"A float greater than 0 or 'auto' or 'balanced'", str)
3838
epochs = hp('epochs', gt(0), "An integer greater-than 0", int)
39-
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'),
40-
'One of "binary_classifier" or "regressor"', str)
39+
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor', 'multiclass_classifier'),
40+
'One of "binary_classifier" or "multiclass_classifier" or "regressor"', str)
4141
use_bias = hp('use_bias', (), "Either True or False", bool)
4242
num_models = hp('num_models', gt(0), "An integer greater-than 0", int)
4343
num_calibration_samples = hp('num_calibration_samples', gt(0), "An integer greater-than 0", int)
4444
init_method = hp('init_method', isin('uniform', 'normal'), 'One of "uniform" or "normal"', str)
4545
init_scale = hp('init_scale', gt(0), 'A float greater-than 0', float)
4646
init_sigma = hp('init_sigma', gt(0), 'A float greater-than 0', float)
4747
init_bias = hp('init_bias', (), 'A number', float)
48-
optimizer = hp('optimizer', isin('sgd', 'adam', 'auto'), 'One of "sgd", "adam" or "auto', str)
48+
optimizer = hp('optimizer', isin('sgd', 'adam', 'rmsprop', 'auto'), 'One of "sgd", "adam", "rmsprop" or "auto', str)
4949
loss = hp('loss', isin('logistic', 'squared_loss', 'absolute_loss', 'hinge_loss', 'eps_insensitive_squared_loss',
50-
'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss', 'auto'),
50+
'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss', 'softmax_loss', 'auto'),
5151
'"logistic", "squared_loss", "absolute_loss", "hinge_loss", "eps_insensitive_squared_loss", '
52-
'"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss" or "auto"', str)
52+
'"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"', str)
5353
wd = hp('wd', ge(0), 'A float greater-than or equal to 0', float)
5454
l1 = hp('l1', ge(0), 'A float greater-than or equal to 0', float)
5555
momentum = hp('momentum', (ge(0), lt(1)), 'A float in [0,1)', float)
@@ -73,6 +73,10 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
7373
huber_delta = hp('huber_delta', ge(0), 'A float greater-than or equal to 0', float)
7474
early_stopping_patience = hp('early_stopping_patience', gt(0), 'An integer greater-than 0', int)
7575
early_stopping_tolerance = hp('early_stopping_tolerance', gt(0), 'A float greater-than 0', float)
76+
num_classes = hp('num_classes', (gt(0), le(1000000)), 'An integer in [1,1000000]', int)
77+
accuracy_top_k = hp('accuracy_top_k', (gt(0), le(1000000)), 'An integer in [1,1000000]', int)
78+
f_beta = hp('f_beta', gt(0), 'A float greater-than 0', float)
79+
balance_multiclass_weights = hp('balance_multiclass_weights', (), 'A boolean', bool)
7680

7781
def __init__(self, role, train_instance_count, train_instance_type, predictor_type,
7882
binary_classifier_model_selection_criteria=None, target_recall=None, target_precision=None,
@@ -83,7 +87,8 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
8387
lr_scheduler_factor=None, lr_scheduler_minimum_lr=None, normalize_data=None,
8488
normalize_label=None, unbias_data=None, unbias_label=None, num_point_for_scaler=None, margin=None,
8589
quantile=None, loss_insensitivity=None, huber_delta=None, early_stopping_patience=None,
86-
early_stopping_tolerance=None, **kwargs):
90+
early_stopping_tolerance=None, num_classes=None, accuracy_top_k=None, f_beta=None,
91+
balance_multiclass_weights=None, **kwargs):
8792
"""An :class:`Estimator` for binary classification and regression.
8893
8994
Amazon SageMaker Linear Learner provides a solution for both classification and regression problems, allowing
@@ -119,9 +124,10 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
119124
the inference code might use the IAM role, if accessing AWS resource.
120125
train_instance_count (int): Number of Amazon EC2 instances to use for training.
121126
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
122-
predictor_type (str): The type of predictor to learn. Either "binary_classifier" or "regressor".
123-
binary_classifier_model_selection_criteria (str): One of 'accuracy', 'f1', 'precision_at_target_recall',
124-
'recall_at_target_precision', 'cross_entropy_loss'
127+
predictor_type (str): The type of predictor to learn. Either "binary_classifier" or
128+
"multiclass_classifier" or "regressor".
129+
binary_classifier_model_selection_criteria (str): One of 'accuracy', 'f1', 'f_beta',
130+
'precision_at_target_recall', 'recall_at_target_precision', 'cross_entropy_loss', 'loss_function'
125131
target_recall (float): Target recall. Only applicable if binary_classifier_model_selection_criteria is
126132
precision_at_target_recall.
127133
target_precision (float): Target precision. Only applicable if binary_classifier_model_selection_criteria
@@ -139,9 +145,10 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
139145
init_scale (float): For "uniform" init, the range of values.
140146
init_sigma (float): For "normal" init, the standard-deviation.
141147
init_bias (float): Initial weight for bias term
142-
optimizer (str): One of 'sgd', 'adam' or 'auto'
148+
optimizer (str): One of 'sgd', 'adam', 'rmsprop' or 'auto'
143149
loss (str): One of 'logistic', 'squared_loss', 'absolute_loss', 'hinge_loss',
144-
'eps_insensitive_squared_loss', 'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss' or 'auto'
150+
'eps_insensitive_squared_loss', 'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss' or
151+
'softmax_loss' or 'auto'.
145152
wd (float): L2 regularization parameter i.e. the weight decay parameter. Use 0 for no L2 regularization.
146153
l1 (float): L1 regularization parameter. Use 0 for no L1 regularization.
147154
momentum (float): Momentum parameter of sgd optimizer.
@@ -180,6 +187,15 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
180187
early_stopping_tolerance (float): Relative tolerance to measure an improvement in loss. If the ratio of
181188
the improvement in loss divided by the previous best loss is smaller than this value, early stopping will
182189
consider the improvement to be zero.
190+
num_classes (int): The number of classes for the response variable. Required when predictor_type is
191+
multiclass_classifier and ignored otherwise. The classes are assumed to be labeled 0, ..., num_classes - 1.
192+
accuracy_top_k (int): The value of k when computing the Top K Accuracy metric for multiclass
193+
classification. An example is scored as correct if the model assigns one of the top k scores to the true
194+
label.
195+
f_beta (float): The value of beta to use when calculating F score metrics for binary or multiclass
196+
classification. Also used if binary_classifier_model_selection_criteria is f_beta.
197+
balance_multiclass_weights (bool): Whether to use class weights which give each class equal importance in
198+
the loss function. Only used when predictor_type is multiclass_classifier.
183199
**kwargs: base class keyword argument values.
184200
"""
185201
super(LinearLearner, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
@@ -221,6 +237,14 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
221237
self.huber_delta = huber_delta
222238
self.early_stopping_patience = early_stopping_patience
223239
self.early_stopping_tolerance = early_stopping_tolerance
240+
self.num_classes = num_classes
241+
self.accuracy_top_k = accuracy_top_k
242+
self.f_beta = f_beta
243+
self.balance_multiclass_weights = balance_multiclass_weights
244+
245+
if self.predictor_type == 'multiclass_classifier' and (num_classes is None or num_classes < 3):
246+
raise ValueError(
247+
"For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2.")
224248

225249
def create_model(self):
226250
"""Return a :class:`~sagemaker.amazon.kmeans.LinearLearnerModel` referencing the latest

tests/integ/test_linear_learner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,35 @@ def test_linear_learner(sagemaker_session):
9292
assert record.label["score"] is not None
9393

9494

95+
def test_linear_learner_multiclass(sagemaker_session):
96+
with timeout(minutes=15):
97+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
98+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
99+
100+
# Load the data into memory as numpy arrays
101+
with gzip.open(data_path, 'rb') as f:
102+
train_set, _, _ = pickle.load(f, **pickle_args)
103+
104+
train_set = train_set[0], train_set[1].astype(np.dtype('float32'))
105+
106+
ll = LinearLearner('SageMakerRole', 1, 'ml.c4.2xlarge', base_job_name='test-linear-learner',
107+
predictor_type='multiclass_classifier', num_classes=10, sagemaker_session=sagemaker_session)
108+
109+
ll.epochs = 1
110+
ll.fit(ll.record_set(train_set[0][:200], train_set[1][:200]))
111+
112+
endpoint_name = name_from_base('linear-learner')
113+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
114+
115+
predictor = ll.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
116+
117+
result = predictor.predict(train_set[0][0:100])
118+
assert len(result) == 100
119+
for record in result:
120+
assert record.label["predicted_label"] is not None
121+
assert record.label["score"] is not None
122+
123+
95124
def test_async_linear_learner(sagemaker_session):
96125
training_job_name = ""
97126
endpoint_name = 'test-linear-learner-async-{}'.format(sagemaker_timestamp())

tests/unit/test_linear_learner.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_all_hyperparameters(sagemaker_session):
8181
lr_scheduler_minimum_lr=0.001, normalize_data=False, normalize_label=True,
8282
unbias_data=True, unbias_label=False, num_point_for_scaler=3, margin=1.0,
8383
quantile=0.5, loss_insensitivity=0.1, huber_delta=0.1, early_stopping_patience=3,
84-
early_stopping_tolerance=0.001, **ALL_REQ_ARGS)
84+
early_stopping_tolerance=0.001, num_classes=1, accuracy_top_k=3, f_beta=1.0,
85+
balance_multiclass_weights=False, **ALL_REQ_ARGS)
8586

8687
assert lr.hyperparameters() == dict(
8788
predictor_type='binary_classifier', binary_classifier_model_selection_criteria='accuracy',
@@ -93,7 +94,8 @@ def test_all_hyperparameters(sagemaker_session):
9394
lr_scheduler_factor='0.03', lr_scheduler_minimum_lr='0.001', normalize_data='False',
9495
normalize_label='True', unbias_data='True', unbias_label='False', num_point_for_scaler='3', margin='1.0',
9596
quantile='0.5', loss_insensitivity='0.1', huber_delta='0.1', early_stopping_patience='3',
96-
early_stopping_tolerance='0.001',
97+
early_stopping_tolerance='0.001', num_classes='1', accuracy_top_k='3', f_beta='1.0',
98+
balance_multiclass_weights='False',
9799
)
98100

99101

@@ -122,6 +124,15 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param
122124
LinearLearner(sagemaker_session=sagemaker_session, **test_params)
123125

124126

127+
def test_num_classes_is_required_for_multiclass_classifier(sagemaker_session):
128+
with pytest.raises(ValueError) as excinfo:
129+
test_params = ALL_REQ_ARGS.copy()
130+
test_params["predictor_type"] = 'multiclass_classifier'
131+
LinearLearner(sagemaker_session=sagemaker_session, **test_params)
132+
assert "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2." in str(
133+
excinfo.value)
134+
135+
125136
@pytest.mark.parametrize('iterable_hyper_parameters, value', [
126137
('eval_metrics', 0)
127138
])
@@ -162,7 +173,10 @@ def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parame
162173
('loss_insensitivity', 'string'),
163174
('huber_delta', 'string'),
164175
('early_stopping_patience', 'string'),
165-
('early_stopping_tolerance', 'string')
176+
('early_stopping_tolerance', 'string'),
177+
('num_classes', 'string'),
178+
('accuracy_top_k', 'string'),
179+
('f_beta', 'string'),
166180
])
167181
def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value):
168182
with pytest.raises(ValueError):
@@ -204,7 +218,11 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame
204218
('loss_insensitivity', 0),
205219
('huber_delta', -1),
206220
('early_stopping_patience', 0),
207-
('early_stopping_tolerance', 0)
221+
('early_stopping_tolerance', 0),
222+
('num_classes', 0),
223+
('accuracy_top_k', 0),
224+
('f_beta', -1.0),
225+
208226
])
209227
def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value):
210228
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)