Skip to content

change: add py2 deprecation message to estimators and models #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_

super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning('chainer py2 container will be deprecated soon.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • s/chainer py2 container/The Python 2 Chainer images
    • same applies to the others (also make sure the right framework name is in the message for some of the ones below)
  • I'd add a more actionable statement after the warning, e.g. "Set py_version='py3' to use a Python 3 image." This would be especially helpful with the frameworks where Python 2 is the default - users may not be aware of the py_version kwarg when using those.


self.py_version = py_version
self.use_mpi = use_mpi
self.num_processes = num_processes
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger('sagemaker')


class ChainerPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against Chainer Endpoints.
Expand Down Expand Up @@ -66,6 +70,9 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3',
"""
super(ChainerModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)
if py_version == 'py2':
logger.warning('chainer py2 container will be deprecated soon.')

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio

super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning('mxnet py2 container will be deprecated soon.')

self.py_version = py_version
self._configure_distribution(distributions)

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer

logger = logging.getLogger('sagemaker')


class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.
Expand Down Expand Up @@ -66,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
"""
super(MXNetModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)

if py_version == 'py2':
logger.warning('chainer py2 container will be deprecated soon.')

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
self.framework_version = framework_version or PYTORCH_VERSION

super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs)

if py_version == 'py2':
logger.warning('pytorch py2 container will be deprecated soon.')

self.py_version = py_version

def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger('sagemaker')


class PyTorchPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against PyTorch Endpoints.
Expand Down Expand Up @@ -65,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
"""
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)

if py_version == 'py2':
logger.warning('pytorch py2 container will be deprecated soon.')

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N

super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
self.checkpoint_path = checkpoint_path

if py_version == 'py2':
LOGGER.warning('tensorflow py2 container will be deprecated soon.')

self.py_version = py_version
self.training_steps = training_steps
self.evaluation_steps = evaluation_steps
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer

LOGGER = logging.getLogger('sagemaker')


class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against TensorFlow endpoint.
Expand Down Expand Up @@ -67,6 +71,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
"""
super(TensorFlowModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
**kwargs)

if py_version == 'py2':
LOGGER.warning('tensorflow py2 container will be deprecated soon.')

self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers
Expand Down