19
19
from sagemaker .fw_utils import (
20
20
framework_name_from_image ,
21
21
framework_version_from_tag ,
22
- empty_framework_version_warning ,
22
+ is_version_equal_or_higher ,
23
23
python_deprecation_warning ,
24
24
parameter_v2_rename_warning ,
25
- is_version_equal_or_higher ,
25
+ validate_version_or_image_args ,
26
26
warn_if_parameter_server_with_multi_gpu ,
27
27
)
28
28
from sagemaker .mxnet import defaults
@@ -43,10 +43,10 @@ class MXNet(Framework):
43
43
def __init__ (
44
44
self ,
45
45
entry_point ,
46
+ framework_version = None ,
47
+ py_version = None ,
46
48
source_dir = None ,
47
49
hyperparameters = None ,
48
- py_version = "py2" ,
49
- framework_version = None ,
50
50
image_name = None ,
51
51
distributions = None ,
52
52
** kwargs
@@ -73,6 +73,11 @@ def __init__(
73
73
file which should be executed as the entry point to training.
74
74
If ``source_dir`` is specified, then ``entry_point``
75
75
must point to a file located at the root of ``source_dir``.
76
+ framework_version (str): MXNet version you want to use for executing
77
+ your model training code. List of supported versions. Defaults to ``None``.
78
+ https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
79
+ py_version (str): Python version you want to use for executing your
80
+ model training code. One of 'py2' or 'py3'. Defaults to ``None``.
76
81
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
77
82
with any other training source code dependencies aside from the entry
78
83
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +89,6 @@ def __init__(
84
89
SageMaker. For convenience, this accepts other types for keys
85
90
and values, but ``str()`` will be called to convert them before
86
91
training.
87
- py_version (str): Python version you want to use for executing your
88
- model training code (default: 'py2'). One of 'py2' or 'py3'.
89
- framework_version (str): MXNet version you want to use for executing
90
- your model training code. List of supported versions
91
- https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92
- If not specified, this will default to 1.2.1.
93
92
image_name (str): If specified, the estimator will use this image for training and
94
93
hosting, instead of selecting the appropriate SageMaker official image based on
95
94
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -98,6 +97,9 @@ def __init__(
98
97
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
99
98
* ``custom-image:latest``
100
99
100
+ If ``framework_version`` or ``py_version`` are ``None``, then
101
+ ``image_name`` is required. If also ``None``, then a ``ValueError``
102
+ will be raised.
101
103
distributions (dict): A dictionary with information on how to run distributed
102
104
training (default: None). To have parameter servers launched for training,
103
105
set this value to be ``{'parameter_server': {'enabled': True}}``.
@@ -110,34 +112,32 @@ def __init__(
110
112
:class:`~sagemaker.estimator.Framework` and
111
113
:class:`~sagemaker.estimator.EstimatorBase`.
112
114
"""
113
- if framework_version is None :
115
+ validate_version_or_image_args (framework_version , py_version , image_name )
116
+ if py_version and py_version == "py2" :
114
117
logger .warning (
115
- empty_framework_version_warning ( defaults . MXNET_VERSION , self . LATEST_VERSION )
118
+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
116
119
)
117
- self .framework_version = framework_version or defaults .MXNET_VERSION
120
+ self .framework_version = framework_version
121
+ self .py_version = py_version
118
122
119
123
if "enable_sagemaker_metrics" not in kwargs :
120
124
# enable sagemaker metrics for MXNet v1.6 or greater:
121
- if is_version_equal_or_higher ([1 , 6 ], self .framework_version ):
125
+ if self .framework_version and is_version_equal_or_higher (
126
+ [1 , 6 ], self .framework_version
127
+ ):
122
128
kwargs ["enable_sagemaker_metrics" ] = True
123
129
124
130
super (MXNet , self ).__init__ (
125
131
entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
126
132
)
127
133
128
- if py_version == "py2" :
129
- logger .warning (
130
- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
131
- )
132
-
133
134
if distributions is not None :
134
135
logger .warning (parameter_v2_rename_warning ("distributions" , "distribution" ))
135
136
train_instance_type = kwargs .get ("train_instance_type" )
136
137
warn_if_parameter_server_with_multi_gpu (
137
138
training_instance_type = train_instance_type , distributions = distributions
138
139
)
139
140
140
- self .py_version = py_version
141
141
self ._configure_distribution (distributions )
142
142
143
143
def _configure_distribution (self , distributions ):
@@ -148,7 +148,10 @@ def _configure_distribution(self, distributions):
148
148
if distributions is None :
149
149
return
150
150
151
- if self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION :
151
+ if (
152
+ self .framework_version
153
+ and self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION
154
+ ):
152
155
raise ValueError (
153
156
"The distributions option is valid for only versions {} and higher" .format (
154
157
"." .join (self ._LOWEST_SCRIPT_MODE_VERSION )
@@ -221,12 +224,12 @@ def create_model(
221
224
self .model_data ,
222
225
role or self .role ,
223
226
entry_point or self .entry_point ,
227
+ framework_version = self .framework_version ,
228
+ py_version = self .py_version ,
224
229
source_dir = (source_dir or self ._model_source_dir ()),
225
230
enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
226
231
container_log_level = self .container_log_level ,
227
232
code_location = self .code_location ,
228
- py_version = self .py_version ,
229
- framework_version = self .framework_version ,
230
233
model_server_workers = model_server_workers ,
231
234
sagemaker_session = self .sagemaker_session ,
232
235
vpc_config = self .get_vpc_config (vpc_config_override ),
@@ -254,22 +257,25 @@ class constructor
254
257
image_name = init_params .pop ("image" )
255
258
framework , py_version , tag , _ = framework_name_from_image (image_name )
256
259
260
+ # We switched image tagging scheme from regular image version (e.g. '1.0') to more
261
+ # expressive containing framework version, device type and python version
262
+ # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
263
+ # '0.12' framework version otherwise extract framework version from the tag itself.
264
+ if tag is None :
265
+ framework_version = None
266
+ elif tag == "1.0" :
267
+ framework_version = "0.12"
268
+ else :
269
+ framework_version = framework_version_from_tag (tag )
270
+ init_params ["framework_version" ] = framework_version
271
+ init_params ["py_version" ] = py_version
272
+
257
273
if not framework :
258
274
# If we were unable to parse the framework name from the image it is not one of our
259
275
# officially supported images, in this case just add the image to the init params.
260
276
init_params ["image_name" ] = image_name
261
277
return init_params
262
278
263
- init_params ["py_version" ] = py_version
264
-
265
- # We switched image tagging scheme from regular image version (e.g. '1.0') to more
266
- # expressive containing framework version, device type and python version
267
- # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
268
- # '0.12' framework version otherwise extract framework version from the tag itself.
269
- init_params ["framework_version" ] = (
270
- "0.12" if tag == "1.0" else framework_version_from_tag (tag )
271
- )
272
-
273
279
training_job_name = init_params ["base_job_name" ]
274
280
275
281
if framework != cls .__framework_name__ :
0 commit comments