From 35ee1f56849c8fd6af95cd9be86031d3e9637b90 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Thu, 9 Jan 2020 13:43:14 -0800 Subject: [PATCH 1/3] fix: update py2 warning message since py2 is deprecated --- src/sagemaker/chainer/defaults.py | 2 ++ src/sagemaker/chainer/estimator.py | 4 ++-- src/sagemaker/chainer/model.py | 5 +++-- src/sagemaker/fw_utils.py | 11 ++++++---- src/sagemaker/mxnet/defaults.py | 2 ++ src/sagemaker/mxnet/estimator.py | 4 ++-- src/sagemaker/mxnet/model.py | 5 +++-- src/sagemaker/pytorch/defaults.py | 2 ++ src/sagemaker/pytorch/estimator.py | 4 ++-- src/sagemaker/pytorch/model.py | 5 +++-- src/sagemaker/sklearn/defaults.py | 2 ++ src/sagemaker/sklearn/estimator.py | 4 ++-- src/sagemaker/sklearn/model.py | 4 ++-- src/sagemaker/tensorflow/defaults.py | 2 ++ src/sagemaker/tensorflow/estimator.py | 11 +++++----- src/sagemaker/tensorflow/model.py | 5 +++-- src/sagemaker/xgboost/estimator.py | 3 +-- src/sagemaker/xgboost/model.py | 4 ++-- tests/unit/test_chainer.py | 28 +++++++++++++++++++++++++ tests/unit/test_mxnet.py | 28 +++++++++++++++++++++++++ tests/unit/test_pytorch.py | 28 +++++++++++++++++++++++++ tests/unit/test_sklearn.py | 30 +++++++++++++++++++++++++++ tests/unit/test_tf_estimator.py | 27 +++++++++++++++++++++++- tests/unit/test_xgboost.py | 27 ++++++++++++++++++++++++ 24 files changed, 215 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/chainer/defaults.py b/src/sagemaker/chainer/defaults.py index 6f429e1ced..49966e9e34 100644 --- a/src/sagemaker/chainer/defaults.py +++ b/src/sagemaker/chainer/defaults.py @@ -20,3 +20,5 @@ LATEST_VERSION = "5.0.0" """The latest version of Chainer included in the SageMaker pre-built Docker images.""" + +LATEST_PY2_VERSION = "5.0.0" diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 884fe0302c..909be9289b 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -22,7 +22,7 @@ empty_framework_version_warning, python_deprecation_warning, ) -from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION +from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.chainer.model import ChainerModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -134,7 +134,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) self.py_version = py_version self.use_mpi = use_mpi diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 32dd121f2c..7c8f8355ee 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -23,7 +23,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION +from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer logger = logging.getLogger("sagemaker") @@ -111,7 +111,8 @@ def __init__( model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + if framework_version is None: logger.warning(empty_framework_version_warning(CHAINER_VERSION, LATEST_VERSION)) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 59382594e5..bfb8abc453 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -38,8 +38,8 @@ "please add framework_version={latest} to your constructor." ) PYTHON_2_DEPRECATION_WARNING = ( - "The Python 2 {framework} images will be soon deprecated and may not be " - "supported for newer upcoming versions of the {framework} images.\n" + "{latest_supported_version} is the latest version of {framework} that supports " + "Python 2. Newer versions of {framework} will only be available for Python 3." "Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image." ) @@ -495,9 +495,12 @@ def get_unsupported_framework_version_error( ) -def python_deprecation_warning(framework): +def python_deprecation_warning(framework, latest_supported_version): """ Args: framework: + latest_supported_version: """ - return PYTHON_2_DEPRECATION_WARNING.format(framework=framework) + return PYTHON_2_DEPRECATION_WARNING.format( + framework=framework, latest_supported_version=latest_supported_version + ) diff --git a/src/sagemaker/mxnet/defaults.py b/src/sagemaker/mxnet/defaults.py index 559cb1d11e..25981bc2a0 100644 --- a/src/sagemaker/mxnet/defaults.py +++ b/src/sagemaker/mxnet/defaults.py @@ -20,3 +20,5 @@ LATEST_VERSION = "1.6.0" """The latest version of MXNet included in the SageMaker pre-built Docker images.""" + +LATEST_PY2_VERSION = "1.6.0" diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 069dab1469..a5763d9bb5 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -23,7 +23,7 @@ python_deprecation_warning, is_version_equal_or_higher, ) -from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION +from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.mxnet.model import MXNetModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -120,7 +120,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) self.py_version = py_version self._configure_distribution(distributions) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 9d50710e9f..7f1eb72e2c 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -25,7 +25,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION +from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer logger = logging.getLogger("sagemaker") @@ -113,7 +113,8 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + if framework_version is None: logger.warning(empty_framework_version_warning(MXNET_VERSION, LATEST_VERSION)) diff --git a/src/sagemaker/pytorch/defaults.py b/src/sagemaker/pytorch/defaults.py index df51efe637..86c97b8e4d 100644 --- a/src/sagemaker/pytorch/defaults.py +++ b/src/sagemaker/pytorch/defaults.py @@ -23,3 +23,5 @@ """The latest version of PyTorch included in the SageMaker pre-built Docker images.""" PYTHON_VERSION = "py3" + +LATEST_PY2_VERSION = "1.3.1" diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index b047154ccf..53dc0c185b 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -23,7 +23,7 @@ python_deprecation_warning, is_version_equal_or_higher, ) -from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION +from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -116,7 +116,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) self.py_version = py_version diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 1b97f5a866..02338debfe 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -24,7 +24,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION +from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer logger = logging.getLogger("sagemaker") @@ -114,7 +114,8 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + if framework_version is None: logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION)) diff --git a/src/sagemaker/sklearn/defaults.py b/src/sagemaker/sklearn/defaults.py index fdc1b61fe0..018611786a 100644 --- a/src/sagemaker/sklearn/defaults.py +++ b/src/sagemaker/sklearn/defaults.py @@ -16,3 +16,5 @@ SKLEARN_NAME = "scikit-learn" SKLEARN_VERSION = "0.20.0" + +LATEST_PY2_VERSION = "0.20.0" diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 39aa284e0d..0f0436e5a1 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -22,7 +22,7 @@ empty_framework_version_warning, python_deprecation_warning, ) -from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME +from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -119,7 +119,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) self.py_version = py_version diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index faa2ef9610..f6b5c2dc07 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -20,7 +20,7 @@ from sagemaker.fw_registry import default_framework_uri from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer -from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME +from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION logger = logging.getLogger("sagemaker") @@ -108,7 +108,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) self.py_version = py_version self.framework_version = framework_version diff --git a/src/sagemaker/tensorflow/defaults.py b/src/sagemaker/tensorflow/defaults.py index a3c0ef456c..5f7ab01e12 100644 --- a/src/sagemaker/tensorflow/defaults.py +++ b/src/sagemaker/tensorflow/defaults.py @@ -20,3 +20,5 @@ LATEST_VERSION = "2.0.0" """The latest version of TensorFlow included in the SageMaker pre-built Docker images.""" + +LATEST_PY2_VERSION = "2.0.0" diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index ee4f00f105..cacd4516cb 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -25,7 +25,7 @@ from sagemaker.debugger import DebuggerHookConfig from sagemaker.estimator import Framework import sagemaker.fw_utils as fw -from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION +from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.tensorflow.serving import Model from sagemaker.transformer import Transformer @@ -293,6 +293,10 @@ def __init__( if not py_version: py_version = "py3" if self._only_python_3_supported() else "py2" + if py_version == "py2": + logger.warning( + fw.python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION) + ) if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for TF v1.15 or greater: @@ -302,9 +306,6 @@ def __init__( super(TensorFlow, self).__init__(image_name=image_name, **kwargs) self.checkpoint_path = checkpoint_path - if py_version == "py2": - logger.warning("tensorflow py2 container will be deprecated soon.") - self.py_version = py_version self.training_steps = training_steps self.evaluation_steps = evaluation_steps @@ -359,7 +360,7 @@ def _validate_args( if py_version == "py2" and self._only_python_3_supported(): msg = ( - "Python 2 containers are only available until January 1st, 2020. " + "Python 2 containers are only available before January 1st, 2020. " "Please use a Python 3 container." ) raise AttributeError(msg) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 59ef3d7958..bea7fca282 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -24,7 +24,7 @@ ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import RealTimePredictor -from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION +from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION, LATEST_PY2_VERSION from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer logger = logging.getLogger("sagemaker") @@ -111,7 +111,8 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + if framework_version is None: logger.warning(empty_framework_version_warning(TF_VERSION, LATEST_VERSION)) diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index 31c66e4e55..e8a0508e3e 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -20,7 +20,6 @@ from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, - python_deprecation_warning, get_unsupported_framework_version_error, UploadedCode, ) @@ -112,7 +111,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + raise AttributeError("XGBoost container does not support Python 2, please use Python 3") self.py_version = py_version if framework_version in XGBOOST_SUPPORTED_VERSIONS: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index fd8365602d..a584f34d27 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -16,7 +16,7 @@ import logging import sagemaker -from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning +from sagemaker.fw_utils import model_code_key_prefix from sagemaker.fw_registry import default_framework_uri from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import RealTimePredictor, npy_serializer, csv_deserializer @@ -99,7 +99,7 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__)) + raise AttributeError("XGBoost container does not support Python 2, please use Python 3") self.py_version = py_version self.framework_version = framework_version diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 597977e380..2254babbef 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -588,6 +588,34 @@ def test_attach_custom_image(sagemaker_session): assert estimator.train_image() == training_image +@patch("sagemaker.chainer.estimator.python_deprecation_warning") +def test_estimator_py2_warning(warning, sagemaker_session): + estimator = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) + + assert estimator.py_version == "py2" + warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) + + +@patch("sagemaker.chainer.model.python_deprecation_warning") +def test_model_py2_warning(warning, sagemaker_session): + model = ChainerModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + py_version="py2", + ) + assert model.py_version == "py2" + warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) + + @patch("sagemaker.chainer.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): estimator = Chainer( diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index d202f6d548..786b1795d1 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -707,6 +707,34 @@ def test_estimator_wrong_version_launch_parameter_server(sagemaker_session): assert "The distributions option is valid for only versions 1.3 and higher" in str(e) +@patch("sagemaker.mxnet.estimator.python_deprecation_warning") +def test_estimator_py2_warning(warning, sagemaker_session): + estimator = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) + + assert estimator.py_version == "py2" + warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) + + +@patch("sagemaker.mxnet.model.python_deprecation_warning") +def test_model_py2_warning(warning, sagemaker_session): + model = MXNetModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + py_version="py2", + ) + assert model.py_version == "py2" + warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) + + @patch("sagemaker.mxnet.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): mx = MXNet( diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 3fac510afb..20191928d0 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -517,6 +517,34 @@ def test_attach_custom_image(sagemaker_session): assert estimator.train_image() == training_image +@patch("sagemaker.pytorch.estimator.python_deprecation_warning") +def test_estimator_py2_warning(warning, sagemaker_session): + estimator = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) + + assert estimator.py_version == "py2" + warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) + + +@patch("sagemaker.pytorch.model.python_deprecation_warning") +def test_model_py2_warning(warning, sagemaker_session): + model = PyTorchModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + py_version="py2", + ) + assert model.py_version == "py2" + warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) + + @patch("sagemaker.pytorch.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): estimator = PyTorch( diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 938050843e..844e8191fc 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -540,3 +540,33 @@ def test_attach_custom_image(sagemaker_session): estimator = SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_name == training_image assert estimator.train_image() == training_image + + +@patch("sagemaker.sklearn.estimator.python_deprecation_warning") +def test_estimator_py2_warning(warning, sagemaker_session): + estimator = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) + + assert estimator.py_version == "py2" + warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) + + +@patch("sagemaker.sklearn.model.python_deprecation_warning") +def test_model_py2_warning(warning, sagemaker_session): + source_dir = "s3://mybucket/source" + + model = SKLearnModel( + model_data=source_dir, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + py_version="py2", + ) + assert model.py_version == "py2" + warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 066cdd4e6c..68500efc80 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -924,6 +924,31 @@ def test_attach_custom_image(sagemaker_session): assert estimator.train_image() == training_image +@patch("sagemaker.fw_utils.python_deprecation_warning") +def test_estimator_py2_deprecation_warning(warning, sagemaker_session): + estimator = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) + + assert estimator.py_version == "py2" + warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) + + model = TensorFlowModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + py_version="py2", + ) + assert model.py_version == "py2" + warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) + + @patch("sagemaker.fw_utils.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): estimator = TensorFlow( @@ -995,7 +1020,7 @@ def test_py2_version_deprecated(sagemaker_session): with pytest.raises(AttributeError) as e: _build_tf(sagemaker_session=sagemaker_session, framework_version="2.0.1", py_version="py2") - msg = "Python 2 containers are only available until January 1st, 2020. Please use a Python 3 container." + msg = "Python 2 containers are only available before January 1st, 2020. Please use a Python 3 container." assert msg in str(e.value) diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 879eac6bd9..6ef43bea58 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -534,3 +534,30 @@ def test_attach_custom_image(sagemaker_session): with pytest.raises(TypeError) as error: XGBoost.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "expected string" in str(error) + + +def test_py2_xgboost_attribute_error(sagemaker_session): + with pytest.raises(AttributeError) as error1: + XGBoost( + entry_point=SCRIPT_PATH, + role=ROLE, + framework_version=XGBOOST_LATEST_VERSION, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + train_instance_count=1, + py_version="py2", + ) + + with pytest.raises(AttributeError) as error2: + XGBoostModel( + model_data=DATA_DIR, + role=ROLE, + sagemaker_session=sagemaker_session, + entry_point=SCRIPT_PATH, + framework_version=XGBOOST_LATEST_VERSION, + py_version="py2", + ) + + error_message = "XGBoost container does not support Python 2, please use Python 3" + assert error_message in str(error1) + assert error_message in str(error2) From 2650374e1a286d99837fb0294f705f4a7bab653f Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Thu, 9 Jan 2020 17:01:06 -0800 Subject: [PATCH 2/3] fix black check errors --- src/sagemaker/pytorch/estimator.py | 7 ++++++- src/sagemaker/pytorch/model.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 53dc0c185b..013bf3f838 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -23,7 +23,12 @@ python_deprecation_warning, is_version_equal_or_higher, ) -from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.pytorch.defaults import ( + PYTORCH_VERSION, + PYTHON_VERSION, + LATEST_VERSION, + LATEST_PY2_VERSION, +) from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 02338debfe..4b01cc035a 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -24,7 +24,12 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.pytorch.defaults import ( + PYTORCH_VERSION, + PYTHON_VERSION, + LATEST_VERSION, + LATEST_PY2_VERSION, +) from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer logger = logging.getLogger("sagemaker") From 4763545402cfd928e419bdae2c4e54bf97bd3411 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Fri, 10 Jan 2020 15:18:46 -0800 Subject: [PATCH 3/3] improve imports and error message --- src/sagemaker/chainer/estimator.py | 14 +++++++++----- src/sagemaker/chainer/model.py | 12 ++++++++---- src/sagemaker/mxnet/estimator.py | 14 +++++++++----- src/sagemaker/mxnet/model.py | 12 ++++++++---- src/sagemaker/pytorch/estimator.py | 21 ++++++++++----------- src/sagemaker/pytorch/model.py | 19 +++++++++---------- src/sagemaker/sklearn/estimator.py | 16 ++++++++++------ src/sagemaker/sklearn/model.py | 10 ++++++---- src/sagemaker/tensorflow/estimator.py | 16 +++++++++------- src/sagemaker/tensorflow/model.py | 12 ++++++++---- src/sagemaker/xgboost/estimator.py | 8 ++++---- tests/unit/test_tf_estimator.py | 15 +++++++++++++-- 12 files changed, 103 insertions(+), 66 deletions(-) diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 909be9289b..cf03d4656a 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -22,7 +22,7 @@ empty_framework_version_warning, python_deprecation_warning, ) -from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.chainer import defaults from sagemaker.chainer.model import ChainerModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -40,7 +40,7 @@ class Chainer(Framework): _process_slots_per_host = "sagemaker_process_slots_per_host" _additional_mpi_options = "sagemaker_additional_mpi_options" - LATEST_VERSION = LATEST_VERSION + LATEST_VERSION = defaults.LATEST_VERSION def __init__( self, @@ -126,15 +126,19 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: - logger.warning(empty_framework_version_warning(CHAINER_VERSION, self.LATEST_VERSION)) - self.framework_version = framework_version or CHAINER_VERSION + logger.warning( + empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION) + ) + self.framework_version = framework_version or defaults.CHAINER_VERSION super(Chainer, self).__init__( entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) self.py_version = py_version self.use_mpi = use_mpi diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 7c8f8355ee..a6a7b68148 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -23,7 +23,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.chainer import defaults from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer logger = logging.getLogger("sagemaker") @@ -111,13 +111,17 @@ def __init__( model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) if framework_version is None: - logger.warning(empty_framework_version_warning(CHAINER_VERSION, LATEST_VERSION)) + logger.warning( + empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION) + ) self.py_version = py_version - self.framework_version = framework_version or CHAINER_VERSION + self.framework_version = framework_version or defaults.CHAINER_VERSION self.model_server_workers = model_server_workers def prepare_container_def(self, instance_type, accelerator_type=None): diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index a5763d9bb5..04a76d956d 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -23,7 +23,7 @@ python_deprecation_warning, is_version_equal_or_higher, ) -from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.mxnet import defaults from sagemaker.mxnet.model import MXNetModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -36,7 +36,7 @@ class MXNet(Framework): __framework_name__ = "mxnet" _LOWEST_SCRIPT_MODE_VERSION = ["1", "3"] - LATEST_VERSION = LATEST_VERSION + LATEST_VERSION = defaults.LATEST_VERSION def __init__( self, @@ -107,8 +107,10 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: - logger.warning(empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION)) - self.framework_version = framework_version or MXNET_VERSION + logger.warning( + empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION) + ) + self.framework_version = framework_version or defaults.MXNET_VERSION if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for MXNet v1.6 or greater: @@ -120,7 +122,9 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) self.py_version = py_version self._configure_distribution(distributions) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 7f1eb72e2c..5ec7d24866 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -25,7 +25,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.mxnet import defaults from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer logger = logging.getLogger("sagemaker") @@ -113,13 +113,17 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) if framework_version is None: - logger.warning(empty_framework_version_warning(MXNET_VERSION, LATEST_VERSION)) + logger.warning( + empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION) + ) self.py_version = py_version - self.framework_version = framework_version or MXNET_VERSION + self.framework_version = framework_version or defaults.MXNET_VERSION self.model_server_workers = model_server_workers def prepare_container_def(self, instance_type, accelerator_type=None): diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 013bf3f838..16a985db7f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -23,12 +23,7 @@ python_deprecation_warning, is_version_equal_or_higher, ) -from sagemaker.pytorch.defaults import ( - PYTORCH_VERSION, - PYTHON_VERSION, - LATEST_VERSION, - LATEST_PY2_VERSION, -) +from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -40,14 +35,14 @@ class PyTorch(Framework): __framework_name__ = "pytorch" - LATEST_VERSION = LATEST_VERSION + LATEST_VERSION = defaults.LATEST_VERSION def __init__( self, entry_point, source_dir=None, hyperparameters=None, - py_version=PYTHON_VERSION, + py_version=defaults.PYTHON_VERSION, framework_version=None, image_name=None, **kwargs @@ -108,8 +103,10 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: - logger.warning(empty_framework_version_warning(PYTORCH_VERSION, self.LATEST_VERSION)) - self.framework_version = framework_version or PYTORCH_VERSION + logger.warning( + empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION) + ) + self.framework_version = framework_version or defaults.PYTORCH_VERSION if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for PT v1.3 or greater: @@ -121,7 +118,9 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) self.py_version = py_version diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 4b01cc035a..e7b6a95638 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -24,12 +24,7 @@ empty_framework_version_warning, ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.pytorch.defaults import ( - PYTORCH_VERSION, - PYTHON_VERSION, - LATEST_VERSION, - LATEST_PY2_VERSION, -) +from sagemaker.pytorch import defaults from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer logger = logging.getLogger("sagemaker") @@ -72,7 +67,7 @@ def __init__( role, entry_point, image=None, - py_version=PYTHON_VERSION, + py_version=defaults.PYTHON_VERSION, framework_version=None, predictor_cls=PyTorchPredictor, model_server_workers=None, @@ -119,13 +114,17 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) if framework_version is None: - logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION)) + logger.warning( + empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION) + ) self.py_version = py_version - self.framework_version = framework_version or PYTORCH_VERSION + self.framework_version = framework_version or defaults.PYTORCH_VERSION self.model_server_workers = model_server_workers def prepare_container_def(self, instance_type, accelerator_type=None): diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 0f0436e5a1..5dde0aec56 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -22,7 +22,7 @@ empty_framework_version_warning, python_deprecation_warning, ) -from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION +from sagemaker.sklearn import defaults from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -32,12 +32,12 @@ class SKLearn(Framework): """Handle end-to-end training and deployment of custom Scikit-learn code.""" - __framework_name__ = SKLEARN_NAME + __framework_name__ = defaults.SKLEARN_NAME def __init__( self, entry_point, - framework_version=SKLEARN_VERSION, + framework_version=defaults.SKLEARN_VERSION, source_dir=None, hyperparameters=None, py_version="py3", @@ -119,13 +119,17 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) self.py_version = py_version if framework_version is None: - logger.warning(empty_framework_version_warning(SKLEARN_VERSION, SKLEARN_VERSION)) - self.framework_version = framework_version or SKLEARN_VERSION + logger.warning( + empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION) + ) + self.framework_version = framework_version or defaults.SKLEARN_VERSION if image_name is None: image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index f6b5c2dc07..1f9cd71fbb 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -20,7 +20,7 @@ from sagemaker.fw_registry import default_framework_uri from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer -from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME, LATEST_PY2_VERSION +from sagemaker.sklearn import defaults logger = logging.getLogger("sagemaker") @@ -53,7 +53,7 @@ class SKLearnModel(FrameworkModel): ``Endpoint``. """ - __framework_name__ = SKLEARN_NAME + __framework_name__ = defaults.SKLEARN_NAME def __init__( self, @@ -62,7 +62,7 @@ def __init__( entry_point, image=None, py_version="py3", - framework_version=SKLEARN_VERSION, + framework_version=defaults.SKLEARN_VERSION, predictor_cls=SKLearnPredictor, model_server_workers=None, **kwargs @@ -108,7 +108,9 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) self.py_version = py_version self.framework_version = framework_version diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index cacd4516cb..4f3d2cd365 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -25,7 +25,7 @@ from sagemaker.debugger import DebuggerHookConfig from sagemaker.estimator import Framework import sagemaker.fw_utils as fw -from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.tensorflow import defaults from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.tensorflow.serving import Model from sagemaker.transformer import Transformer @@ -197,7 +197,7 @@ class TensorFlow(Framework): __framework_name__ = "tensorflow" - LATEST_VERSION = LATEST_VERSION + LATEST_VERSION = defaults.LATEST_VERSION _LATEST_1X_VERSION = "1.15.0" @@ -288,14 +288,16 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: - logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION)) - self.framework_version = framework_version or TF_VERSION + logger.warning( + fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION) + ) + self.framework_version = framework_version or defaults.TF_VERSION if not py_version: py_version = "py3" if self._only_python_3_supported() else "py2" if py_version == "py2": logger.warning( - fw.python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION) + fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) ) if "enable_sagemaker_metrics" not in kwargs: @@ -360,8 +362,8 @@ def _validate_args( if py_version == "py2" and self._only_python_3_supported(): msg = ( - "Python 2 containers are only available before January 1st, 2020. " - "Please use a Python 3 container." + "Python 2 containers are only available with {} and lower versions. " + "Please use a Python 3 container.".format(defaults.LATEST_PY2_VERSION) ) raise AttributeError(msg) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index bea7fca282..8d00b13544 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -24,7 +24,7 @@ ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import RealTimePredictor -from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION, LATEST_PY2_VERSION +from sagemaker.tensorflow import defaults from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer logger = logging.getLogger("sagemaker") @@ -111,13 +111,17 @@ def __init__( ) if py_version == "py2": - logger.warning(python_deprecation_warning(self.__framework_name__, LATEST_PY2_VERSION)) + logger.warning( + python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) + ) if framework_version is None: - logger.warning(empty_framework_version_warning(TF_VERSION, LATEST_VERSION)) + logger.warning( + empty_framework_version_warning(defaults.TF_VERSION, defaults.LATEST_VERSION) + ) self.py_version = py_version - self.framework_version = framework_version or TF_VERSION + self.framework_version = framework_version or defaults.TF_VERSION self.model_server_workers = model_server_workers def prepare_container_def(self, instance_type, accelerator_type=None): diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index e8a0508e3e..a042914163 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -31,7 +31,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -from sagemaker.xgboost.defaults import XGBOOST_NAME, XGBOOST_SUPPORTED_VERSIONS +from sagemaker.xgboost import defaults from sagemaker.xgboost.model import XGBoostModel logger = logging.getLogger("sagemaker") @@ -47,7 +47,7 @@ class XGBoost(Framework): """Handle end-to-end training and deployment of XGBoost booster training or training using customer provided XGBoost entry point script.""" - __framework_name__ = XGBOOST_NAME + __framework_name__ = defaults.XGBOOST_NAME def __init__( self, @@ -114,12 +114,12 @@ def __init__( raise AttributeError("XGBoost container does not support Python 2, please use Python 3") self.py_version = py_version - if framework_version in XGBOOST_SUPPORTED_VERSIONS: + if framework_version in defaults.XGBOOST_SUPPORTED_VERSIONS: self.framework_version = framework_version else: raise ValueError( get_unsupported_framework_version_error( - self.__framework_name__, framework_version, XGBOOST_SUPPORTED_VERSIONS + self.__framework_name__, framework_version, defaults.XGBOOST_SUPPORTED_VERSIONS ) ) diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 68500efc80..4fc08a872f 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -1018,9 +1018,20 @@ def test_script_mode_deprecated_args(sagemaker_session): def test_py2_version_deprecated(sagemaker_session): with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, framework_version="2.0.1", py_version="py2") + TensorFlow( + entry_point=SCRIPT_PATH, + framework_version="2.0.1", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version="py2", + ) - msg = "Python 2 containers are only available before January 1st, 2020. Please use a Python 3 container." + msg = ( + "Python 2 containers are only available with 2.0.0 and lower versions. " + "Please use a Python 3 container." + ) assert msg in str(e.value)