Skip to content

change: add py2 deprecation message to estimators and models #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
python_deprecation_warning
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.chainer.model import ChainerModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -90,6 +91,10 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_

super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.use_mpi = use_mpi
self.num_processes = num_processes
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger('sagemaker')


class ChainerPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against Chainer Endpoints.
Expand Down Expand Up @@ -66,6 +70,9 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3',
"""
super(ChainerModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)
if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
LATER_FRAMEWORK_VERSION_WARNING = 'This is not the latest supported version. ' \
'If you would like to use version {latest}, ' \
'please add framework_version={latest} to your constructor.'
PYTHON_2_DEPRECATION_WARNING = 'The Python 2 {framework} images will be soon deprecated and may not be ' \
'supported for newer upcoming versions of the {framework} images.\n' \
'Please set the argument \"py_version=\'py3\'\" to use the Python 3 {framework} image.'


EMPTY_FRAMEWORK_VERSION_ERROR = 'framework_version is required for script mode estimator. ' \
'Please add framework_version={} to your constructor to avoid this error.'
Expand Down Expand Up @@ -303,3 +307,7 @@ def empty_framework_version_warning(default_version, latest_version):
if default_version != latest_version:
msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version))
return ' '.join(msgs)


def python_deprecation_warning(framework):
return PYTHON_2_DEPRECATION_WARNING.format(framework=framework)
7 changes: 6 additions & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
python_deprecation_warning
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.mxnet.model import MXNetModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio

super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self._configure_distribution(distributions)

Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer

logger = logging.getLogger('sagemaker')


class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.
Expand Down Expand Up @@ -66,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
"""
super(MXNetModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
python_deprecation_warning
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -74,6 +75,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
self.framework_version = framework_version or PYTORCH_VERSION

super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version

def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger('sagemaker')


class PyTorchPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against PyTorch Endpoints.
Expand Down Expand Up @@ -65,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
"""
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sagemaker.estimator import Framework
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning, python_deprecation_warning
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME
from sagemaker.sklearn.model import SKLearnModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -79,6 +79,9 @@ def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=No
super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name,
**dict(kwargs, train_instance_count=1))

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version

if framework_version is None:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME

logger = logging.getLogger('sagemaker')


class SKLearnPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against Scikit-learn Endpoints.
Expand Down Expand Up @@ -68,6 +72,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3',
"""
super(SKLearnModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

LOGGER = logging.getLogger('sagemaker')
logger = logging.getLogger('sagemaker')


_FRAMEWORK_MODE_ARGS = ('training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path')
Expand Down Expand Up @@ -154,7 +154,7 @@ def run(self):
"""Run TensorBoard process."""
port, tensorboard_process = self.create_tensorboard_process()

LOGGER.info('TensorBoard 0.1.7 at http://localhost:{}'.format(port))
logger.info('TensorBoard 0.1.7 at http://localhost:{}'.format(port))
while not self.estimator.checkpoint_path:
self.event.wait(1)
with self._temporary_directory() as aws_sync_dir:
Expand Down Expand Up @@ -231,11 +231,15 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
**kwargs: Additional kwargs passed to the Framework constructor.
"""
if framework_version is None:
LOGGER.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
self.framework_version = framework_version or TF_VERSION

super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
self.checkpoint_path = checkpoint_path

if py_version == 'py2':
logger.warning('tensorflow py2 container will be deprecated soon.')

self.py_version = py_version
self.training_steps = training_steps
self.evaluation_steps = evaluation_steps
Expand Down Expand Up @@ -320,7 +324,7 @@ def fit_super():
raise ValueError("Tensorboard is not supported with async fit")

if self._script_mode_enabled() and run_tensorboard_locally:
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
logger.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
fit_super()
elif run_tensorboard_locally:
tensorboard = Tensorboard(self)
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer

logger = logging.getLogger('sagemaker')


class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against TensorFlow endpoint.
Expand Down Expand Up @@ -67,6 +71,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
"""
super(TensorFlowModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)

if py_version == 'py2':
logger.warning(python_deprecation_warning(self.__framework_name__))

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down