19
19
from sagemaker .fw_utils import (
20
20
framework_name_from_image ,
21
21
framework_version_from_tag ,
22
- empty_framework_version_warning ,
22
+ is_version_equal_or_higher ,
23
23
python_deprecation_warning ,
24
24
parameter_v2_rename_warning ,
25
- is_version_equal_or_higher ,
25
+ validate_version_or_image_args ,
26
26
warn_if_parameter_server_with_multi_gpu ,
27
27
)
28
28
from sagemaker .mxnet import defaults
@@ -43,10 +43,10 @@ class MXNet(Framework):
43
43
def __init__ (
44
44
self ,
45
45
entry_point ,
46
+ framework_version = None ,
47
+ py_version = None ,
46
48
source_dir = None ,
47
49
hyperparameters = None ,
48
- py_version = "py2" ,
49
- framework_version = None ,
50
50
image_name = None ,
51
51
distributions = None ,
52
52
** kwargs
@@ -73,6 +73,13 @@ def __init__(
73
73
file which should be executed as the entry point to training.
74
74
If ``source_dir`` is specified, then ``entry_point``
75
75
must point to a file located at the root of ``source_dir``.
76
+ framework_version (str): MXNet version you want to use for executing
77
+ your model training code. Defaults to `None`. Required unless
78
+ ``image_name`` is provided. List of supported versions.
79
+ https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
80
+ py_version (str): Python version you want to use for executing your
81
+ model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
82
+ unless ``image_name`` is provided.
76
83
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
77
84
with any other training source code dependencies aside from the entry
78
85
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +91,6 @@ def __init__(
84
91
SageMaker. For convenience, this accepts other types for keys
85
92
and values, but ``str()`` will be called to convert them before
86
93
training.
87
- py_version (str): Python version you want to use for executing your
88
- model training code (default: 'py2'). One of 'py2' or 'py3'.
89
- framework_version (str): MXNet version you want to use for executing
90
- your model training code. List of supported versions
91
- https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92
- If not specified, this will default to 1.2.1.
93
94
image_name (str): If specified, the estimator will use this image for training and
94
95
hosting, instead of selecting the appropriate SageMaker official image based on
95
96
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -98,6 +99,9 @@ def __init__(
98
99
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
99
100
* ``custom-image:latest``
100
101
102
+ If ``framework_version`` or ``py_version`` are ``None``, then
103
+ ``image_name`` is required. If also ``None``, then a ``ValueError``
104
+ will be raised.
101
105
distributions (dict): A dictionary with information on how to run distributed
102
106
training (default: None). To have parameter servers launched for training,
103
107
set this value to be ``{'parameter_server': {'enabled': True}}``.
@@ -110,34 +114,32 @@ def __init__(
110
114
:class:`~sagemaker.estimator.Framework` and
111
115
:class:`~sagemaker.estimator.EstimatorBase`.
112
116
"""
113
- if framework_version is None :
117
+ validate_version_or_image_args (framework_version , py_version , image_name )
118
+ if py_version and py_version == "py2" :
114
119
logger .warning (
115
- empty_framework_version_warning ( defaults . MXNET_VERSION , self . LATEST_VERSION )
120
+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
116
121
)
117
- self .framework_version = framework_version or defaults .MXNET_VERSION
122
+ self .framework_version = framework_version
123
+ self .py_version = py_version
118
124
119
125
if "enable_sagemaker_metrics" not in kwargs :
120
126
# enable sagemaker metrics for MXNet v1.6 or greater:
121
- if is_version_equal_or_higher ([1 , 6 ], self .framework_version ):
127
+ if self .framework_version and is_version_equal_or_higher (
128
+ [1 , 6 ], self .framework_version
129
+ ):
122
130
kwargs ["enable_sagemaker_metrics" ] = True
123
131
124
132
super (MXNet , self ).__init__ (
125
133
entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
126
134
)
127
135
128
- if py_version == "py2" :
129
- logger .warning (
130
- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
131
- )
132
-
133
136
if distributions is not None :
134
137
logger .warning (parameter_v2_rename_warning ("distributions" , "distribution" ))
135
138
train_instance_type = kwargs .get ("train_instance_type" )
136
139
warn_if_parameter_server_with_multi_gpu (
137
140
training_instance_type = train_instance_type , distributions = distributions
138
141
)
139
142
140
- self .py_version = py_version
141
143
self ._configure_distribution (distributions )
142
144
143
145
def _configure_distribution (self , distributions ):
@@ -148,7 +150,10 @@ def _configure_distribution(self, distributions):
148
150
if distributions is None :
149
151
return
150
152
151
- if self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION :
153
+ if (
154
+ self .framework_version
155
+ and self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION
156
+ ):
152
157
raise ValueError (
153
158
"The distributions option is valid for only versions {} and higher" .format (
154
159
"." .join (self ._LOWEST_SCRIPT_MODE_VERSION )
@@ -221,12 +226,12 @@ def create_model(
221
226
self .model_data ,
222
227
role or self .role ,
223
228
entry_point or self .entry_point ,
229
+ framework_version = self .framework_version ,
230
+ py_version = self .py_version ,
224
231
source_dir = (source_dir or self ._model_source_dir ()),
225
232
enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
226
233
container_log_level = self .container_log_level ,
227
234
code_location = self .code_location ,
228
- py_version = self .py_version ,
229
- framework_version = self .framework_version ,
230
235
model_server_workers = model_server_workers ,
231
236
sagemaker_session = self .sagemaker_session ,
232
237
vpc_config = self .get_vpc_config (vpc_config_override ),
@@ -254,22 +259,25 @@ class constructor
254
259
image_name = init_params .pop ("image" )
255
260
framework , py_version , tag , _ = framework_name_from_image (image_name )
256
261
262
+ # We switched image tagging scheme from regular image version (e.g. '1.0') to more
263
+ # expressive containing framework version, device type and python version
264
+ # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
265
+ # '0.12' framework version otherwise extract framework version from the tag itself.
266
+ if tag is None :
267
+ framework_version = None
268
+ elif tag == "1.0" :
269
+ framework_version = "0.12"
270
+ else :
271
+ framework_version = framework_version_from_tag (tag )
272
+ init_params ["framework_version" ] = framework_version
273
+ init_params ["py_version" ] = py_version
274
+
257
275
if not framework :
258
276
# If we were unable to parse the framework name from the image it is not one of our
259
277
# officially supported images, in this case just add the image to the init params.
260
278
init_params ["image_name" ] = image_name
261
279
return init_params
262
280
263
- init_params ["py_version" ] = py_version
264
-
265
- # We switched image tagging scheme from regular image version (e.g. '1.0') to more
266
- # expressive containing framework version, device type and python version
267
- # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
268
- # '0.12' framework version otherwise extract framework version from the tag itself.
269
- init_params ["framework_version" ] = (
270
- "0.12" if tag == "1.0" else framework_version_from_tag (tag )
271
- )
272
-
273
281
training_job_name = init_params ["base_job_name" ]
274
282
275
283
if framework != cls .__framework_name__ :
0 commit comments