Skip to content

Commit 059f7ff

Browse files
author
Dewen Qi
committed
change: Add PipelineVariable annotation to Amazon estimators
1 parent cb5c991 commit 059f7ff

11 files changed

+426
-380
lines changed

src/sagemaker/amazon/amazon_estimator.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19-
from typing import Union
19+
from typing import Union, Optional, Dict
2020

2121
from six.moves.urllib.parse import urlparse
2222

@@ -30,6 +30,7 @@
3030
from sagemaker.utils import sagemaker_timestamp
3131
from sagemaker.workflow.entities import PipelineVariable
3232
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
33+
from sagemaker.workflow import is_pipeline_variable
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -42,16 +43,16 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4243

4344
feature_dim = hp("feature_dim", validation.gt(0), data_type=int)
4445
mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int)
45-
repo_name = None
46-
repo_version = None
46+
repo_name: Optional[str] = None
47+
repo_version: Optional[str] = None
4748

4849
def __init__(
4950
self,
50-
role,
51-
instance_count=None,
52-
instance_type=None,
53-
data_location=None,
54-
enable_network_isolation=False,
51+
role: str,
52+
instance_count: Optional[Union[int]] = None,
53+
instance_type: Optional[Union[str, PipelineVariable]] = None,
54+
data_location: Optional[str] = None,
55+
enable_network_isolation: Union[bool, PipelineVariable] = False,
5556
**kwargs
5657
):
5758
"""Initialize an AmazonAlgorithmEstimatorBase.
@@ -115,6 +116,11 @@ def data_location(self):
115116
@data_location.setter
116117
def data_location(self, data_location):
117118
"""Placeholder docstring"""
119+
if is_pipeline_variable(data_location):
120+
raise ValueError(
121+
"data_location argument has to be an integer " + "rather than a pipeline variable"
122+
)
123+
118124
if not data_location.startswith("s3://"):
119125
raise ValueError(
120126
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
@@ -198,12 +204,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
198204
@runnable_by_pipeline
199205
def fit(
200206
self,
201-
records,
202-
mini_batch_size=None,
203-
wait=True,
204-
logs=True,
205-
job_name=None,
206-
experiment_config=None,
207+
records: "RecordSet",
208+
mini_batch_size: Optional[int] = None,
209+
wait: bool = True,
210+
logs: bool = True,
211+
job_name: Optional[str] = None,
212+
experiment_config: Optional[Dict[str, str]] = None,
207213
):
208214
"""Fit this Estimator on serialized Record objects, stored in S3.
209215

src/sagemaker/amazon/factorization_machines.py

+53-53
Original file line numberDiff line numberDiff line change
@@ -37,83 +37,83 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase):
3737
sparse datasets economically.
3838
"""
3939

40-
repo_name = "factorization-machines"
41-
repo_version = 1
40+
repo_name: str = "factorization-machines"
41+
repo_version: int = 1
4242

43-
num_factors = hp("num_factors", gt(0), "An integer greater than zero", int)
44-
predictor_type = hp(
43+
num_factors: hp = hp("num_factors", gt(0), "An integer greater than zero", int)
44+
predictor_type: hp = hp(
4545
"predictor_type",
4646
isin("binary_classifier", "regressor"),
4747
'Value "binary_classifier" or "regressor"',
4848
str,
4949
)
50-
epochs = hp("epochs", gt(0), "An integer greater than 0", int)
51-
clip_gradient = hp("clip_gradient", (), "A float value", float)
52-
eps = hp("eps", (), "A float value", float)
53-
rescale_grad = hp("rescale_grad", (), "A float value", float)
54-
bias_lr = hp("bias_lr", ge(0), "A non-negative float", float)
55-
linear_lr = hp("linear_lr", ge(0), "A non-negative float", float)
56-
factors_lr = hp("factors_lr", ge(0), "A non-negative float", float)
57-
bias_wd = hp("bias_wd", ge(0), "A non-negative float", float)
58-
linear_wd = hp("linear_wd", ge(0), "A non-negative float", float)
59-
factors_wd = hp("factors_wd", ge(0), "A non-negative float", float)
60-
bias_init_method = hp(
50+
epochs: hp = hp("epochs", gt(0), "An integer greater than 0", int)
51+
clip_gradient: hp = hp("clip_gradient", (), "A float value", float)
52+
eps: hp = hp("eps", (), "A float value", float)
53+
rescale_grad: hp = hp("rescale_grad", (), "A float value", float)
54+
bias_lr: hp = hp("bias_lr", ge(0), "A non-negative float", float)
55+
linear_lr: hp = hp("linear_lr", ge(0), "A non-negative float", float)
56+
factors_lr: hp = hp("factors_lr", ge(0), "A non-negative float", float)
57+
bias_wd: hp = hp("bias_wd", ge(0), "A non-negative float", float)
58+
linear_wd: hp = hp("linear_wd", ge(0), "A non-negative float", float)
59+
factors_wd: hp = hp("factors_wd", ge(0), "A non-negative float", float)
60+
bias_init_method: hp = hp(
6161
"bias_init_method",
6262
isin("normal", "uniform", "constant"),
6363
'Value "normal", "uniform" or "constant"',
6464
str,
6565
)
66-
bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float)
67-
bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68-
bias_init_value = hp("bias_init_value", (), "A float value", float)
69-
linear_init_method = hp(
66+
bias_init_scale: hp = hp("bias_init_scale", ge(0), "A non-negative float", float)
67+
bias_init_sigma: hp = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68+
bias_init_value: hp = hp("bias_init_value", (), "A float value", float)
69+
linear_init_method: hp = hp(
7070
"linear_init_method",
7171
isin("normal", "uniform", "constant"),
7272
'Value "normal", "uniform" or "constant"',
7373
str,
7474
)
75-
linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float)
76-
linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77-
linear_init_value = hp("linear_init_value", (), "A float value", float)
78-
factors_init_method = hp(
75+
linear_init_scale: hp = hp("linear_init_scale", ge(0), "A non-negative float", float)
76+
linear_init_sigma: hp = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77+
linear_init_value: hp = hp("linear_init_value", (), "A float value", float)
78+
factors_init_method: hp = hp(
7979
"factors_init_method",
8080
isin("normal", "uniform", "constant"),
8181
'Value "normal", "uniform" or "constant"',
8282
str,
8383
)
84-
factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float)
85-
factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86-
factors_init_value = hp("factors_init_value", (), "A float value", float)
84+
factors_init_scale: hp = hp("factors_init_scale", ge(0), "A non-negative float", float)
85+
factors_init_sigma: hp = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86+
factors_init_value: hp = hp("factors_init_value", (), "A float value", float)
8787

8888
def __init__(
8989
self,
90-
role,
91-
instance_count=None,
92-
instance_type=None,
93-
num_factors=None,
94-
predictor_type=None,
95-
epochs=None,
96-
clip_gradient=None,
97-
eps=None,
98-
rescale_grad=None,
99-
bias_lr=None,
100-
linear_lr=None,
101-
factors_lr=None,
102-
bias_wd=None,
103-
linear_wd=None,
104-
factors_wd=None,
105-
bias_init_method=None,
106-
bias_init_scale=None,
107-
bias_init_sigma=None,
108-
bias_init_value=None,
109-
linear_init_method=None,
110-
linear_init_scale=None,
111-
linear_init_sigma=None,
112-
linear_init_value=None,
113-
factors_init_method=None,
114-
factors_init_scale=None,
115-
factors_init_sigma=None,
116-
factors_init_value=None,
90+
role: str,
91+
instance_count: Optional[Union[int, PipelineVariable]] = None,
92+
instance_type: Optional[Union[str, PipelineVariable]] = None,
93+
num_factors: Optional[int] = None,
94+
predictor_type: Optional[str] = None,
95+
epochs: Optional[int] = None,
96+
clip_gradient: Optional[float] = None,
97+
eps: Optional[float] = None,
98+
rescale_grad: Optional[float] = None,
99+
bias_lr: Optional[float] = None,
100+
linear_lr: Optional[float] = None,
101+
factors_lr: Optional[float] = None,
102+
bias_wd: Optional[float] = None,
103+
linear_wd: Optional[float] = None,
104+
factors_wd: Optional[float] = None,
105+
bias_init_method: Optional[str] = None,
106+
bias_init_scale: Optional[float] = None,
107+
bias_init_sigma: Optional[float] = None,
108+
bias_init_value: Optional[float] = None,
109+
linear_init_method: Optional[str] = None,
110+
linear_init_scale: Optional[float] = None,
111+
linear_init_sigma: Optional[float] = None,
112+
linear_init_value: Optional[float] = None,
113+
factors_init_method: Optional[str] = None,
114+
factors_init_scale: Optional[float] = None,
115+
factors_init_sigma: Optional[float] = None,
116+
factors_init_value: Optional[float] = None,
117117
**kwargs
118118
):
119119
"""Factorization Machines is :class:`Estimator` for general-purpose supervised learning.

src/sagemaker/amazon/ipinsights.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,45 @@ class IPInsights(AmazonAlgorithmEstimatorBase):
3636
as user IDs or account numbers.
3737
"""
3838

39-
repo_name = "ipinsights"
40-
repo_version = 1
41-
MINI_BATCH_SIZE = 10000
39+
repo_name: str = "ipinsights"
40+
repo_version: int = 1
41+
MINI_BATCH_SIZE: int = 10000
4242

43-
num_entity_vectors = hp(
43+
num_entity_vectors: hp = hp(
4444
"num_entity_vectors", (ge(1), le(250000000)), "An integer in [1, 250000000]", int
4545
)
46-
vector_dim = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int)
46+
vector_dim: hp = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int)
4747

48-
batch_metrics_publish_interval = hp(
48+
batch_metrics_publish_interval: hp = hp(
4949
"batch_metrics_publish_interval", (ge(1)), "An integer greater than 0", int
5050
)
51-
epochs = hp("epochs", (ge(1)), "An integer greater than 0", int)
52-
learning_rate = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float)
53-
num_ip_encoder_layers = hp(
51+
epochs: hp = hp("epochs", (ge(1)), "An integer greater than 0", int)
52+
learning_rate: hp = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float)
53+
num_ip_encoder_layers: hp = hp(
5454
"num_ip_encoder_layers", (ge(0), le(100)), "An integer in [0, 100]", int
5555
)
56-
random_negative_sampling_rate = hp(
56+
random_negative_sampling_rate: hp = hp(
5757
"random_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int
5858
)
59-
shuffled_negative_sampling_rate = hp(
59+
shuffled_negative_sampling_rate: hp = hp(
6060
"shuffled_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int
6161
)
62-
weight_decay = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float)
62+
weight_decay: hp = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float)
6363

6464
def __init__(
6565
self,
66-
role,
67-
instance_count=None,
68-
instance_type=None,
69-
num_entity_vectors=None,
70-
vector_dim=None,
71-
batch_metrics_publish_interval=None,
72-
epochs=None,
73-
learning_rate=None,
74-
num_ip_encoder_layers=None,
75-
random_negative_sampling_rate=None,
76-
shuffled_negative_sampling_rate=None,
77-
weight_decay=None,
66+
role: str,
67+
instance_count: Optional[Union[int, PipelineVariable]] = None,
68+
instance_type: Optional[Union[str, PipelineVariable]] = None,
69+
num_entity_vectors: Optional[int] = None,
70+
vector_dim: Optional[int] = None,
71+
batch_metrics_publish_interval: Optional[int] = None,
72+
epochs: Optional[int] = None,
73+
learning_rate: Optional[float] = None,
74+
num_ip_encoder_layers: Optional[int] = None,
75+
random_negative_sampling_rate: Optional[int] = None,
76+
shuffled_negative_sampling_rate: Optional[int] = None,
77+
weight_decay: Optional[float] = None,
7878
**kwargs
7979
):
8080
"""This estimator is for IP Insights.

src/sagemaker/amazon/kmeans.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from typing import Union, Optional
16+
from typing import Union, Optional, List
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
@@ -36,43 +36,45 @@ class KMeans(AmazonAlgorithmEstimatorBase):
3636
the algorithm to use to determine similarity.
3737
"""
3838

39-
repo_name = "kmeans"
40-
repo_version = 1
39+
repo_name: str = "kmeans"
40+
repo_version: str = 1
4141

42-
k = hp("k", gt(1), "An integer greater-than 1", int)
43-
init_method = hp("init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str)
44-
max_iterations = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int)
45-
tol = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float)
46-
num_trials = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int)
47-
local_init_method = hp(
42+
k: hp = hp("k", gt(1), "An integer greater-than 1", int)
43+
init_method: hp = hp(
44+
"init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str
45+
)
46+
max_iterations: hp = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int)
47+
tol: hp = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float)
48+
num_trials: hp = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int)
49+
local_init_method: hp = hp(
4850
"local_lloyd_init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str
4951
)
50-
half_life_time_size = hp(
52+
half_life_time_size: hp = hp(
5153
"half_life_time_size", ge(0), "An integer greater-than-or-equal-to 0", int
5254
)
53-
epochs = hp("epochs", gt(0), "An integer greater-than 0", int)
54-
center_factor = hp("extra_center_factor", gt(0), "An integer greater-than 0", int)
55-
eval_metrics = hp(
55+
epochs: hp = hp("epochs", gt(0), "An integer greater-than 0", int)
56+
center_factor: hp = hp("extra_center_factor", gt(0), "An integer greater-than 0", int)
57+
eval_metrics: hp = hp(
5658
name="eval_metrics",
5759
validation_message='A comma separated list of "msd" or "ssd"',
5860
data_type=list,
5961
)
6062

6163
def __init__(
6264
self,
63-
role,
64-
instance_count=None,
65-
instance_type=None,
66-
k=None,
67-
init_method=None,
68-
max_iterations=None,
69-
tol=None,
70-
num_trials=None,
71-
local_init_method=None,
72-
half_life_time_size=None,
73-
epochs=None,
74-
center_factor=None,
75-
eval_metrics=None,
65+
role: str,
66+
instance_count: Optional[Union[int, PipelineVariable]] = None,
67+
instance_type: Optional[Union[str, PipelineVariable]] = None,
68+
k: Optional[int] = None,
69+
init_method: Optional[str] = None,
70+
max_iterations: Optional[int] = None,
71+
tol: Optional[float] = None,
72+
num_trials: Optional[int] = None,
73+
local_init_method: Optional[str] = None,
74+
half_life_time_size: Optional[int] = None,
75+
epochs: Optional[int] = None,
76+
center_factor: Optional[int] = None,
77+
eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None,
7678
**kwargs
7779
):
7880
"""A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`.

0 commit comments

Comments
 (0)