|
26 | 26 | from sagemaker.tensorflow.defaults import TF_VERSION
|
27 | 27 | from sagemaker.tensorflow.model import TensorFlowModel
|
28 | 28 | from sagemaker.tensorflow.serving import Model
|
29 |
| -from sagemaker.utils import get_config_value |
| 29 | +from sagemaker.utils import get_config_value, get_short_version |
30 | 30 | from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
|
31 | 31 |
|
32 | 32 | logger = logging.getLogger('sagemaker')
|
@@ -171,9 +171,11 @@ class TensorFlow(Framework):
|
171 | 171 |
|
172 | 172 | __framework_name__ = 'tensorflow'
|
173 | 173 |
|
174 |
| - LATEST_VERSION = '1.12' |
| 174 | + LATEST_VERSION = '1.13' |
175 | 175 | """The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
|
176 | 176 |
|
| 177 | + _LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13] |
| 178 | + |
177 | 179 | def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
|
178 | 180 | framework_version=None, model_dir=None, requirements_file='', image_name=None,
|
179 | 181 | script_mode=False, distributions=None, **kwargs):
|
@@ -276,6 +278,13 @@ def _validate_args(self, py_version, script_mode, framework_version, training_st
|
276 | 278 | .format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args))
|
277 | 279 | )
|
278 | 280 |
|
| 281 | + if (not self._script_mode_enabled()) and \ |
| 282 | + [int(s) for s in self.framework_version.split('.')] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION: |
| 283 | + raise AttributeError( |
| 284 | + 'Legacy mode is deprecated in versions 1.13 and higher.' |
| 285 | + 'Please set the script_mode argument to True to use Script Mode' |
| 286 | + ) |
| 287 | + |
279 | 288 | def _validate_requirements_file(self, requirements_file):
|
280 | 289 | if not requirements_file:
|
281 | 290 | return
|
@@ -427,7 +436,7 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
|
427 | 436 | image=self.image_name,
|
428 | 437 | name=self._current_job_name,
|
429 | 438 | container_log_level=self.container_log_level,
|
430 |
| - framework_version=self.framework_version, |
| 439 | + framework_version=get_short_version(self.framework_version), |
431 | 440 | sagemaker_session=self.sagemaker_session,
|
432 | 441 | vpc_config=self.get_vpc_config(vpc_config_override))
|
433 | 442 |
|
|
0 commit comments