Skip to content

Commit c7bd228

Browse files
committed
breaking: deprecated constants from defaults
1 parent dbdaf50 commit c7bd228

26 files changed

+132
-189
lines changed

src/sagemaker/amazon/amazon_estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ def get_image_uri(region_name, repo_name, repo_version=1):
611611
"""Return algorithm image URI for the given AWS region, repository name, and
612612
repository version
613613
614+
TODO: consider refactoring version/repo validation into a module for all frameworks
615+
614616
Args:
615617
region_name:
616618
repo_name:

src/sagemaker/chainer/defaults.py

-8
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
CHAINER_VERSION = "4.1.0"
17-
"""Default Chainer version for when the framework version is not specified.
18-
This is no longer updated so as to not break existing workflows.
19-
"""
20-
21-
LATEST_VERSION = "5.0.0"
22-
"""The latest version of Chainer included in the SageMaker pre-built Docker images."""
23-
2416
LATEST_PY2_VERSION = "5.0.0"

src/sagemaker/chainer/estimator.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
2322
python_deprecation_warning,
2423
)
2524
from sagemaker.chainer import defaults
@@ -40,8 +39,6 @@ class Chainer(Framework):
4039
_process_slots_per_host = "sagemaker_process_slots_per_host"
4140
_additional_mpi_options = "sagemaker_additional_mpi_options"
4241

43-
LATEST_VERSION = defaults.LATEST_VERSION
44-
4542
def __init__(
4643
self,
4744
entry_point,
@@ -126,11 +123,7 @@ def __init__(
126123
:class:`~sagemaker.estimator.Framework` and
127124
:class:`~sagemaker.estimator.EstimatorBase`.
128125
"""
129-
if framework_version is None:
130-
logger.warning(
131-
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
132-
)
133-
self.framework_version = framework_version or defaults.CHAINER_VERSION
126+
self.framework_version = framework_version
134127

135128
super(Chainer, self).__init__(
136129
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs

src/sagemaker/chainer/model.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
from sagemaker import fw_utils
1919

2020
import sagemaker
21-
from sagemaker.fw_utils import (
22-
create_image_uri,
23-
model_code_key_prefix,
24-
python_deprecation_warning,
25-
empty_framework_version_warning,
26-
)
21+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
2722
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2823
from sagemaker.chainer import defaults
2924
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
@@ -117,13 +112,8 @@ def __init__(
117112
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118113
)
119114

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
123-
)
124-
125115
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.CHAINER_VERSION
116+
self.framework_version = framework_version
127117
self.model_server_workers = model_server_workers
128118

129119
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/cli/mxnet.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.common import HostCommand, TrainCommand
17-
from sagemaker.mxnet import defaults
17+
18+
MXNET_VERSION = "1.2"
1819

1920

2021
def train(args):
@@ -42,7 +43,7 @@ def create_estimator(self):
4243

4344
return MXNet(
4445
entry_point=self.script,
45-
framework_version=defaults.MXNET_VERSION,
46+
framework_version=MXNET_VERSION,
4647
py_version=self.python,
4748
role=self.role_name,
4849
base_job_name=self.job_name,
@@ -66,7 +67,7 @@ def create_model(self, model_url):
6667
model_data=model_url,
6768
role=self.role_name,
6869
entry_point=self.script,
69-
framework_version=defaults.MXNET_VERSION,
70+
framework_version=MXNET_VERSION,
7071
py_version=self.python,
7172
name=self.endpoint_name,
7273
env=self.environment,

src/sagemaker/mxnet/defaults.py

-8
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
MXNET_VERSION = "1.2"
17-
"""Default MXNet version for when the framework version is not specified.
18-
This is no longer updated so as to not break existing workflows.
19-
"""
20-
21-
LATEST_VERSION = "1.6.0"
22-
"""The latest version of MXNet included in the SageMaker pre-built Docker images."""
23-
2416
LATEST_PY2_VERSION = "1.6.0"

src/sagemaker/mxnet/estimator.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ class MXNet(Framework):
3838
__framework_name__ = "mxnet"
3939
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]
4040

41-
LATEST_VERSION = defaults.LATEST_VERSION
42-
4341
def __init__(
4442
self,
4543
entry_point,
@@ -115,7 +113,7 @@ def __init__(
115113
:class:`~sagemaker.estimator.EstimatorBase`.
116114
"""
117115
validate_version_or_image_args(framework_version, py_version, image_name)
118-
if py_version and py_version == "py2":
116+
if py_version == "py2":
119117
logger.warning(
120118
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
121119
)

src/sagemaker/mxnet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115
:class:`~sagemaker.model.Model`.
116116
"""
117117
validate_version_or_image_args(framework_version, py_version, image)
118-
if py_version and py_version == "py2":
118+
if py_version == "py2":
119119
logger.warning(
120120
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
121121
)

src/sagemaker/pytorch/defaults.py

-10
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
PYTORCH_VERSION = "0.4"
17-
"""Default PyTorch version for when the framework version is not specified.
18-
The default version is no longer updated so as to not break existing workflows.
19-
"""
20-
21-
LATEST_VERSION = "1.5.0"
22-
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""
23-
24-
PYTHON_VERSION = "py3"
25-
2616
LATEST_PY2_VERSION = "1.3.1"

src/sagemaker/pytorch/estimator.py

-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class PyTorch(Framework):
3535

3636
__framework_name__ = "pytorch"
3737

38-
LATEST_VERSION = defaults.LATEST_VERSION
39-
4038
def __init__(
4139
self,
4240
entry_point,

src/sagemaker/pytorch/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
:class:`~sagemaker.model.Model`.
115115
"""
116116
validate_version_or_image_args(framework_version, py_version, image)
117-
if py_version and py_version == "py2":
117+
if py_version == "py2":
118118
logger.warning(
119119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
120120
)

src/sagemaker/sklearn/defaults.py

-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,3 @@
1414
from __future__ import absolute_import
1515

1616
SKLEARN_NAME = "scikit-learn"
17-
18-
SKLEARN_VERSION = "0.20.0"
19-
20-
LATEST_PY2_VERSION = "0.20.0"

src/sagemaker/tensorflow/defaults.py

-11
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
TF_VERSION = "1.11"
17-
"""Default TF version for when the framework version is not specified.
18-
This is no longer updated so as to not break existing workflows.
19-
"""
20-
21-
LATEST_VERSION = "2.2.0"
22-
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
23-
24-
LATEST_SERVING_VERSION = "2.1.0"
25-
"""The latest version of TensorFlow Serving included in the SageMaker pre-built Docker images."""
26-
2716
LATEST_PY2_VERSION = "2.1.0"

src/sagemaker/tensorflow/estimator.py

-4
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ class TensorFlow(Framework):
3636
__framework_name__ = "tensorflow"
3737
_ECR_REPO_NAME = "tensorflow-scriptmode"
3838

39-
LATEST_VERSION = defaults.LATEST_VERSION
40-
41-
_LATEST_1X_VERSION = "1.15.2"
42-
4339
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
4440
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
4541

src/sagemaker/xgboost/defaults.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@
1515

1616
XGBOOST_NAME = "xgboost"
1717
XGBOOST_1P_VERSIONS = ["1", "latest"]
18-
XGBOOST_VERSION_0_90 = "0.90"
18+
19+
# TODO: evaluate usefulness of these constants. they are defined here and only used here
1920
XGBOOST_VERSION_0_90_1 = "0.90-1"
2021
XGBOOST_VERSION_0_90_2 = "0.90-2"
22+
2123
XGBOOST_LATEST_VERSION = "1.0-1"
24+
2225
# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
2326
XGBOOST_SUPPORTED_VERSIONS = [
2427
XGBOOST_VERSION_0_90_1,
2528
XGBOOST_VERSION_0_90_2,
2629
XGBOOST_LATEST_VERSION,
2730
]
31+
32+
# TODO: evaluate use of this constant. it's used in precisely one place in different a module
33+
# may possibly be unnecessary indirection
2834
XGBOOST_VERSION_EQUIVALENTS = ["-cpu-py3"]

tests/conftest.py

+39-21
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,8 @@
2121
from botocore.config import Config
2222

2323
from sagemaker import Session, utils
24-
from sagemaker.chainer import Chainer
2524
from sagemaker.local import LocalSession
26-
from sagemaker.mxnet import MXNet
27-
from sagemaker.pytorch import PyTorch
2825
from sagemaker.rl import RLEstimator
29-
from sagemaker.sklearn.defaults import SKLEARN_VERSION
30-
from sagemaker.tensorflow import TensorFlow
31-
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION
3226

3327
DEFAULT_REGION = "us-west-2"
3428
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
@@ -44,15 +38,19 @@
4438

4539
NO_T2_REGIONS = ["eu-north-1", "ap-east-1", "me-south-1"]
4640

41+
# TODO: refactor handling of versions, repo, image uris, validations for all frameworks
42+
TENSORFLOW_LATEST_VERSION = "2.2.0"
43+
TENSORFLOW_LATEST_1X_VERSION = "1.15.2"
44+
4745

4846
def pytest_addoption(parser):
4947
parser.addoption("--sagemaker-client-config", action="store", default=None)
5048
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
5149
parser.addoption("--boto-config", action="store", default=None)
52-
parser.addoption("--chainer-full-version", action="store", default=Chainer.LATEST_VERSION)
53-
parser.addoption("--mxnet-full-version", action="store", default=MXNet.LATEST_VERSION)
50+
parser.addoption("--chainer-full-version", action="store", default="5.0.0")
51+
parser.addoption("--mxnet-full-version", action="store", default="1.6.0")
5452
parser.addoption("--ei-mxnet-full-version", action="store", default="1.5.1")
55-
parser.addoption("--pytorch-full-version", action="store", default=PyTorch.LATEST_VERSION)
53+
parser.addoption("--pytorch-full-version", action="store", default="1.5.0")
5654
parser.addoption(
5755
"--rl-coach-mxnet-full-version",
5856
action="store",
@@ -64,10 +62,10 @@ def pytest_addoption(parser):
6462
parser.addoption(
6563
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
6664
)
67-
parser.addoption("--sklearn-full-version", action="store", default=SKLEARN_VERSION)
65+
parser.addoption("--sklearn-full-version", action="store", default="0.20.0")
6866
parser.addoption("--tf-full-version", action="store")
6967
parser.addoption("--ei-tf-full-version", action="store")
70-
parser.addoption("--xgboost-full-version", action="store", default=SKLEARN_VERSION)
68+
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
7169

7270

7371
def pytest_configure(config):
@@ -291,7 +289,27 @@ def sklearn_full_version(request):
291289
return request.config.getoption("--sklearn-full-version")
292290

293291

294-
@pytest.fixture(scope="module", params=[TensorFlow._LATEST_1X_VERSION, LATEST_VERSION])
292+
@pytest.fixture(scope="module", params=[TENSORFLOW_LATEST_VERSION])
293+
def tf_latest_version(request):
294+
return request.param
295+
296+
297+
@pytest.fixture(scope="module")
298+
def tf_latest_py_version():
299+
return "py37"
300+
301+
302+
@pytest.fixture(scope="module", params=[TENSORFLOW_LATEST_1X_VERSION])
303+
def tf_latest_1x_version(request):
304+
return request.param
305+
306+
307+
@pytest.fixture(scope="module")
308+
def tf_latest_serving_version():
309+
return "2.1.0"
310+
311+
312+
@pytest.fixture(scope="module", params=[TENSORFLOW_LATEST_VERSION, TENSORFLOW_LATEST_1X_VERSION])
295313
def tf_full_version(request):
296314
tf_version = request.config.getoption("--tf-full-version")
297315
if tf_version is None:
@@ -301,7 +319,7 @@ def tf_full_version(request):
301319

302320

303321
@pytest.fixture(scope="module")
304-
def tf_full_py_version(tf_full_version, request):
322+
def tf_full_py_version(tf_full_version, tf_latest_version, tf_latest_1x_version):
305323
"""fixture to match tf_full_version
306324
307325
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
@@ -312,11 +330,18 @@ def tf_full_py_version(tf_full_version, request):
312330
version = [int(val) for val in tf_full_version.split(".")]
313331
if version < [1, 11]:
314332
return "py2"
315-
if tf_full_version in [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]:
333+
if tf_full_version in [tf_latest_version, tf_latest_1x_version]:
316334
return "py37"
317335
return "py3"
318336

319337

338+
@pytest.fixture(scope="module")
339+
def tf_serving_version(tf_full_version, tf_latest_version, tf_latest_serving_version):
340+
if tf_full_version == tf_latest_version:
341+
return tf_latest_serving_version
342+
return tf_full_version
343+
344+
320345
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
321346
def ei_tf_full_version(request):
322347
tf_ei_version = request.config.getoption("--ei-tf-full-version")
@@ -384,10 +409,3 @@ def pytest_generate_tests(metafunc):
384409
@pytest.fixture(scope="module")
385410
def xgboost_full_version(request):
386411
return request.config.getoption("--xgboost-full-version")
387-
388-
389-
@pytest.fixture(scope="module")
390-
def tf_serving_version(tf_full_version):
391-
if tf_full_version == LATEST_VERSION:
392-
return LATEST_SERVING_VERSION
393-
return tf_full_version

0 commit comments

Comments
 (0)