Skip to content

change: rename __framework_name__ as _framework_name #1775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class Chainer(Framework):
"""Handle end-to-end training and deployment of custom Chainer code."""

__framework_name__ = "chainer"
_framework_name = "chainer"

# Hyperparameters
_use_mpi = "sagemaker_use_mpi"
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -272,7 +272,7 @@ class constructor
init_params["image_uri"] = image_uri
return init_params

if framework != cls.__framework_name__:
if framework != cls._framework_name:
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ChainerModel(FrameworkModel):
``Endpoint``.
"""

__framework_name__ = "chainer"
_framework_name = "chainer"

def __init__(
self,
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -176,7 +176,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):

"""
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ class Framework(EstimatorBase):
such as training/deployment images and predictor instances.
"""

__framework_name__ = None
_framework_name = None

LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
Expand Down Expand Up @@ -1816,7 +1816,7 @@ def train_image(self):
if self.image_uri:
return self.image_uri
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version, # pylint: disable=no-member
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _set_model_name_if_needed(self):

def _framework(self):
"""Placeholder docstring"""
return getattr(self, "__framework_name__", None)
return getattr(self, "_framework_name", None)

def _get_framework_version(self):
"""Placeholder docstring"""
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class MXNet(Framework):
"""Handle end-to-end training and deployment of custom MXNet code."""

__framework_name__ = "mxnet"
_framework_name = "mxnet"
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]

def __init__(
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -280,7 +280,7 @@ class constructor
init_params["image_uri"] = image_uri
return init_params

if framework != cls.__framework_name__:
if framework != cls._framework_name:
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
class MXNetModel(FrameworkModel):
"""An MXNet SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

__framework_name__ = "mxnet"
_framework_name = "mxnet"
_LOWEST_MMS_VERSION = "1.4.0"

def __init__(
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -184,7 +184,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):

"""
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class PyTorch(Framework):
"""Handle end-to-end training and deployment of custom PyTorch code."""

__framework_name__ = "pytorch"
_framework_name = "pytorch"

def __init__(
self,
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -221,7 +221,7 @@ class constructor
init_params["image_uri"] = image_uri
return init_params

if framework != cls.__framework_name__:
if framework != cls._framework_name:
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PyTorchModel(FrameworkModel):
``Endpoint``.
"""

__framework_name__ = "pytorch"
_framework_name = "pytorch"
_LOWEST_MMS_VERSION = "1.2"

def __init__(
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -183,7 +183,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):

"""
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class SKLearn(Framework):
"""Handle end-to-end training and deployment of custom Scikit-learn code."""

__framework_name__ = defaults.SKLEARN_NAME
_framework_name = defaults.SKLEARN_NAME

def __init__(
self,
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(

if image_uri is None:
self.image_uri = image_uris.retrieve(
SKLearn.__framework_name__,
SKLearn._framework_name,
self.sagemaker_session.boto_region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SKLearnModel(FrameworkModel):
``Endpoint``.
"""

__framework_name__ = defaults.SKLEARN_NAME
_framework_name = defaults.SKLEARN_NAME

def __init__(
self,
Expand Down Expand Up @@ -175,7 +175,7 @@ def serving_image_uri(self, region_name, instance_type):

"""
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class TensorFlow(Framework):
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""

__framework_name__ = "tensorflow"
_framework_name = "tensorflow"

_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
fw.validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
fw.python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version
Expand Down Expand Up @@ -221,7 +221,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
if not script_mode:
init_params["image_uri"] = image_uri

if framework != cls.__framework_name__:
if framework != cls._framework_name:
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def predict(self, data, initial_args=None):
class TensorFlowModel(sagemaker.model.FrameworkModel):
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""

__framework_name__ = "tensorflow"
_framework_name = "tensorflow"
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
LOG_LEVEL_MAP = {
logging.DEBUG: "debug",
Expand Down Expand Up @@ -286,7 +286,7 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
return self.image_uri

return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
self.sagemaker_session.boto_region_name,
version=self.framework_version,
instance_type=instance_type,
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class XGBoost(Framework):
"""Handle end-to-end training and deployment of XGBoost booster training or training using
customer provided XGBoost entry point script."""

__framework_name__ = defaults.XGBOOST_NAME
_framework_name = defaults.XGBOOST_NAME

def __init__(
self,
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(

if image_uri is None:
self.image_uri = image_uris.retrieve(
self.__framework_name__,
self._framework_name,
self.sagemaker_session.boto_region_name,
version=framework_version,
py_version=self.py_version,
Expand Down Expand Up @@ -252,7 +252,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
framework, py_version, tag, _ = framework_name_from_image(image_uri)
init_params["py_version"] = py_version

if framework and framework != cls.__framework_name__:
if framework and framework != cls._framework_name:
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
class XGBoostModel(FrameworkModel):
"""An XGBoost SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

__framework_name__ = XGBOOST_NAME
_framework_name = XGBOOST_NAME

def __init__(
self,
Expand Down Expand Up @@ -144,7 +144,7 @@ def serving_image_uri(self, region_name, instance_type):
str: The appropriate image URI based on the given parameters.
"""
return image_uris.retrieve(
self.__framework_name__,
self._framework_name,
region_name,
version=self.framework_version,
py_version=self.py_version,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/tensorflow/test_estimator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_py2_version_is_not_deprecated(sagemaker_session):

def test_framework_name(sagemaker_session):
tf = _build_tf(sagemaker_session, framework_version="1.15.2", py_version="py3")
assert tf.__framework_name__ == "tensorflow"
assert tf._framework_name == "tensorflow"


def test_enable_sm_metrics(sagemaker_session):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def test_estimator_py2_warning(warning, sagemaker_session, chainer_version):
)

assert estimator.py_version == "py2"
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)


@patch("sagemaker.chainer.model.python_deprecation_warning")
Expand All @@ -575,4 +575,4 @@ def test_model_py2_warning(warning, sagemaker_session, chainer_version):
py_version="py2",
)
assert model.py_version == "py2"
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@


class DummyFramework(Framework):
__framework_name__ = "dummy"
_framework_name = "dummy"

def train_image(self):
return IMAGE_URI
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def sagemaker_session():


class DummyFramework(Framework):
__framework_name__ = "dummy"
_framework_name = "dummy"

def train_image(self):
return IMAGE_NAME
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def test_estimator_py2_warning(warning, sagemaker_session):
)

assert estimator.py_version == "py2"
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)


@patch("sagemaker.mxnet.model.python_deprecation_warning")
Expand All @@ -729,7 +729,7 @@ def test_model_py2_warning(warning, sagemaker_session):
sagemaker_session=sagemaker_session,
)
assert model.py_version == "py2"
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)


def test_create_model_with_custom_hosting_image(sagemaker_session):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_estimator_py2_warning(warning, sagemaker_session, pytorch_training_vers
)

assert estimator.py_version == "py2"
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(estimator._framework_name, defaults.LATEST_PY2_VERSION)


@patch("sagemaker.pytorch.model.python_deprecation_warning")
Expand All @@ -567,7 +567,7 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_inference_version
py_version="py2",
)
assert model.py_version == "py2"
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION)


def test_pt_enable_sm_metrics(
Expand Down