Skip to content

feature: add support for SKLearn 0.23 #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,18 @@ def empty_framework_version_warning(default_version, latest_version):
"""
msgs = [EMPTY_FRAMEWORK_VERSION_WARNING.format(default_version)]
if default_version != latest_version:
msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version))
msgs.append(later_framework_version_warning(latest_version))
return " ".join(msgs)


def later_framework_version_warning(latest_version):
"""
Args:
latest_version:
"""
return LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version)


def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions):
"""Warn the user that training will not fully leverage all the GPU
cores if parameter server is enabled and a multi-GPU instance is selected.
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/sklearn/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

SKLEARN_NAME = "scikit-learn"

# Default SKLearn version for when the framework version is not specified.
# This is no longer updated so as to not break existing workflows.
SKLEARN_VERSION = "0.20.0"
SKLEARN_LATEST_VERSION = "0.23-1"
SKLEARN_SUPPORTED_VERSIONS = [SKLEARN_VERSION, SKLEARN_LATEST_VERSION]


LATEST_PY2_VERSION = "0.20.0"
17 changes: 12 additions & 5 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import (
framework_name_from_image,
empty_framework_version_warning,
get_unsupported_framework_version_error,
later_framework_version_warning,
python_deprecation_warning,
)
from sagemaker.sklearn import defaults
Expand Down Expand Up @@ -127,11 +128,17 @@ def __init__(

self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION)
if framework_version in defaults.SKLEARN_SUPPORTED_VERSIONS:
self.framework_version = framework_version
else:
raise ValueError(
get_unsupported_framework_version_error(
self.__framework_name__, framework_version, defaults.SKLEARN_SUPPORTED_VERSIONS
)
)
self.framework_version = framework_version or defaults.SKLEARN_VERSION

if framework_version != defaults.SKLEARN_LATEST_VERSION:
logger.warning(later_framework_version_warning(defaults.SKLEARN_LATEST_VERSION))

if image_name is None:
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,35 @@ def test_estimator_py2_warning(warning, sagemaker_session):
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)


@patch("sagemaker.sklearn.estimator.later_framework_version_warning")
def test_estimator_later_framework_version_warning(warning, sagemaker_session):
estimator = SKLearn(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
)

assert estimator.framework_version == defaults.SKLEARN_VERSION
warning.assert_called_with(defaults.SKLEARN_LATEST_VERSION)


@patch("sagemaker.sklearn.estimator.get_unsupported_framework_version_error")
def test_estimator_throws_error_for_unsupported_version(error, sagemaker_session):
with pytest.raises(ValueError):
estimator = SKLearn(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version="foo",
)
assert estimator.framework_version not in defaults.SKLEARN_SUPPORTED_VERSIONS
error.assert_called_with(defaults.SKLEARN_NAME, "foo", defaults.SKLEARN_SUPPORT_VERSIONS)


@patch("sagemaker.sklearn.model.python_deprecation_warning")
def test_model_py2_warning(warning, sagemaker_session):
source_dir = "s3://mybucket/source"
Expand Down