Skip to content

Commit f939f89

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
change: Add PipelineVariable annotation to Amazon estimators (#3373)
Co-authored-by: Dewen Qi <[email protected]>
1 parent ac353b2 commit f939f89

11 files changed

+481
-425
lines changed

src/sagemaker/amazon/amazon_estimator.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19-
from typing import Union
19+
from typing import Union, Optional, Dict
2020

2121
from six.moves.urllib.parse import urlparse
2222

@@ -30,6 +30,7 @@
3030
from sagemaker.utils import sagemaker_timestamp
3131
from sagemaker.workflow.entities import PipelineVariable
3232
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
33+
from sagemaker.workflow import is_pipeline_variable
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -40,18 +41,20 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4041
This class isn't intended to be instantiated directly.
4142
"""
4243

43-
feature_dim = hp("feature_dim", validation.gt(0), data_type=int)
44-
mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int)
45-
repo_name = None
46-
repo_version = None
44+
feature_dim: hp = hp("feature_dim", validation.gt(0), data_type=int)
45+
mini_batch_size: hp = hp("mini_batch_size", validation.gt(0), data_type=int)
46+
repo_name: Optional[str] = None
47+
repo_version: Optional[str] = None
48+
49+
DEFAULT_MINI_BATCH_SIZE: Optional[int] = None
4750

4851
def __init__(
4952
self,
50-
role,
51-
instance_count=None,
52-
instance_type=None,
53-
data_location=None,
54-
enable_network_isolation=False,
53+
role: str,
54+
instance_count: Optional[Union[int, PipelineVariable]] = None,
55+
instance_type: Optional[Union[str, PipelineVariable]] = None,
56+
data_location: Optional[str] = None,
57+
enable_network_isolation: Union[bool, PipelineVariable] = False,
5558
**kwargs
5659
):
5760
"""Initialize an AmazonAlgorithmEstimatorBase.
@@ -62,16 +65,16 @@ def __init__(
6265
endpoints use this role to access training data and model
6366
artifacts. After the endpoint is created, the inference code
6467
might use the IAM role, if it needs to access an AWS resource.
65-
instance_count (int): Number of Amazon EC2 instances to use
68+
instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
6669
for training. Required.
67-
instance_type (str): Type of EC2 instance to use for training,
70+
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
6871
for example, 'ml.c4.xlarge'. Required.
6972
data_location (str or None): The s3 prefix to upload RecordSet
7073
objects to, expressed as an S3 url. For example
7174
"s3://example-bucket/some-key-prefix/". Objects will be saved in
7275
a unique sub-directory of the specified location. If None, a
7376
default data location will be used.
74-
enable_network_isolation (bool): Specifies whether container will
77+
enable_network_isolation (bool or PipelineVariable): Specifies whether container will
7578
run in network isolation mode. Network isolation mode restricts
7679
the container access to outside networks (such as the internet).
7780
Also known as internet-free mode (default: ``False``).
@@ -113,8 +116,14 @@ def data_location(self):
113116
return self._data_location
114117

115118
@data_location.setter
116-
def data_location(self, data_location):
119+
def data_location(self, data_location: str):
117120
"""Placeholder docstring"""
121+
if is_pipeline_variable(data_location):
122+
raise TypeError(
123+
"Invalid input: data_location should be a plain string "
124+
"rather than a pipeline variable - ({}).".format(type(data_location))
125+
)
126+
118127
if not data_location.startswith("s3://"):
119128
raise ValueError(
120129
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
@@ -198,12 +207,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
198207
@runnable_by_pipeline
199208
def fit(
200209
self,
201-
records,
202-
mini_batch_size=None,
203-
wait=True,
204-
logs=True,
205-
job_name=None,
206-
experiment_config=None,
210+
records: "RecordSet",
211+
mini_batch_size: Optional[int] = None,
212+
wait: bool = True,
213+
logs: bool = True,
214+
job_name: Optional[str] = None,
215+
experiment_config: Optional[Dict[str, str]] = None,
207216
):
208217
"""Fit this Estimator on serialized Record objects, stored in S3.
209218
@@ -301,6 +310,20 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
301310
channel=channel,
302311
)
303312

313+
def _get_default_mini_batch_size(self, num_records: int):
314+
"""Generate the default mini_batch_size"""
315+
if is_pipeline_variable(self.instance_count):
316+
logger.warning(
317+
"mini_batch_size is not given in .fit() and instance_count is a "
318+
"pipeline variable (%s) which is only interpreted in pipeline execution time. "
319+
"Thus setting mini_batch_size to 1, since it can't be greater than "
320+
"number of records per instance_count, otherwise the training job fails.",
321+
type(self.instance_count),
322+
)
323+
return 1
324+
325+
return min(self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)))
326+
304327

305328
class RecordSet(object):
306329
"""Placeholder docstring"""
@@ -461,7 +484,7 @@ def upload_numpy_to_s3_shards(
461484
raise ex
462485

463486

464-
def get_image_uri(region_name, repo_name, repo_version=1):
487+
def get_image_uri(region_name, repo_name, repo_version="1"):
465488
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
466489
467490
Args:

src/sagemaker/amazon/factorization_machines.py

+58-58
Original file line numberDiff line numberDiff line change
@@ -37,83 +37,83 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase):
3737
sparse datasets economically.
3838
"""
3939

40-
repo_name = "factorization-machines"
41-
repo_version = 1
40+
repo_name: str = "factorization-machines"
41+
repo_version: str = "1"
4242

43-
num_factors = hp("num_factors", gt(0), "An integer greater than zero", int)
44-
predictor_type = hp(
43+
num_factors: hp = hp("num_factors", gt(0), "An integer greater than zero", int)
44+
predictor_type: hp = hp(
4545
"predictor_type",
4646
isin("binary_classifier", "regressor"),
4747
'Value "binary_classifier" or "regressor"',
4848
str,
4949
)
50-
epochs = hp("epochs", gt(0), "An integer greater than 0", int)
51-
clip_gradient = hp("clip_gradient", (), "A float value", float)
52-
eps = hp("eps", (), "A float value", float)
53-
rescale_grad = hp("rescale_grad", (), "A float value", float)
54-
bias_lr = hp("bias_lr", ge(0), "A non-negative float", float)
55-
linear_lr = hp("linear_lr", ge(0), "A non-negative float", float)
56-
factors_lr = hp("factors_lr", ge(0), "A non-negative float", float)
57-
bias_wd = hp("bias_wd", ge(0), "A non-negative float", float)
58-
linear_wd = hp("linear_wd", ge(0), "A non-negative float", float)
59-
factors_wd = hp("factors_wd", ge(0), "A non-negative float", float)
60-
bias_init_method = hp(
50+
epochs: hp = hp("epochs", gt(0), "An integer greater than 0", int)
51+
clip_gradient: hp = hp("clip_gradient", (), "A float value", float)
52+
eps: hp = hp("eps", (), "A float value", float)
53+
rescale_grad: hp = hp("rescale_grad", (), "A float value", float)
54+
bias_lr: hp = hp("bias_lr", ge(0), "A non-negative float", float)
55+
linear_lr: hp = hp("linear_lr", ge(0), "A non-negative float", float)
56+
factors_lr: hp = hp("factors_lr", ge(0), "A non-negative float", float)
57+
bias_wd: hp = hp("bias_wd", ge(0), "A non-negative float", float)
58+
linear_wd: hp = hp("linear_wd", ge(0), "A non-negative float", float)
59+
factors_wd: hp = hp("factors_wd", ge(0), "A non-negative float", float)
60+
bias_init_method: hp = hp(
6161
"bias_init_method",
6262
isin("normal", "uniform", "constant"),
6363
'Value "normal", "uniform" or "constant"',
6464
str,
6565
)
66-
bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float)
67-
bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68-
bias_init_value = hp("bias_init_value", (), "A float value", float)
69-
linear_init_method = hp(
66+
bias_init_scale: hp = hp("bias_init_scale", ge(0), "A non-negative float", float)
67+
bias_init_sigma: hp = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68+
bias_init_value: hp = hp("bias_init_value", (), "A float value", float)
69+
linear_init_method: hp = hp(
7070
"linear_init_method",
7171
isin("normal", "uniform", "constant"),
7272
'Value "normal", "uniform" or "constant"',
7373
str,
7474
)
75-
linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float)
76-
linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77-
linear_init_value = hp("linear_init_value", (), "A float value", float)
78-
factors_init_method = hp(
75+
linear_init_scale: hp = hp("linear_init_scale", ge(0), "A non-negative float", float)
76+
linear_init_sigma: hp = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77+
linear_init_value: hp = hp("linear_init_value", (), "A float value", float)
78+
factors_init_method: hp = hp(
7979
"factors_init_method",
8080
isin("normal", "uniform", "constant"),
8181
'Value "normal", "uniform" or "constant"',
8282
str,
8383
)
84-
factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float)
85-
factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86-
factors_init_value = hp("factors_init_value", (), "A float value", float)
84+
factors_init_scale: hp = hp("factors_init_scale", ge(0), "A non-negative float", float)
85+
factors_init_sigma: hp = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86+
factors_init_value: hp = hp("factors_init_value", (), "A float value", float)
8787

8888
def __init__(
8989
self,
90-
role,
91-
instance_count=None,
92-
instance_type=None,
93-
num_factors=None,
94-
predictor_type=None,
95-
epochs=None,
96-
clip_gradient=None,
97-
eps=None,
98-
rescale_grad=None,
99-
bias_lr=None,
100-
linear_lr=None,
101-
factors_lr=None,
102-
bias_wd=None,
103-
linear_wd=None,
104-
factors_wd=None,
105-
bias_init_method=None,
106-
bias_init_scale=None,
107-
bias_init_sigma=None,
108-
bias_init_value=None,
109-
linear_init_method=None,
110-
linear_init_scale=None,
111-
linear_init_sigma=None,
112-
linear_init_value=None,
113-
factors_init_method=None,
114-
factors_init_scale=None,
115-
factors_init_sigma=None,
116-
factors_init_value=None,
90+
role: str,
91+
instance_count: Optional[Union[int, PipelineVariable]] = None,
92+
instance_type: Optional[Union[str, PipelineVariable]] = None,
93+
num_factors: Optional[int] = None,
94+
predictor_type: Optional[str] = None,
95+
epochs: Optional[int] = None,
96+
clip_gradient: Optional[float] = None,
97+
eps: Optional[float] = None,
98+
rescale_grad: Optional[float] = None,
99+
bias_lr: Optional[float] = None,
100+
linear_lr: Optional[float] = None,
101+
factors_lr: Optional[float] = None,
102+
bias_wd: Optional[float] = None,
103+
linear_wd: Optional[float] = None,
104+
factors_wd: Optional[float] = None,
105+
bias_init_method: Optional[str] = None,
106+
bias_init_scale: Optional[float] = None,
107+
bias_init_sigma: Optional[float] = None,
108+
bias_init_value: Optional[float] = None,
109+
linear_init_method: Optional[str] = None,
110+
linear_init_scale: Optional[float] = None,
111+
linear_init_sigma: Optional[float] = None,
112+
linear_init_value: Optional[float] = None,
113+
factors_init_method: Optional[str] = None,
114+
factors_init_scale: Optional[float] = None,
115+
factors_init_sigma: Optional[float] = None,
116+
factors_init_value: Optional[float] = None,
117117
**kwargs
118118
):
119119
"""Factorization Machines is :class:`Estimator` for general-purpose supervised learning.
@@ -160,9 +160,9 @@ def __init__(
160160
endpoints use this role to access training data and model
161161
artifacts. After the endpoint is created, the inference code
162162
might use the IAM role, if accessing AWS resource.
163-
instance_count (int): Number of Amazon EC2 instances to use
163+
instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
164164
for training.
165-
instance_type (str): Type of EC2 instance to use for training,
165+
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
166166
for example, 'ml.c4.xlarge'.
167167
num_factors (int): Dimensionality of factorization.
168168
predictor_type (str): Type of predictor 'binary_classifier' or
@@ -183,7 +183,7 @@ def __init__(
183183
linear_wd (float): Non-negative weight decay for linear terms.
184184
factors_wd (float): Non-negative weight decay for factorization
185185
terms.
186-
bias_init_method (string): Initialization method for the bias term:
186+
bias_init_method (str): Initialization method for the bias term:
187187
'normal', 'uniform' or 'constant'.
188188
bias_init_scale (float): Non-negative range for initialization of
189189
the bias term that takes effect when bias_init_method parameter
@@ -193,7 +193,7 @@ def __init__(
193193
bias_init_method parameter is 'normal'.
194194
bias_init_value (float): Initial value of the bias term that takes
195195
effect when bias_init_method parameter is 'constant'.
196-
linear_init_method (string): Initialization method for linear term:
196+
linear_init_method (str): Initialization method for linear term:
197197
'normal', 'uniform' or 'constant'.
198198
linear_init_scale (float): Non-negative range for initialization of
199199
linear terms that takes effect when linear_init_method parameter
@@ -203,7 +203,7 @@ def __init__(
203203
linear_init_method parameter is 'normal'.
204204
linear_init_value (float): Initial value of linear terms that takes
205205
effect when linear_init_method parameter is 'constant'.
206-
factors_init_method (string): Initialization method for
206+
factors_init_method (str): Initialization method for
207207
factorization term: 'normal', 'uniform' or 'constant'.
208208
factors_init_scale (float): Non-negative range for initialization of
209209
factorization terms that takes effect when factors_init_method

src/sagemaker/amazon/ipinsights.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,45 @@ class IPInsights(AmazonAlgorithmEstimatorBase):
3636
as user IDs or account numbers.
3737
"""
3838

39-
repo_name = "ipinsights"
40-
repo_version = 1
41-
MINI_BATCH_SIZE = 10000
39+
repo_name: str = "ipinsights"
40+
repo_version: str = "1"
41+
MINI_BATCH_SIZE: int = 10000
4242

43-
num_entity_vectors = hp(
43+
num_entity_vectors: hp = hp(
4444
"num_entity_vectors", (ge(1), le(250000000)), "An integer in [1, 250000000]", int
4545
)
46-
vector_dim = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int)
46+
vector_dim: hp = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int)
4747

48-
batch_metrics_publish_interval = hp(
48+
batch_metrics_publish_interval: hp = hp(
4949
"batch_metrics_publish_interval", (ge(1)), "An integer greater than 0", int
5050
)
51-
epochs = hp("epochs", (ge(1)), "An integer greater than 0", int)
52-
learning_rate = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float)
53-
num_ip_encoder_layers = hp(
51+
epochs: hp = hp("epochs", (ge(1)), "An integer greater than 0", int)
52+
learning_rate: hp = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float)
53+
num_ip_encoder_layers: hp = hp(
5454
"num_ip_encoder_layers", (ge(0), le(100)), "An integer in [0, 100]", int
5555
)
56-
random_negative_sampling_rate = hp(
56+
random_negative_sampling_rate: hp = hp(
5757
"random_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int
5858
)
59-
shuffled_negative_sampling_rate = hp(
59+
shuffled_negative_sampling_rate: hp = hp(
6060
"shuffled_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int
6161
)
62-
weight_decay = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float)
62+
weight_decay: hp = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float)
6363

6464
def __init__(
6565
self,
66-
role,
67-
instance_count=None,
68-
instance_type=None,
69-
num_entity_vectors=None,
70-
vector_dim=None,
71-
batch_metrics_publish_interval=None,
72-
epochs=None,
73-
learning_rate=None,
74-
num_ip_encoder_layers=None,
75-
random_negative_sampling_rate=None,
76-
shuffled_negative_sampling_rate=None,
77-
weight_decay=None,
66+
role: str,
67+
instance_count: Optional[Union[int, PipelineVariable]] = None,
68+
instance_type: Optional[Union[str, PipelineVariable]] = None,
69+
num_entity_vectors: Optional[int] = None,
70+
vector_dim: Optional[int] = None,
71+
batch_metrics_publish_interval: Optional[int] = None,
72+
epochs: Optional[int] = None,
73+
learning_rate: Optional[float] = None,
74+
num_ip_encoder_layers: Optional[int] = None,
75+
random_negative_sampling_rate: Optional[int] = None,
76+
shuffled_negative_sampling_rate: Optional[int] = None,
77+
weight_decay: Optional[float] = None,
7878
**kwargs
7979
):
8080
"""This estimator is for IP Insights.
@@ -106,9 +106,9 @@ def __init__(
106106
endpoints use this role to access training data and model
107107
artifacts. After the endpoint is created, the inference code
108108
might use the IAM role, if accessing AWS resource.
109-
instance_count (int): Number of Amazon EC2 instances to use
109+
instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
110110
for training.
111-
instance_type (str): Type of EC2 instance to use for training,
111+
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
112112
for example, 'ml.m5.xlarge'.
113113
num_entity_vectors (int): Required. The number of embeddings to
114114
train for entities accessing online resources. We recommend 2x

0 commit comments

Comments
 (0)