Skip to content

Commit 4763545

Browse files
author
Chuyang Deng
committed
improve imports and error message
1 parent 2650374 commit 4763545

File tree

12 files changed

+103
-66
lines changed

12 files changed

+103
-66
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
2424
)
25-
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION
25+
from sagemaker.chainer import defaults
2626
from sagemaker.chainer.model import ChainerModel
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

@@ -40,7 +40,7 @@ class Chainer(Framework):
4040
_process_slots_per_host = "sagemaker_process_slots_per_host"
4141
_additional_mpi_options = "sagemaker_additional_mpi_options"
4242

43-
LATEST_VERSION = LATEST_VERSION
43+
LATEST_VERSION = defaults.LATEST_VERSION
4444

4545
def __init__(
4646
self,
@@ -126,15 +126,19 @@ def __init__(
126126
:class:`~sagemaker.estimator.EstimatorBase`.
127127
"""
128128
if framework_version is None:
129-
logger.warning(empty_framework_version_warning(CHAINER_VERSION, self.LATEST_VERSION))
130-
self.framework_version = framework_version or CHAINER_VERSION
129+
logger.warning(
130+
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
131+
)
132+
self.framework_version = framework_version or defaults.CHAINER_VERSION
131133

132134
super(Chainer, self).__init__(
133135
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
134136
)
135137

136138
if py_version == "py2":
137-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
139+
logger.warning(
140+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
141+
)
138142

139143
self.py_version = py_version
140144
self.use_mpi = use_mpi

src/sagemaker/chainer/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
empty_framework_version_warning,
2424
)
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
26-
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION
26+
from sagemaker.chainer import defaults
2727
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2828

2929
logger = logging.getLogger("sagemaker")
@@ -111,13 +111,17 @@ def __init__(
111111
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
112112
)
113113
if py_version == "py2":
114-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
114+
logger.warning(
115+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
116+
)
115117

116118
if framework_version is None:
117-
logger.warning(empty_framework_version_warning(CHAINER_VERSION, LATEST_VERSION))
119+
logger.warning(
120+
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
121+
)
118122

119123
self.py_version = py_version
120-
self.framework_version = framework_version or CHAINER_VERSION
124+
self.framework_version = framework_version or defaults.CHAINER_VERSION
121125
self.model_server_workers = model_server_workers
122126

123127
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/mxnet/estimator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
python_deprecation_warning,
2424
is_version_equal_or_higher,
2525
)
26-
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION
26+
from sagemaker.mxnet import defaults
2727
from sagemaker.mxnet.model import MXNetModel
2828
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2929

@@ -36,7 +36,7 @@ class MXNet(Framework):
3636
__framework_name__ = "mxnet"
3737
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]
3838

39-
LATEST_VERSION = LATEST_VERSION
39+
LATEST_VERSION = defaults.LATEST_VERSION
4040

4141
def __init__(
4242
self,
@@ -107,8 +107,10 @@ def __init__(
107107
:class:`~sagemaker.estimator.EstimatorBase`.
108108
"""
109109
if framework_version is None:
110-
logger.warning(empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION))
111-
self.framework_version = framework_version or MXNET_VERSION
110+
logger.warning(
111+
empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)
112+
)
113+
self.framework_version = framework_version or defaults.MXNET_VERSION
112114

113115
if "enable_sagemaker_metrics" not in kwargs:
114116
# enable sagemaker metrics for MXNet v1.6 or greater:
@@ -120,7 +122,9 @@ def __init__(
120122
)
121123

122124
if py_version == "py2":
123-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
125+
logger.warning(
126+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
127+
)
124128

125129
self.py_version = py_version
126130
self._configure_distribution(distributions)

src/sagemaker/mxnet/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
empty_framework_version_warning,
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
28-
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION
28+
from sagemaker.mxnet import defaults
2929
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
3030

3131
logger = logging.getLogger("sagemaker")
@@ -113,13 +113,17 @@ def __init__(
113113
)
114114

115115
if py_version == "py2":
116-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
116+
logger.warning(
117+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118+
)
117119

118120
if framework_version is None:
119-
logger.warning(empty_framework_version_warning(MXNET_VERSION, LATEST_VERSION))
121+
logger.warning(
122+
empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
123+
)
120124

121125
self.py_version = py_version
122-
self.framework_version = framework_version or MXNET_VERSION
126+
self.framework_version = framework_version or defaults.MXNET_VERSION
123127
self.model_server_workers = model_server_workers
124128

125129
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/pytorch/estimator.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@
2323
python_deprecation_warning,
2424
is_version_equal_or_higher,
2525
)
26-
from sagemaker.pytorch.defaults import (
27-
PYTORCH_VERSION,
28-
PYTHON_VERSION,
29-
LATEST_VERSION,
30-
LATEST_PY2_VERSION,
31-
)
26+
from sagemaker.pytorch import defaults
3227
from sagemaker.pytorch.model import PyTorchModel
3328
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3429

@@ -40,14 +35,14 @@ class PyTorch(Framework):
4035

4136
__framework_name__ = "pytorch"
4237

43-
LATEST_VERSION = LATEST_VERSION
38+
LATEST_VERSION = defaults.LATEST_VERSION
4439

4540
def __init__(
4641
self,
4742
entry_point,
4843
source_dir=None,
4944
hyperparameters=None,
50-
py_version=PYTHON_VERSION,
45+
py_version=defaults.PYTHON_VERSION,
5146
framework_version=None,
5247
image_name=None,
5348
**kwargs
@@ -108,8 +103,10 @@ def __init__(
108103
:class:`~sagemaker.estimator.EstimatorBase`.
109104
"""
110105
if framework_version is None:
111-
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, self.LATEST_VERSION))
112-
self.framework_version = framework_version or PYTORCH_VERSION
106+
logger.warning(
107+
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
108+
)
109+
self.framework_version = framework_version or defaults.PYTORCH_VERSION
113110

114111
if "enable_sagemaker_metrics" not in kwargs:
115112
# enable sagemaker metrics for PT v1.3 or greater:
@@ -121,7 +118,9 @@ def __init__(
121118
)
122119

123120
if py_version == "py2":
124-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
121+
logger.warning(
122+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
123+
)
125124

126125
self.py_version = py_version
127126

src/sagemaker/pytorch/model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
empty_framework_version_warning,
2525
)
2626
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
27-
from sagemaker.pytorch.defaults import (
28-
PYTORCH_VERSION,
29-
PYTHON_VERSION,
30-
LATEST_VERSION,
31-
LATEST_PY2_VERSION,
32-
)
27+
from sagemaker.pytorch import defaults
3328
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
3429

3530
logger = logging.getLogger("sagemaker")
@@ -72,7 +67,7 @@ def __init__(
7267
role,
7368
entry_point,
7469
image=None,
75-
py_version=PYTHON_VERSION,
70+
py_version=defaults.PYTHON_VERSION,
7671
framework_version=None,
7772
predictor_cls=PyTorchPredictor,
7873
model_server_workers=None,
@@ -119,13 +114,17 @@ def __init__(
119114
)
120115

121116
if py_version == "py2":
122-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
117+
logger.warning(
118+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119+
)
123120

124121
if framework_version is None:
125-
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION))
122+
logger.warning(
123+
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
124+
)
126125

127126
self.py_version = py_version
128-
self.framework_version = framework_version or PYTORCH_VERSION
127+
self.framework_version = framework_version or defaults.PYTORCH_VERSION
129128
self.model_server_workers = model_server_workers
130129

131130
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/sklearn/estimator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
2424
)
25-
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION
25+
from sagemaker.sklearn import defaults
2626
from sagemaker.sklearn.model import SKLearnModel
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

@@ -32,12 +32,12 @@
3232
class SKLearn(Framework):
3333
"""Handle end-to-end training and deployment of custom Scikit-learn code."""
3434

35-
__framework_name__ = SKLEARN_NAME
35+
__framework_name__ = defaults.SKLEARN_NAME
3636

3737
def __init__(
3838
self,
3939
entry_point,
40-
framework_version=SKLEARN_VERSION,
40+
framework_version=defaults.SKLEARN_VERSION,
4141
source_dir=None,
4242
hyperparameters=None,
4343
py_version="py3",
@@ -119,13 +119,17 @@ def __init__(
119119
)
120120

121121
if py_version == "py2":
122-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
122+
logger.warning(
123+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
124+
)
123125

124126
self.py_version = py_version
125127

126128
if framework_version is None:
127-
logger.warning(empty_framework_version_warning(SKLEARN_VERSION, SKLEARN_VERSION))
128-
self.framework_version = framework_version or SKLEARN_VERSION
129+
logger.warning(
130+
empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION)
131+
)
132+
self.framework_version = framework_version or defaults.SKLEARN_VERSION
129133

130134
if image_name is None:
131135
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)

src/sagemaker/sklearn/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.fw_registry import default_framework_uri
2121
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2222
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
23-
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION
23+
from sagemaker.sklearn import defaults
2424

2525
logger = logging.getLogger("sagemaker")
2626

@@ -53,7 +53,7 @@ class SKLearnModel(FrameworkModel):
5353
``Endpoint``.
5454
"""
5555

56-
__framework_name__ = SKLEARN_NAME
56+
__framework_name__ = defaults.SKLEARN_NAME
5757

5858
def __init__(
5959
self,
@@ -62,7 +62,7 @@ def __init__(
6262
entry_point,
6363
image=None,
6464
py_version="py3",
65-
framework_version=SKLEARN_VERSION,
65+
framework_version=defaults.SKLEARN_VERSION,
6666
predictor_cls=SKLearnPredictor,
6767
model_server_workers=None,
6868
**kwargs
@@ -108,7 +108,9 @@ def __init__(
108108
)
109109

110110
if py_version == "py2":
111-
logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION))
111+
logger.warning(
112+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
113+
)
112114

113115
self.py_version = py_version
114116
self.framework_version = framework_version

src/sagemaker/tensorflow/estimator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sagemaker.debugger import DebuggerHookConfig
2626
from sagemaker.estimator import Framework
2727
import sagemaker.fw_utils as fw
28-
from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION, LATEST_PY2_VERSION
28+
from sagemaker.tensorflow import defaults
2929
from sagemaker.tensorflow.model import TensorFlowModel
3030
from sagemaker.tensorflow.serving import Model
3131
from sagemaker.transformer import Transformer
@@ -197,7 +197,7 @@ class TensorFlow(Framework):
197197

198198
__framework_name__ = "tensorflow"
199199

200-
LATEST_VERSION = LATEST_VERSION
200+
LATEST_VERSION = defaults.LATEST_VERSION
201201

202202
_LATEST_1X_VERSION = "1.15.0"
203203

@@ -288,14 +288,16 @@ def __init__(
288288
:class:`~sagemaker.estimator.EstimatorBase`.
289289
"""
290290
if framework_version is None:
291-
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
292-
self.framework_version = framework_version or TF_VERSION
291+
logger.warning(
292+
fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION)
293+
)
294+
self.framework_version = framework_version or defaults.TF_VERSION
293295

294296
if not py_version:
295297
py_version = "py3" if self._only_python_3_supported() else "py2"
296298
if py_version == "py2":
297299
logger.warning(
298-
fw.python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)
300+
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
299301
)
300302

301303
if "enable_sagemaker_metrics" not in kwargs:
@@ -360,8 +362,8 @@ def _validate_args(
360362

361363
if py_version == "py2" and self._only_python_3_supported():
362364
msg = (
363-
"Python 2 containers are only available before January 1st, 2020. "
364-
"Please use a Python 3 container."
365+
"Python 2 containers are only available with {} and lower versions. "
366+
"Please use a Python 3 container.".format(defaults.LATEST_PY2_VERSION)
365367
)
366368
raise AttributeError(msg)
367369

0 commit comments

Comments
 (0)