13
13
"""Test docstring"""
14
14
from __future__ import absolute_import
15
15
16
+ from typing import Optional , Union , Dict , List
17
+
16
18
import sagemaker
17
19
import sagemaker .parameter
18
20
from sagemaker import vpc_utils
19
21
from sagemaker .deserializers import BytesDeserializer
20
22
from sagemaker .deprecations import removed_kwargs
21
23
from sagemaker .estimator import EstimatorBase
24
+ from sagemaker .inputs import TrainingInput , FileSystemInput
22
25
from sagemaker .serializers import IdentitySerializer
23
26
from sagemaker .transformer import Transformer
24
27
from sagemaker .predictor import Predictor
28
+ from sagemaker .session import Session
29
+ from sagemaker .workflow .entities import PipelineVariable
30
+
31
+ from sagemaker .workflow import is_pipeline_variable
25
32
26
33
27
34
class AlgorithmEstimator (EstimatorBase ):
@@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase):
37
44
38
45
def __init__ (
39
46
self ,
40
- algorithm_arn ,
41
- role ,
42
- instance_count ,
43
- instance_type ,
44
- volume_size = 30 ,
45
- volume_kms_key = None ,
46
- max_run = 24 * 60 * 60 ,
47
- input_mode = "File" ,
48
- output_path = None ,
49
- output_kms_key = None ,
50
- base_job_name = None ,
51
- sagemaker_session = None ,
52
- hyperparameters = None ,
53
- tags = None ,
54
- subnets = None ,
55
- security_group_ids = None ,
56
- model_uri = None ,
57
- model_channel_name = "model" ,
58
- metric_definitions = None ,
59
- encrypt_inter_container_traffic = False ,
60
- use_spot_instances = False ,
61
- max_wait = None ,
47
+ algorithm_arn : str ,
48
+ role : str ,
49
+ instance_count : Optional [ Union [ int , PipelineVariable ]] = None ,
50
+ instance_type : Optional [ Union [ str , PipelineVariable ]] = None ,
51
+ volume_size : Union [ int , PipelineVariable ] = 30 ,
52
+ volume_kms_key : Optional [ Union [ str , PipelineVariable ]] = None ,
53
+ max_run : Union [ int , PipelineVariable ] = 24 * 60 * 60 ,
54
+ input_mode : Union [ str , PipelineVariable ] = "File" ,
55
+ output_path : Optional [ Union [ str , PipelineVariable ]] = None ,
56
+ output_kms_key : Optional [ Union [ str , PipelineVariable ]] = None ,
57
+ base_job_name : Optional [ str ] = None ,
58
+ sagemaker_session : Optional [ Session ] = None ,
59
+ hyperparameters : Optional [ Dict [ str , Union [ str , PipelineVariable ]]] = None ,
60
+ tags : Optional [ List [ Dict [ str , Union [ str , PipelineVariable ]]]] = None ,
61
+ subnets : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
62
+ security_group_ids : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
63
+ model_uri : Optional [ str ] = None ,
64
+ model_channel_name : Union [ str , PipelineVariable ] = "model" ,
65
+ metric_definitions : Optional [ List [ Dict [ str , Union [ str , PipelineVariable ]]]] = None ,
66
+ encrypt_inter_container_traffic : Union [ bool , PipelineVariable ] = False ,
67
+ use_spot_instances : Union [ bool , PipelineVariable ] = False ,
68
+ max_wait : Optional [ Union [ int , PipelineVariable ]] = None ,
62
69
** kwargs # pylint: disable=W0613
63
70
):
64
71
"""Initialize an ``AlgorithmEstimator`` instance.
@@ -71,18 +78,21 @@ def __init__(
71
78
access training data and model artifacts. After the endpoint
72
79
is created, the inference code might use the IAM role, if it
73
80
needs to access an AWS resource.
74
- instance_count (int): Number of Amazon EC2 instances to use for training.
75
- instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
76
- volume_size (int): Size in GB of the EBS volume to use for
81
+ instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
82
+ for training.
83
+ instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
84
+ for example, 'ml.c4.xlarge'.
85
+ volume_size (int or PipelineVariable): Size in GB of the EBS volume to use for
77
86
storing input data during training (default: 30). Must be large enough to store
78
87
training data if File Mode is used (which is the default).
79
- volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached
80
- to the training instance (default: None).
81
- max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
88
+ volume_kms_key (str or PipelineVariable): Optional. KMS key ID for encrypting
89
+ EBS volume attached to the training instance (default: None).
90
+ max_run (int or PipelineVariable): Timeout in seconds for training
91
+ (default: 24 * 60 * 60).
82
92
After this amount of time Amazon SageMaker terminates the
83
93
job regardless of its current status.
84
- input_mode (str): The input mode that the algorithm supports
85
- (default: 'File'). Valid modes:
94
+ input_mode (str or PipelineVariable ): The input mode that the algorithm supports
95
+ (default: 'File'). Valid modes:
86
96
87
97
* 'File' - Amazon SageMaker copies the training dataset from
88
98
the S3 location to a local directory.
@@ -92,13 +102,14 @@ def __init__(
92
102
This argument can be overriden on a per-channel basis using
93
103
``sagemaker.inputs.TrainingInput.input_mode``.
94
104
95
- output_path (str): S3 location for saving the training result (model artifacts and
96
- output files). If not specified, results are stored to a default bucket. If
105
+ output_path (str or PipelineVariable): S3 location for saving the training result
106
+ (model artifacts and output files). If not specified,
107
+ results are stored to a default bucket. If
97
108
the bucket with the specific name does not exist, the
98
109
estimator creates the bucket during the
99
110
:meth:`~sagemaker.estimator.EstimatorBase.fit` method
100
111
execution.
101
- output_kms_key (str): Optional. KMS key ID for encrypting the
112
+ output_kms_key (str or PipelineVariable ): Optional. KMS key ID for encrypting the
102
113
training output (default: None). base_job_name (str): Prefix for
103
114
training job name when the
104
115
:meth:`~sagemaker.estimator.EstimatorBase.fit`
@@ -109,9 +120,10 @@ def __init__(
109
120
interactions with Amazon SageMaker APIs and any other AWS services needed. If
110
121
not specified, the estimator creates one using the default
111
122
AWS configuration chain.
112
- tags (list[dict]): List of tags for labeling a training job. For more, see
123
+ tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
124
+ labeling a training job. For more, see
113
125
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
114
- subnets (list[str]): List of subnet ids. If not specified
126
+ subnets (list[str] or list[PipelineVariable] ): List of subnet ids. If not specified
115
127
training job will be created without VPC config.
116
128
security_group_ids (list[str]): List of security group ids. If
117
129
not specified training job will be created without VPC config.
@@ -122,22 +134,22 @@ def __init__(
122
134
other artifacts coming from a different source.
123
135
More information:
124
136
https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
125
- model_channel_name (str): Name of the channel where 'model_uri'
137
+ model_channel_name (str or PipelineVariable ): Name of the channel where 'model_uri'
126
138
will be downloaded (default: 'model'). metric_definitions
127
139
(list[dict]): A list of dictionaries that defines the metric(s)
128
140
used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
129
141
the name of the metric, and 'Regex' for the regular
130
142
expression used to extract the metric from the logs.
131
- encrypt_inter_container_traffic (bool): Specifies whether traffic between training
132
- containers is encrypted for the training job (default: ``False``).
133
- use_spot_instances (bool): Specifies whether to use SageMaker
143
+ encrypt_inter_container_traffic (bool or PipelineVariable ): Specifies whether traffic
144
+ between training containers is encrypted for the training job (default: ``False``).
145
+ use_spot_instances (bool or PipelineVariable ): Specifies whether to use SageMaker
134
146
Managed Spot instances for training. If enabled then the
135
147
`max_wait` arg should also be set.
136
148
137
149
More information:
138
150
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
139
151
(default: ``False``).
140
- max_wait (int): Timeout in seconds waiting for spot training
152
+ max_wait (int or PipelineVariable ): Timeout in seconds waiting for spot training
141
153
instances (default: None). After this amount of time Amazon
142
154
SageMaker will stop waiting for Spot instances to become
143
155
available (default: ``None``).
@@ -186,22 +198,25 @@ def validate_train_spec(self):
186
198
# Check that the input mode provided is compatible with the training input modes for the
187
199
# algorithm.
188
200
input_modes = self ._algorithm_training_input_modes (train_spec ["TrainingChannels" ])
189
- if self .input_mode not in input_modes :
201
+ if not is_pipeline_variable ( self . input_mode ) and self .input_mode not in input_modes :
190
202
raise ValueError (
191
203
"Invalid input mode: %s. %s only supports: %s"
192
204
% (self .input_mode , algorithm_name , input_modes )
193
205
)
194
206
195
207
# Check that the training instance type is compatible with the algorithm.
196
208
supported_instances = train_spec ["SupportedTrainingInstanceTypes" ]
197
- if self .instance_type not in supported_instances :
209
+ if (
210
+ not is_pipeline_variable (self .instance_type )
211
+ and self .instance_type not in supported_instances
212
+ ):
198
213
raise ValueError (
199
214
"Invalid instance_type: %s. %s supports the following instance types: %s"
200
215
% (self .instance_type , algorithm_name , supported_instances )
201
216
)
202
217
203
218
# Verify if distributed training is supported by the algorithm
204
- if (
219
+ if not is_pipeline_variable ( self . instance_count ) and (
205
220
self .instance_count > 1
206
221
and "SupportsDistributedTraining" in train_spec
207
222
and not train_spec ["SupportsDistributedTraining" ]
@@ -414,12 +429,18 @@ def _prepare_for_training(self, job_name=None):
414
429
415
430
super (AlgorithmEstimator , self )._prepare_for_training (job_name )
416
431
417
- def fit (self , inputs = None , wait = True , logs = True , job_name = None ):
432
+ def fit (
433
+ self ,
434
+ inputs : Optional [Union [str , Dict , TrainingInput , FileSystemInput ]] = None ,
435
+ wait : bool = True ,
436
+ logs : bool = True ,
437
+ job_name : Optional [str ] = None ,
438
+ ):
418
439
"""Placeholder docstring"""
419
440
if inputs :
420
441
self ._validate_input_channels (inputs )
421
442
422
- super (AlgorithmEstimator , self ).fit (inputs , wait , logs , job_name )
443
+ return super (AlgorithmEstimator , self ).fit (inputs , wait , logs , job_name )
423
444
424
445
def _validate_input_channels (self , channels ):
425
446
"""Placeholder docstring"""
0 commit comments