Skip to content

Commit 6aa5da7

Browse files
authored
Merge branch 'zwei' into deprecate-util-functions
2 parents 5ebe225 + 1562d90 commit 6aa5da7

20 files changed

+44
-44
lines changed

src/sagemaker/chainer/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class Chainer(Framework):
3333
"""Handle end-to-end training and deployment of custom Chainer code."""
3434

35-
__framework_name__ = "chainer"
35+
_framework_name = "chainer"
3636

3737
# Hyperparameters
3838
_use_mpi = "sagemaker_use_mpi"
@@ -131,7 +131,7 @@ def __init__(
131131
validate_version_or_image_args(framework_version, py_version, image_uri)
132132
if py_version == "py2":
133133
logger.warning(
134-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
134+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
135135
)
136136
self.framework_version = framework_version
137137
self.py_version = py_version
@@ -272,7 +272,7 @@ class constructor
272272
init_params["image_uri"] = image_uri
273273
return init_params
274274

275-
if framework != cls.__framework_name__:
275+
if framework != cls._framework_name:
276276
raise ValueError(
277277
"Training job: {} didn't use image for requested framework".format(
278278
job_details["TrainingJobName"]

src/sagemaker/chainer/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ChainerModel(FrameworkModel):
5959
``Endpoint``.
6060
"""
6161

62-
__framework_name__ = "chainer"
62+
_framework_name = "chainer"
6363

6464
def __init__(
6565
self,
@@ -116,7 +116,7 @@ def __init__(
116116
validate_version_or_image_args(framework_version, py_version, image_uri)
117117
if py_version == "py2":
118118
logger.warning(
119-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
120120
)
121121
self.framework_version = framework_version
122122
self.py_version = py_version
@@ -176,7 +176,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
176176
177177
"""
178178
return image_uris.retrieve(
179-
self.__framework_name__,
179+
self._framework_name,
180180
region_name,
181181
version=self.framework_version,
182182
py_version=self.py_version,

src/sagemaker/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,7 @@ class Framework(EstimatorBase):
14181418
such as training/deployment images and predictor instances.
14191419
"""
14201420

1421-
__framework_name__ = None
1421+
_framework_name = None
14221422

14231423
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
14241424
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
@@ -1816,7 +1816,7 @@ def train_image(self):
18161816
if self.image_uri:
18171817
return self.image_uri
18181818
return image_uris.retrieve(
1819-
self.__framework_name__,
1819+
self._framework_name,
18201820
self.sagemaker_session.boto_region_name,
18211821
instance_type=self.instance_type,
18221822
version=self.framework_version, # pylint: disable=no-member

src/sagemaker/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _set_model_name_if_needed(self):
184184

185185
def _framework(self):
186186
"""Placeholder docstring"""
187-
return getattr(self, "__framework_name__", None)
187+
return getattr(self, "_framework_name", None)
188188

189189
def _get_framework_version(self):
190190
"""Placeholder docstring"""

src/sagemaker/mxnet/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
class MXNet(Framework):
3535
"""Handle end-to-end training and deployment of custom MXNet code."""
3636

37-
__framework_name__ = "mxnet"
37+
_framework_name = "mxnet"
3838
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]
3939

4040
def __init__(
@@ -114,7 +114,7 @@ def __init__(
114114
validate_version_or_image_args(framework_version, py_version, image_uri)
115115
if py_version == "py2":
116116
logger.warning(
117-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
117+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
118118
)
119119
self.framework_version = framework_version
120120
self.py_version = py_version
@@ -280,7 +280,7 @@ class constructor
280280
init_params["image_uri"] = image_uri
281281
return init_params
282282

283-
if framework != cls.__framework_name__:
283+
if framework != cls._framework_name:
284284
raise ValueError(
285285
"Training job: {} didn't use image for requested framework".format(
286286
job_details["TrainingJobName"]

src/sagemaker/mxnet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5959
class MXNetModel(FrameworkModel):
6060
"""An MXNet SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
6161

62-
__framework_name__ = "mxnet"
62+
_framework_name = "mxnet"
6363
_LOWEST_MMS_VERSION = "1.4.0"
6464

6565
def __init__(
@@ -119,7 +119,7 @@ def __init__(
119119
validate_version_or_image_args(framework_version, py_version, image_uri)
120120
if py_version == "py2":
121121
logger.warning(
122-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
122+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
123123
)
124124
self.framework_version = framework_version
125125
self.py_version = py_version
@@ -184,7 +184,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
184184
185185
"""
186186
return image_uris.retrieve(
187-
self.__framework_name__,
187+
self._framework_name,
188188
region_name,
189189
version=self.framework_version,
190190
py_version=self.py_version,

src/sagemaker/pytorch/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class PyTorch(Framework):
3434
"""Handle end-to-end training and deployment of custom PyTorch code."""
3535

36-
__framework_name__ = "pytorch"
36+
_framework_name = "pytorch"
3737

3838
def __init__(
3939
self,
@@ -109,7 +109,7 @@ def __init__(
109109
validate_version_or_image_args(framework_version, py_version, image_uri)
110110
if py_version == "py2":
111111
logger.warning(
112-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
112+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
113113
)
114114
self.framework_version = framework_version
115115
self.py_version = py_version
@@ -221,7 +221,7 @@ class constructor
221221
init_params["image_uri"] = image_uri
222222
return init_params
223223

224-
if framework != cls.__framework_name__:
224+
if framework != cls._framework_name:
225225
raise ValueError(
226226
"Training job: {} didn't use image for requested framework".format(
227227
job_details["TrainingJobName"]

src/sagemaker/pytorch/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class PyTorchModel(FrameworkModel):
6060
``Endpoint``.
6161
"""
6262

63-
__framework_name__ = "pytorch"
63+
_framework_name = "pytorch"
6464
_LOWEST_MMS_VERSION = "1.2"
6565

6666
def __init__(
@@ -118,7 +118,7 @@ def __init__(
118118
validate_version_or_image_args(framework_version, py_version, image_uri)
119119
if py_version == "py2":
120120
logger.warning(
121-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
121+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
122122
)
123123
self.framework_version = framework_version
124124
self.py_version = py_version
@@ -183,7 +183,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
183183
184184
"""
185185
return image_uris.retrieve(
186-
self.__framework_name__,
186+
self._framework_name,
187187
region_name,
188188
version=self.framework_version,
189189
py_version=self.py_version,

src/sagemaker/sklearn/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class SKLearn(Framework):
3333
"""Handle end-to-end training and deployment of custom Scikit-learn code."""
3434

35-
__framework_name__ = defaults.SKLEARN_NAME
35+
_framework_name = defaults.SKLEARN_NAME
3636

3737
def __init__(
3838
self,
@@ -138,7 +138,7 @@ def __init__(
138138

139139
if image_uri is None:
140140
self.image_uri = image_uris.retrieve(
141-
SKLearn.__framework_name__,
141+
SKLearn._framework_name,
142142
self.sagemaker_session.boto_region_name,
143143
version=self.framework_version,
144144
py_version=self.py_version,

src/sagemaker/sklearn/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SKLearnModel(FrameworkModel):
5555
``Endpoint``.
5656
"""
5757

58-
__framework_name__ = defaults.SKLEARN_NAME
58+
_framework_name = defaults.SKLEARN_NAME
5959

6060
def __init__(
6161
self,
@@ -175,7 +175,7 @@ def serving_image_uri(self, region_name, instance_type):
175175
176176
"""
177177
return image_uris.retrieve(
178-
self.__framework_name__,
178+
self._framework_name,
179179
region_name,
180180
version=self.framework_version,
181181
py_version=self.py_version,

src/sagemaker/tensorflow/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class TensorFlow(Framework):
3434
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
3535

36-
__framework_name__ = "tensorflow"
36+
_framework_name = "tensorflow"
3737

3838
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
3939
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
@@ -116,7 +116,7 @@ def __init__(
116116
fw.validate_version_or_image_args(framework_version, py_version, image_uri)
117117
if py_version == "py2":
118118
logger.warning(
119-
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119+
fw.python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
120120
)
121121
self.framework_version = framework_version
122122
self.py_version = py_version
@@ -221,7 +221,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
221221
if not script_mode:
222222
init_params["image_uri"] = image_uri
223223

224-
if framework != cls.__framework_name__:
224+
if framework != cls._framework_name:
225225
raise ValueError(
226226
"Training job: {} didn't use image for requested framework".format(
227227
job_details["TrainingJobName"]

src/sagemaker/tensorflow/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def predict(self, data, initial_args=None):
121121
class TensorFlowModel(sagemaker.model.FrameworkModel):
122122
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""
123123

124-
__framework_name__ = "tensorflow"
124+
_framework_name = "tensorflow"
125125
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
126126
LOG_LEVEL_MAP = {
127127
logging.DEBUG: "debug",
@@ -286,7 +286,7 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
286286
return self.image_uri
287287

288288
return image_uris.retrieve(
289-
self.__framework_name__,
289+
self._framework_name,
290290
self.sagemaker_session.boto_region_name,
291291
version=self.framework_version,
292292
instance_type=instance_type,

src/sagemaker/xgboost/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class XGBoost(Framework):
3434
"""Handle end-to-end training and deployment of XGBoost booster training or training using
3535
customer provided XGBoost entry point script."""
3636

37-
__framework_name__ = defaults.XGBOOST_NAME
37+
_framework_name = defaults.XGBOOST_NAME
3838

3939
def __init__(
4040
self,
@@ -103,7 +103,7 @@ def __init__(
103103

104104
if image_uri is None:
105105
self.image_uri = image_uris.retrieve(
106-
self.__framework_name__,
106+
self._framework_name,
107107
self.sagemaker_session.boto_region_name,
108108
version=framework_version,
109109
py_version=self.py_version,
@@ -252,7 +252,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252252
framework, py_version, tag, _ = framework_name_from_image(image_uri)
253253
init_params["py_version"] = py_version
254254

255-
if framework and framework != cls.__framework_name__:
255+
if framework and framework != cls._framework_name:
256256
raise ValueError(
257257
"Training job: {} didn't use image for requested framework".format(
258258
job_details["TrainingJobName"]

src/sagemaker/xgboost/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5151
class XGBoostModel(FrameworkModel):
5252
"""An XGBoost SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
5353

54-
__framework_name__ = XGBOOST_NAME
54+
_framework_name = XGBOOST_NAME
5555

5656
def __init__(
5757
self,
@@ -144,7 +144,7 @@ def serving_image_uri(self, region_name, instance_type):
144144
str: The appropriate image URI based on the given parameters.
145145
"""
146146
return image_uris.retrieve(
147-
self.__framework_name__,
147+
self._framework_name,
148148
region_name,
149149
version=self.framework_version,
150150
py_version=self.py_version,

tests/unit/sagemaker/tensorflow/test_estimator_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_py2_version_is_not_deprecated(sagemaker_session):
6565

6666
def test_framework_name(sagemaker_session):
6767
tf = _build_tf(sagemaker_session, framework_version="1.15.2", py_version="py3")
68-
assert tf.__framework_name__ == "tensorflow"
68+
assert tf._framework_name == "tensorflow"
6969

7070

7171
def test_enable_sm_metrics(sagemaker_session):

tests/unit/test_chainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def test_estimator_py2_warning(warning, sagemaker_session, chainer_version):
561561
)
562562

563563
assert estimator.py_version == "py2"
564-
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
564+
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)
565565

566566

567567
@patch("sagemaker.chainer.model.python_deprecation_warning")
@@ -575,4 +575,4 @@ def test_model_py2_warning(warning, sagemaker_session, chainer_version):
575575
py_version="py2",
576576
)
577577
assert model.py_version == "py2"
578-
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
578+
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)

tests/unit/test_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107

108108

109109
class DummyFramework(Framework):
110-
__framework_name__ = "dummy"
110+
_framework_name = "dummy"
111111

112112
def train_image(self):
113113
return IMAGE_URI

tests/unit/test_job.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def sagemaker_session():
8181

8282

8383
class DummyFramework(Framework):
84-
__framework_name__ = "dummy"
84+
_framework_name = "dummy"
8585

8686
def train_image(self):
8787
return IMAGE_NAME

tests/unit/test_mxnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def test_estimator_py2_warning(warning, sagemaker_session):
715715
)
716716

717717
assert estimator.py_version == "py2"
718-
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
718+
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)
719719

720720

721721
@patch("sagemaker.mxnet.model.python_deprecation_warning")
@@ -729,7 +729,7 @@ def test_model_py2_warning(warning, sagemaker_session):
729729
sagemaker_session=sagemaker_session,
730730
)
731731
assert model.py_version == "py2"
732-
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
732+
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)
733733

734734

735735
def test_create_model_with_custom_hosting_image(sagemaker_session):

tests/unit/test_pytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def test_estimator_py2_warning(warning, sagemaker_session, pytorch_training_vers
553553
)
554554

555555
assert estimator.py_version == "py2"
556-
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
556+
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)
557557

558558

559559
@patch("sagemaker.pytorch.model.python_deprecation_warning")
@@ -567,7 +567,7 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_inference_version
567567
py_version="py2",
568568
)
569569
assert model.py_version == "py2"
570-
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
570+
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)
571571

572572

573573
def test_pt_enable_sm_metrics(

0 commit comments

Comments
 (0)