From 58a07ae749391421835e7820edcf0914b801b390 Mon Sep 17 00:00:00 2001 From: Kim Date: Fri, 5 Jun 2020 16:00:38 -0700 Subject: [PATCH] feature: add support for SKLearn 0.23 --- src/sagemaker/fw_utils.py | 10 +++++++++- src/sagemaker/sklearn/defaults.py | 5 +++++ src/sagemaker/sklearn/estimator.py | 17 ++++++++++++----- tests/unit/test_sklearn.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 9141ae8c72..75a6ccffdb 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -587,10 +587,18 @@ def empty_framework_version_warning(default_version, latest_version): """ msgs = [EMPTY_FRAMEWORK_VERSION_WARNING.format(default_version)] if default_version != latest_version: - msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version)) + msgs.append(later_framework_version_warning(latest_version)) return " ".join(msgs) +def later_framework_version_warning(latest_version): + """ + Args: + latest_version: + """ + return LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version) + + def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions): """Warn the user that training will not fully leverage all the GPU cores if parameter server is enabled and a multi-GPU instance is selected. diff --git a/src/sagemaker/sklearn/defaults.py b/src/sagemaker/sklearn/defaults.py index 018611786a..34f5a3d2ef 100644 --- a/src/sagemaker/sklearn/defaults.py +++ b/src/sagemaker/sklearn/defaults.py @@ -15,6 +15,11 @@ SKLEARN_NAME = "scikit-learn" +# Default SKLearn version for when the framework version is not specified. +# This is no longer updated so as to not break existing workflows. SKLEARN_VERSION = "0.20.0" +SKLEARN_LATEST_VERSION = "0.23-1" +SKLEARN_SUPPORTED_VERSIONS = [SKLEARN_VERSION, SKLEARN_LATEST_VERSION] + LATEST_PY2_VERSION = "0.20.0" diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index e0fc4f76b9..a7e164c56e 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -19,7 +19,8 @@ from sagemaker.fw_registry import default_framework_uri from sagemaker.fw_utils import ( framework_name_from_image, - empty_framework_version_warning, + get_unsupported_framework_version_error, + later_framework_version_warning, python_deprecation_warning, ) from sagemaker.sklearn import defaults @@ -127,11 +128,17 @@ def __init__( self.py_version = py_version - if framework_version is None: - logger.warning( - empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION) + if framework_version in defaults.SKLEARN_SUPPORTED_VERSIONS: + self.framework_version = framework_version + else: + raise ValueError( + get_unsupported_framework_version_error( + self.__framework_name__, framework_version, defaults.SKLEARN_SUPPORTED_VERSIONS + ) ) - self.framework_version = framework_version or defaults.SKLEARN_VERSION + + if framework_version != defaults.SKLEARN_LATEST_VERSION: + logger.warning(later_framework_version_warning(defaults.SKLEARN_LATEST_VERSION)) if image_name is None: image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 3acba6b8ee..ddc22230b6 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -574,6 +574,35 @@ def test_estimator_py2_warning(warning, sagemaker_session): warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION) +@patch("sagemaker.sklearn.estimator.later_framework_version_warning") +def test_estimator_later_framework_version_warning(warning, sagemaker_session): + estimator = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + + assert estimator.framework_version == defaults.SKLEARN_VERSION + warning.assert_called_with(defaults.SKLEARN_LATEST_VERSION) + + +@patch("sagemaker.sklearn.estimator.get_unsupported_framework_version_error") +def test_estimator_throws_error_for_unsupported_version(error, sagemaker_session): + with pytest.raises(ValueError): + estimator = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version="foo", + ) + assert estimator.framework_version not in defaults.SKLEARN_SUPPORTED_VERSIONS + error.assert_called_with(defaults.SKLEARN_NAME, "foo", defaults.SKLEARN_SUPPORT_VERSIONS) + + @patch("sagemaker.sklearn.model.python_deprecation_warning") def test_model_py2_warning(warning, sagemaker_session): source_dir = "s3://mybucket/source"