Skip to content

Commit a787e47

Browse files
committed
infra: use fixture for Python version in scikit-learn tests
1 parent c24e0b5 commit a787e47

File tree

4 files changed

+45
-29
lines changed

4 files changed

+45
-29
lines changed

tests/conftest.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def pytest_addoption(parser):
5757
parser.addoption(
5858
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
5959
)
60-
parser.addoption("--sklearn-full-version", action="store", default="0.20.0")
6160
parser.addoption("--ei-tf-full-version", action="store")
6261
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
6362

@@ -299,8 +298,13 @@ def rl_ray_full_version(request):
299298

300299

301300
@pytest.fixture(scope="module")
302-
def sklearn_full_version(request):
303-
return request.config.getoption("--sklearn-full-version")
301+
def sklearn_full_version():
302+
return "0.20.0"
303+
304+
305+
@pytest.fixture(scope="module")
306+
def sklearn_full_py_version():
307+
return "py3"
304308

305309

306310
@pytest.fixture(scope="module")

tests/integ/test_airflow_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def test_mxnet_airflow_config_uploads_data_source_to_s3(
475475

476476
@pytest.mark.canary_quick
477477
def test_sklearn_airflow_config_uploads_data_source_to_s3(
478-
sagemaker_session, cpu_instance_type, sklearn_full_version
478+
sagemaker_session, cpu_instance_type, sklearn_full_version, sklearn_full_py_version
479479
):
480480
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
481481
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -486,7 +486,7 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
486486
role=ROLE,
487487
train_instance_type=cpu_instance_type,
488488
framework_version=sklearn_full_version,
489-
py_version=PYTHON_VERSION,
489+
py_version=sklearn_full_py_version,
490490
sagemaker_session=sagemaker_session,
491491
hyperparameters={"epochs": 1},
492492
)

tests/integ/test_git.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def test_private_github(sagemaker_local_session, mxnet_full_version, mxnet_full_
133133

134134
@pytest.mark.local_mode
135135
@pytest.mark.skip("needs a secure authentication approach")
136-
def test_private_github_with_2fa(sagemaker_local_session, sklearn_full_version):
136+
def test_private_github_with_2fa(
137+
sagemaker_local_session, sklearn_full_version, sklearn_full_py_version
138+
):
137139
script_path = "mnist.py"
138140
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
139141
git_config = {
@@ -149,7 +151,7 @@ def test_private_github_with_2fa(sagemaker_local_session, sklearn_full_version):
149151
entry_point=script_path,
150152
role="SageMakerRole",
151153
source_dir=source_dir,
152-
py_version="py3", # Scikit-learn supports only Python 3
154+
py_version=sklearn_full_py_version,
153155
train_instance_count=1,
154156
train_instance_type="local",
155157
sagemaker_session=sagemaker_local_session,
@@ -187,7 +189,9 @@ def test_private_github_with_2fa(sagemaker_local_session, sklearn_full_version):
187189

188190

189191
@pytest.mark.local_mode
190-
def test_github_with_ssh_passphrase_not_configured(sagemaker_local_session, sklearn_full_version):
192+
def test_github_with_ssh_passphrase_not_configured(
193+
sagemaker_local_session, sklearn_full_version, sklearn_full_py_version
194+
):
191195
script_path = "mnist.py"
192196
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
193197
git_config = {
@@ -201,11 +205,11 @@ def test_github_with_ssh_passphrase_not_configured(sagemaker_local_session, skle
201205
entry_point=script_path,
202206
role="SageMakerRole",
203207
source_dir=source_dir,
204-
py_version="py3", # Scikit-learn supports only Python 3
205208
train_instance_count=1,
206209
train_instance_type="local",
207210
sagemaker_session=sagemaker_local_session,
208211
framework_version=sklearn_full_version,
212+
py_version=sklearn_full_py_version,
209213
hyperparameters={"epochs": 1},
210214
git_config=git_config,
211215
)

tests/integ/test_sklearn_train.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.sklearn import SKLearn
2222
from sagemaker.sklearn import SKLearnModel
2323
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base
24-
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
24+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2525
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2626

2727

@@ -30,14 +30,17 @@
3030
reason="This test has always failed, but the failure was masked by a bug. "
3131
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
3232
)
33-
def sklearn_training_job(sagemaker_session, sklearn_full_version, cpu_instance_type):
34-
return _run_mnist_training_job(sagemaker_session, cpu_instance_type, sklearn_full_version)
33+
def sklearn_training_job(
34+
sagemaker_session, sklearn_full_version, sklearn_full_py_version, cpu_instance_type
35+
):
36+
return _run_mnist_training_job(
37+
sagemaker_session, cpu_instance_type, sklearn_full_version, sklearn_full_py_version
38+
)
3539
sagemaker_session.boto_region_name
3640

3741

38-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only Python 3.")
3942
def test_training_with_additional_hyperparameters(
40-
sagemaker_session, sklearn_full_version, cpu_instance_type
43+
sagemaker_session, sklearn_full_version, sklearn_full_py_version, cpu_instance_type
4144
):
4245
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
4346
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -48,7 +51,7 @@ def test_training_with_additional_hyperparameters(
4851
role="SageMakerRole",
4952
train_instance_type=cpu_instance_type,
5053
framework_version=sklearn_full_version,
51-
py_version=PYTHON_VERSION,
54+
py_version=sklearn_full_py_version,
5255
sagemaker_session=sagemaker_session,
5356
hyperparameters={"epochs": 1},
5457
)
@@ -65,9 +68,8 @@ def test_training_with_additional_hyperparameters(
6568
return sklearn.latest_training_job.name
6669

6770

68-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only Python 3.")
6971
def test_training_with_network_isolation(
70-
sagemaker_session, sklearn_full_version, cpu_instance_type
72+
sagemaker_session, sklearn_full_version, sklearn_full_py_version, cpu_instance_type
7173
):
7274
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
7375
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -78,7 +80,7 @@ def test_training_with_network_isolation(
7880
role="SageMakerRole",
7981
train_instance_type=cpu_instance_type,
8082
framework_version=sklearn_full_version,
81-
py_version=PYTHON_VERSION,
83+
py_version=sklearn_full_py_version,
8284
sagemaker_session=sagemaker_session,
8385
hyperparameters={"epochs": 1},
8486
enable_network_isolation=True,
@@ -101,7 +103,6 @@ def test_training_with_network_isolation(
101103

102104
@pytest.mark.canary_quick
103105
@pytest.mark.regional_testing
104-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
105106
@pytest.mark.skip(
106107
reason="This test has always failed, but the failure was masked by a bug. "
107108
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
@@ -115,13 +116,16 @@ def test_attach_deploy(sklearn_training_job, sagemaker_session, cpu_instance_typ
115116
_predict_and_assert(predictor)
116117

117118

118-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
119119
@pytest.mark.skip(
120120
reason="This test has always failed, but the failure was masked by a bug. "
121121
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
122122
)
123123
def test_deploy_model(
124-
sklearn_training_job, sagemaker_session, cpu_instance_type, sklearn_full_version
124+
sklearn_training_job,
125+
sagemaker_session,
126+
cpu_instance_type,
127+
sklearn_full_version,
128+
sklearn_full_py_version,
125129
):
126130
endpoint_name = "test-sklearn-deploy-model-{}".format(sagemaker_timestamp())
127131
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -141,12 +145,13 @@ def test_deploy_model(
141145
_predict_and_assert(predictor)
142146

143147

144-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
145148
@pytest.mark.skip(
146149
reason="This test has always failed, but the failure was masked by a bug. "
147150
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
148151
)
149-
def test_async_fit(sagemaker_session, cpu_instance_type, sklearn_full_version):
152+
def test_async_fit(
153+
sagemaker_session, cpu_instance_type, sklearn_full_version, sklearn_full_py_version
154+
):
150155
endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp())
151156

152157
with timeout(minutes=5):
@@ -169,8 +174,9 @@ def test_async_fit(sagemaker_session, cpu_instance_type, sklearn_full_version):
169174
_predict_and_assert(predictor)
170175

171176

172-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
173-
def test_failed_training_job(sagemaker_session, sklearn_full_version, cpu_instance_type):
177+
def test_failed_training_job(
178+
sagemaker_session, sklearn_full_version, sklearn_full_py_version, cpu_instance_type
179+
):
174180
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
175181
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "failure_script.py")
176182
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
@@ -179,7 +185,7 @@ def test_failed_training_job(sagemaker_session, sklearn_full_version, cpu_instan
179185
entry_point=script_path,
180186
role="SageMakerRole",
181187
framework_version=sklearn_full_version,
182-
py_version=PYTHON_VERSION,
188+
py_version=sklearn_full_py_version,
183189
train_instance_count=1,
184190
train_instance_type=cpu_instance_type,
185191
sagemaker_session=sagemaker_session,
@@ -194,7 +200,9 @@ def test_failed_training_job(sagemaker_session, sklearn_full_version, cpu_instan
194200
sklearn.fit(train_input, job_name=job_name)
195201

196202

197-
def _run_mnist_training_job(sagemaker_session, instance_type, sklearn_full_version, wait=True):
203+
def _run_mnist_training_job(
204+
sagemaker_session, instance_type, sklearn_version, py_version, wait=True
205+
):
198206
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
199207

200208
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -204,8 +212,8 @@ def _run_mnist_training_job(sagemaker_session, instance_type, sklearn_full_versi
204212
sklearn = SKLearn(
205213
entry_point=script_path,
206214
role="SageMakerRole",
207-
framework_version=sklearn_full_version,
208-
py_version=PYTHON_VERSION,
215+
framework_version=sklearn_version,
216+
py_version=py_version,
209217
train_instance_type=instance_type,
210218
sagemaker_session=sagemaker_session,
211219
hyperparameters={"epochs": 1},

0 commit comments

Comments
 (0)