Skip to content

Commit a175112

Browse files
committed
fix: revert #2251 changes for sklearn
1 parent fed5d0b commit a175112

File tree

2 files changed

+82
-64
lines changed

2 files changed

+82
-64
lines changed

src/sagemaker/sklearn/processing.py

+69-46
Original file line numberDiff line numberDiff line change
@@ -11,73 +11,96 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code related to SKLearn Processors which are used for Processing jobs.
14-
1514
These jobs let customers perform data pre-processing, post-processing, feature engineering,
1615
data validation, and model evaluation and interpretation on SageMaker.
1716
"""
1817
from __future__ import absolute_import
1918

20-
from sagemaker.processing import FrameworkProcessor
21-
from sagemaker.sklearn.estimator import SKLearn
22-
23-
24-
class SKLearnProcessor(FrameworkProcessor):
25-
"""Initialize an ``SKLearnProcessor`` instance.
26-
27-
The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn.
28-
29-
Unless ``image_uri`` is specified, the scikit-learn environment is an
30-
Amazon-built Docker container that executes functions defined in the supplied
31-
``code`` Python script.
32-
33-
The arguments have the exact same meaning as in ``FrameworkProcessor``.
34-
35-
.. tip::
19+
from sagemaker import image_uris, Session
20+
from sagemaker.processing import ScriptProcessor
21+
from sagemaker.sklearn import defaults
3622

37-
You can find additional parameters for initializing this class at
38-
:class:`~sagemaker.processing.FrameworkProcessor`.
39-
"""
4023

41-
estimator_cls = SKLearn
24+
class SKLearnProcessor(ScriptProcessor):
25+
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
4226

4327
def __init__(
4428
self,
45-
framework_version, # New arg
29+
framework_version,
4630
role,
47-
instance_count,
4831
instance_type,
49-
py_version="py3", # New kwarg
50-
image_uri=None,
32+
instance_count,
5133
command=None,
5234
volume_size_in_gb=30,
5335
volume_kms_key=None,
5436
output_kms_key=None,
55-
code_location=None, # New arg
5637
max_runtime_in_seconds=None,
5738
base_job_name=None,
5839
sagemaker_session=None,
5940
env=None,
6041
tags=None,
6142
network_config=None,
6243
):
63-
"""This processor executes a Python script in a scikit-learn execution environment."""
64-
super().__init__(
65-
self.estimator_cls,
66-
framework_version,
67-
role,
68-
instance_count,
69-
instance_type,
70-
py_version,
71-
image_uri,
72-
command,
73-
volume_size_in_gb,
74-
volume_kms_key,
75-
output_kms_key,
76-
code_location,
77-
max_runtime_in_seconds,
78-
base_job_name,
79-
sagemaker_session,
80-
env,
81-
tags,
82-
network_config,
44+
"""Initialize an ``SKLearnProcessor`` instance.
45+
The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn.
46+
Args:
47+
framework_version (str): The version of scikit-learn.
48+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs
49+
and APIs that create Amazon SageMaker endpoints use this role
50+
to access training data and model artifacts. After the endpoint
51+
is created, the inference code might use the IAM role, if it
52+
needs to access an AWS resource.
53+
instance_type (str): Type of EC2 instance to use for
54+
processing, for example, 'ml.c4.xlarge'.
55+
instance_count (int): The number of instances to run
56+
the Processing job with. Defaults to 1.
57+
command ([str]): The command to run, along with any command-line flags.
58+
Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"]
59+
will be chosen based on the py_version parameter.
60+
volume_size_in_gb (int): Size in GB of the EBS volume to
61+
use for storing data during processing (default: 30).
62+
volume_kms_key (str): A KMS key for the processing
63+
volume.
64+
output_kms_key (str): The KMS key id for all ProcessingOutputs.
65+
max_runtime_in_seconds (int): Timeout in seconds.
66+
After this amount of time Amazon SageMaker terminates the job
67+
regardless of its current status.
68+
base_job_name (str): Prefix for processing name. If not specified,
69+
the processor generates a default job name, based on the
70+
training image name and current timestamp.
71+
sagemaker_session (sagemaker.session.Session): Session object which
72+
manages interactions with Amazon SageMaker APIs and any other
73+
AWS services needed. If not specified, the processor creates one
74+
using the default AWS configuration chain.
75+
env (dict): Environment variables to be passed to the processing job.
76+
tags ([dict]): List of tags to be passed to the processing job.
77+
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
78+
object that configures network isolation, encryption of
79+
inter-container traffic, security group IDs, and subnets.
80+
"""
81+
if not command:
82+
command = ["python3"]
83+
84+
session = sagemaker_session or Session()
85+
region = session.boto_region_name
86+
87+
image_uri = image_uris.retrieve(
88+
defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type
89+
)
90+
91+
super(SKLearnProcessor, self).__init__(
92+
role=role,
93+
image_uri=image_uri,
94+
instance_count=instance_count,
95+
instance_type=instance_type,
96+
command=command,
97+
volume_size_in_gb=volume_size_in_gb,
98+
volume_kms_key=volume_kms_key,
99+
output_kms_key=output_kms_key,
100+
max_runtime_in_seconds=max_runtime_in_seconds,
101+
base_job_name=base_job_name,
102+
sagemaker_session=session,
103+
env=env,
104+
tags=tags,
105+
network_config=network_config,
83106
)

tests/integ/test_processing.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_sklearn(sagemaker_session, sklearn_latest_version, cpu_instance_type):
125125
role=ROLE,
126126
instance_type=cpu_instance_type,
127127
instance_count=1,
128+
command=["python3"],
128129
sagemaker_session=sagemaker_session,
129130
base_job_name="test-sklearn",
130131
)
@@ -138,16 +139,16 @@ def test_sklearn(sagemaker_session, sklearn_latest_version, cpu_instance_type):
138139

139140
job_description = sklearn_processor.latest_job.describe()
140141

141-
assert len(job_description["ProcessingInputs"]) == 3
142+
assert len(job_description["ProcessingInputs"]) == 2
142143
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] == 1
143144
assert (
144145
job_description["ProcessingResources"]["ClusterConfig"]["InstanceType"] == cpu_instance_type
145146
)
146147
assert job_description["ProcessingResources"]["ClusterConfig"]["VolumeSizeInGB"] == 30
147148
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 86400}
148149
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
149-
"/bin/bash",
150-
"/opt/ml/processing/input/entrypoint/runproc.sh",
150+
"python3",
151+
"/opt/ml/processing/input/code/dummy_script.py",
151152
]
152153
assert ROLE in job_description["RoleArn"]
153154

@@ -203,7 +204,6 @@ def test_sklearn_with_customizations(
203204
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
204205

205206
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
206-
assert job_description["ProcessingInputs"][2]["InputName"] == "entrypoint"
207207

208208
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
209209

@@ -220,8 +220,8 @@ def test_sklearn_with_customizations(
220220

221221
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
222222
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
223-
"/bin/bash",
224-
"/opt/ml/processing/input/entrypoint/runproc.sh",
223+
"python3",
224+
"/opt/ml/processing/input/code/dummy_script.py",
225225
]
226226
assert job_description["AppSpecification"]["ImageUri"] == image_uri
227227

@@ -245,6 +245,7 @@ def test_sklearn_with_custom_default_bucket(
245245
sklearn_processor = SKLearnProcessor(
246246
framework_version=sklearn_latest_version,
247247
role=ROLE,
248+
command=["python3"],
248249
instance_type=cpu_instance_type,
249250
instance_count=1,
250251
volume_size_in_gb=100,
@@ -287,9 +288,6 @@ def test_sklearn_with_custom_default_bucket(
287288
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
288289
assert custom_bucket_name in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
289290

290-
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
291-
assert custom_bucket_name in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
292-
293291
assert job_description["ProcessingInputs"][2]["InputName"] == "entrypoint"
294292
assert custom_bucket_name in job_description["ProcessingInputs"][2]["S3Input"]["S3Uri"]
295293

@@ -308,8 +306,8 @@ def test_sklearn_with_custom_default_bucket(
308306

309307
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
310308
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
311-
"/bin/bash",
312-
"/opt/ml/processing/input/entrypoint/runproc.sh",
309+
"python3",
310+
"/opt/ml/processing/input/code/dummy_script.py",
313311
]
314312
assert job_description["AppSpecification"]["ImageUri"] == image_uri
315313

@@ -326,6 +324,7 @@ def test_sklearn_with_no_inputs_or_outputs(
326324
sklearn_processor = SKLearnProcessor(
327325
framework_version=sklearn_latest_version,
328326
role=ROLE,
327+
command=["python3"],
329328
instance_type=cpu_instance_type,
330329
instance_count=1,
331330
volume_size_in_gb=100,
@@ -338,16 +337,12 @@ def test_sklearn_with_no_inputs_or_outputs(
338337
)
339338

340339
sklearn_processor.run(
341-
code=os.path.join(DATA_DIR, "dummy_script.py"),
342-
arguments=["-v"],
343-
wait=True,
344-
logs=True,
340+
code=os.path.join(DATA_DIR, "dummy_script.py"), arguments=["-v"], wait=True, logs=True
345341
)
346342

347343
job_description = sklearn_processor.latest_job.describe()
348344

349345
assert job_description["ProcessingInputs"][0]["InputName"] == "code"
350-
assert job_description["ProcessingInputs"][1]["InputName"] == "entrypoint"
351346

352347
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-no-inputs")
353348

@@ -361,8 +356,8 @@ def test_sklearn_with_no_inputs_or_outputs(
361356

362357
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
363358
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
364-
"/bin/bash",
365-
"/opt/ml/processing/input/entrypoint/runproc.sh",
359+
"python3",
360+
"/opt/ml/processing/input/code/dummy_script.py",
366361
]
367362
assert job_description["AppSpecification"]["ImageUri"] == image_uri
368363

0 commit comments

Comments
 (0)