Skip to content

Commit 8ff10cc

Browse files
authored
change: use images_uris.retrieve() for scikit-learn classes (#1728)
1 parent 67810d7 commit 8ff10cc

File tree

15 files changed

+242
-157
lines changed

15 files changed

+242
-157
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"processors": ["cpu"],
3+
"scope": ["inference", "training"],
4+
"versions": {
5+
"0.20.0": {
6+
"py_versions": ["py3"],
7+
"registries": {
8+
"ap-east-1": "651117190479",
9+
"ap-northeast-1": "354813040037",
10+
"ap-northeast-2": "366743142698",
11+
"ap-south-1": "720646828776",
12+
"ap-southeast-1": "121021644041",
13+
"ap-southeast-2": "783357654285",
14+
"ca-central-1": "341280168497",
15+
"cn-north-1": "450853457545",
16+
"cn-northwest-1": "451049120500",
17+
"eu-central-1": "492215442770",
18+
"eu-north-1": "662702820516",
19+
"eu-west-1": "141502667606",
20+
"eu-west-2": "764974769150",
21+
"eu-west-3": "659782779980",
22+
"me-south-1": "801668240914",
23+
"sa-east-1": "737474898029",
24+
"us-east-1": "683313688378",
25+
"us-east-2": "257758044811",
26+
"us-gov-west-1": "414596584902",
27+
"us-iso-east-1": "833128469047",
28+
"us-west-1": "746614075791",
29+
"us-west-2": "246618743249"
30+
},
31+
"repository": "sagemaker-scikit-learn"
32+
},
33+
"0.23-1": {
34+
"py_versions": ["py3"],
35+
"registries": {
36+
"ap-east-1": "651117190479",
37+
"ap-northeast-1": "354813040037",
38+
"ap-northeast-2": "366743142698",
39+
"ap-south-1": "720646828776",
40+
"ap-southeast-1": "121021644041",
41+
"ap-southeast-2": "783357654285",
42+
"ca-central-1": "341280168497",
43+
"cn-north-1": "450853457545",
44+
"cn-northwest-1": "451049120500",
45+
"eu-central-1": "492215442770",
46+
"eu-north-1": "662702820516",
47+
"eu-west-1": "141502667606",
48+
"eu-west-2": "764974769150",
49+
"eu-west-3": "659782779980",
50+
"me-south-1": "801668240914",
51+
"sa-east-1": "737474898029",
52+
"us-east-1": "683313688378",
53+
"us-east-2": "257758044811",
54+
"us-gov-west-1": "414596584902",
55+
"us-iso-east-1": "833128469047",
56+
"us-west-1": "746614075791",
57+
"us-west-2": "246618743249"
58+
},
59+
"repository": "sagemaker-scikit-learn"
60+
}
61+
}
62+
}

src/sagemaker/sklearn/defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
SKLEARN_NAME = "scikit-learn"
16+
SKLEARN_NAME = "sklearn"

src/sagemaker/sklearn/estimator.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
import logging
1717

18+
from sagemaker import image_uris
1819
from sagemaker.estimator import Framework
19-
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import (
2121
framework_name_from_image,
2222
framework_version_from_tag,
@@ -137,9 +137,12 @@ def __init__(
137137
)
138138

139139
if image_uri is None:
140-
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
141-
self.image_uri = default_framework_uri(
142-
SKLearn.__framework_name__, self.sagemaker_session.boto_region_name, image_tag
140+
self.image_uri = image_uris.retrieve(
141+
SKLearn.__framework_name__,
142+
self.sagemaker_session.boto_region_name,
143+
version=self.framework_version,
144+
py_version=self.py_version,
145+
instance_type=instance_type,
143146
)
144147

145148
def create_model(
@@ -243,7 +246,7 @@ class constructor
243246
init_params["image_uri"] = image_uri
244247
return init_params
245248

246-
if framework and framework != cls.__framework_name__:
249+
if framework and framework != "scikit-learn":
247250
raise ValueError(
248251
"Training job: {} didn't use image for requested framework".format(
249252
job_details["TrainingJobName"]

src/sagemaker/sklearn/model.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker import image_uris
1920
from sagemaker.deserializers import NumpyDeserializer
20-
from sagemaker.fw_registry import default_framework_uri
2121
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2323
from sagemaker.predictor import Predictor
@@ -163,17 +163,21 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
163163
)
164164
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)
165165

166-
def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
166+
def serving_image_uri(self, region_name, instance_type):
167167
"""Create a URI for the serving image.
168168
169169
Args:
170170
region_name (str): AWS region where the image is uploaded.
171-
instance_type (str): SageMaker instance type. This parameter is unused because
172-
Scikit-learn supports only CPU.
171+
instance_type (str): SageMaker instance type.
173172
174173
Returns:
175174
str: The appropriate image URI based on the given parameters.
176175
177176
"""
178-
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
179-
return default_framework_uri(self.__framework_name__, region_name, image_tag)
177+
return image_uris.retrieve(
178+
self.__framework_name__,
179+
region_name,
180+
version=self.framework_version,
181+
py_version=self.py_version,
182+
instance_type=instance_type,
183+
)

src/sagemaker/sklearn/processing.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@
1717
"""
1818
from __future__ import absolute_import
1919

20-
from sagemaker.fw_registry import default_framework_uri
21-
22-
from sagemaker import Session
20+
from sagemaker import image_uris, Session
2321
from sagemaker.processing import ScriptProcessor
22+
from sagemaker.sklearn import defaults
2423

2524

2625
class SKLearnProcessor(ScriptProcessor):
2726
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
2827

29-
_valid_framework_versions = ["0.20.0"]
30-
3128
def __init__(
3229
self,
3330
framework_version,
@@ -83,21 +80,14 @@ def __init__(
8380
object that configures network isolation, encryption of
8481
inter-container traffic, security group IDs, and subnets.
8582
"""
86-
session = sagemaker_session or Session()
87-
region = session.boto_region_name
88-
89-
if framework_version not in self._valid_framework_versions:
90-
raise ValueError(
91-
"scikit-learn version {} is not supported. Supported versions are {}".format(
92-
framework_version, self._valid_framework_versions
93-
)
94-
)
95-
9683
if not command:
9784
command = ["python3"]
9885

99-
image_tag = "{}-{}-{}".format(framework_version, "cpu", "py3")
100-
image_uri = default_framework_uri("scikit-learn", region, image_tag)
86+
session = sagemaker_session or Session()
87+
region = session.boto_region_name
88+
image_uri = image_uris.retrieve(
89+
defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type
90+
)
10191

10292
super(SKLearnProcessor, self).__init__(
10393
role=role,

tests/conftest.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,6 @@ def pytorch_eia_py_version():
141141
return "py3"
142142

143143

144-
@pytest.fixture(scope="module", params=["0.20.0"])
145-
def sklearn_version(request):
146-
return request.param
147-
148-
149144
@pytest.fixture(scope="module")
150145
def xgboost_framework_version(xgboost_version):
151146
if xgboost_version in ("1", "latest"):
@@ -202,16 +197,6 @@ def rl_ray_full_version():
202197
return RLEstimator.RAY_LATEST_VERSION
203198

204199

205-
@pytest.fixture(scope="module")
206-
def sklearn_full_version():
207-
return "0.20.0"
208-
209-
210-
@pytest.fixture(scope="module")
211-
def sklearn_full_py_version():
212-
return "py3"
213-
214-
215200
@pytest.fixture(scope="module")
216201
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
217202
"""Fixture for TF tests that test both training and inference.
@@ -300,7 +285,7 @@ def pytest_generate_tests(metafunc):
300285

301286

302287
def _generate_all_framework_version_fixtures(metafunc):
303-
for fw in ("chainer", "mxnet", "pytorch", "tensorflow", "xgboost"):
288+
for fw in ("chainer", "mxnet", "pytorch", "sklearn", "tensorflow", "xgboost"):
304289
config = image_uris.config_for_framework(fw)
305290
if "scope" in config:
306291
_parametrize_framework_version_fixtures(metafunc, fw, config)

tests/data/sklearn_mnist/mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import numpy as np
1717
import os
1818

19+
import joblib
1920
from sklearn import svm
20-
from sklearn.externals import joblib
2121

2222

2323
def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format):

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_mxnet_airflow_config_uploads_data_source_to_s3(
478478

479479
@pytest.mark.canary_quick
480480
def test_sklearn_airflow_config_uploads_data_source_to_s3(
481-
sagemaker_session, cpu_instance_type, sklearn_full_version, sklearn_full_py_version
481+
sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version,
482482
):
483483
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
484484
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -488,8 +488,8 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
488488
entry_point=script_path,
489489
role=ROLE,
490490
instance_type=cpu_instance_type,
491-
framework_version=sklearn_full_version,
492-
py_version=sklearn_full_py_version,
491+
framework_version=sklearn_latest_version,
492+
py_version=sklearn_latest_py_version,
493493
sagemaker_session=sagemaker_session,
494494
hyperparameters={"epochs": 1},
495495
)

tests/integ/test_git.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_private_github(
138138
@pytest.mark.local_mode
139139
@pytest.mark.skip("needs a secure authentication approach")
140140
def test_private_github_with_2fa(
141-
sagemaker_local_session, sklearn_full_version, sklearn_full_py_version
141+
sagemaker_local_session, sklearn_latest_version, sklearn_latest_py_version
142142
):
143143
script_path = "mnist.py"
144144
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
@@ -155,11 +155,11 @@ def test_private_github_with_2fa(
155155
entry_point=script_path,
156156
role="SageMakerRole",
157157
source_dir=source_dir,
158-
py_version=sklearn_full_py_version,
158+
py_version=sklearn_latest_py_version,
159159
instance_count=1,
160160
instance_type="local",
161161
sagemaker_session=sagemaker_local_session,
162-
framework_version=sklearn_full_version,
162+
framework_version=sklearn_latest_version,
163163
hyperparameters={"epochs": 1},
164164
git_config=git_config,
165165
)
@@ -178,7 +178,7 @@ def test_private_github_with_2fa(
178178
model_data,
179179
"SageMakerRole",
180180
entry_point=script_path,
181-
framework_version=sklearn_full_version,
181+
framework_version=sklearn_latest_version,
182182
source_dir=source_dir,
183183
sagemaker_session=sagemaker_local_session,
184184
git_config=git_config,
@@ -194,7 +194,7 @@ def test_private_github_with_2fa(
194194

195195
@pytest.mark.local_mode
196196
def test_github_with_ssh_passphrase_not_configured(
197-
sagemaker_local_session, sklearn_full_version, sklearn_full_py_version
197+
sagemaker_local_session, sklearn_latest_version, sklearn_latest_py_version
198198
):
199199
script_path = "mnist.py"
200200
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
@@ -212,8 +212,8 @@ def test_github_with_ssh_passphrase_not_configured(
212212
instance_count=1,
213213
instance_type="local",
214214
sagemaker_session=sagemaker_local_session,
215-
framework_version=sklearn_full_version,
216-
py_version=sklearn_full_py_version,
215+
framework_version=sklearn_latest_version,
216+
py_version=sklearn_latest_py_version,
217217
hyperparameters={"epochs": 1},
218218
git_config=git_config,
219219
)

0 commit comments

Comments
 (0)