Skip to content

Commit 575f8c0

Browse files
committed
fix: remove py_version from SKLearnProcessor (#285)
1 parent bf9b777 commit 575f8c0

File tree

3 files changed

+2
-28
lines changed

3 files changed

+2
-28
lines changed

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
from sagemaker.processing import ScriptProcessor
2424

2525

26-
_PYTHON_VERSION_TO_COMMAND_MAPPING = {"py2": ["python2"], "py3": ["python3"]}
27-
_VALID_PYTHON_VERSIONS = ["py2", "py3"]
28-
29-
3026
class SKLearnProcessor(ScriptProcessor):
3127
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
3228

@@ -37,7 +33,6 @@ def __init__(
3733
instance_type,
3834
instance_count,
3935
command=None,
40-
py_version="py3",
4136
volume_size_in_gb=30,
4237
volume_kms_key=None,
4338
output_kms_key=None,
@@ -65,7 +60,6 @@ def __init__(
6560
command ([str]): The command to run, along with any command-line flags.
6661
Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"]
6762
will be chosen based on the py_version parameter.
68-
py_version (str): The python version to use, for example, 'py3'.
6963
volume_size_in_gb (int): Size in GB of the EBS volume to
7064
use for storing data during processing (default: 30).
7165
volume_kms_key (str): A KMS key for the processing
@@ -90,15 +84,10 @@ def __init__(
9084
session = sagemaker_session or Session()
9185
region = session.boto_region_name
9286

93-
if py_version not in _VALID_PYTHON_VERSIONS:
94-
raise ValueError(
95-
"'py_version' must be a valid value. Please provide one of: 'py2', 'py3'"
96-
)
97-
9887
if not command:
99-
command = _PYTHON_VERSION_TO_COMMAND_MAPPING[py_version]
88+
command = ["python3"]
10089

101-
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
90+
image_tag = "{}-{}-{}".format(framework_version, "cpu", "py3")
10291
image_uri = default_framework_uri("scikit-learn", region, image_tag)
10392

10493
super(SKLearnProcessor, self).__init__(

tests/integ/test_processing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def test_sklearn_with_customizations(
105105
command=["python3"],
106106
instance_type=cpu_instance_type,
107107
instance_count=1,
108-
py_version="py3",
109108
volume_size_in_gb=100,
110109
volume_kms_key=None,
111110
output_kms_key=output_kms_key,
@@ -181,7 +180,6 @@ def test_sklearn_with_no_inputs_or_outputs(
181180
command=["python3"],
182181
instance_type=cpu_instance_type,
183182
instance_count=1,
184-
py_version="py3",
185183
volume_size_in_gb=100,
186184
volume_kms_key=None,
187185
max_runtime_in_seconds=3600,

tests/unit/test_processing.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def test_sklearn_with_all_customizations(sagemaker_session):
171171
command=["python3"],
172172
instance_type="ml.m4.xlarge",
173173
instance_count=1,
174-
py_version="py3",
175174
volume_size_in_gb=100,
176175
volume_kms_key=None,
177176
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/kms-key",
@@ -462,15 +461,3 @@ def test_byo_container_with_baked_in_script(sagemaker_session):
462461
"experiment_config": None,
463462
}
464463
sagemaker_session.process.assert_called_with(**expected_args)
465-
466-
467-
def test_sklearn_processor_raises_value_error_if_invalid_py_version_passed_in(sagemaker_session):
468-
with pytest.raises(ValueError):
469-
SKLearnProcessor(
470-
framework_version="0.20.0",
471-
role=ROLE,
472-
py_version="INVALID_PYTHON_VERSION",
473-
instance_type="ml.m4.xlarge",
474-
instance_count=1,
475-
sagemaker_session=sagemaker_session,
476-
)

0 commit comments

Comments
 (0)