Skip to content

Commit 4f0761e

Browse files
committed
infra: use fixture for Chainer and XGBoost Python version, clean up remaining version fixtures
1 parent 6edac7c commit 4f0761e

File tree

6 files changed

+57
-62
lines changed

6 files changed

+57
-62
lines changed

tests/conftest.py

+28-37
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,6 @@ def pytest_addoption(parser):
4444
parser.addoption("--sagemaker-client-config", action="store", default=None)
4545
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
4646
parser.addoption("--boto-config", action="store", default=None)
47-
parser.addoption("--chainer-full-version", action="store", default="5.0.0")
48-
parser.addoption("--ei-mxnet-full-version", action="store", default="1.5.1")
49-
parser.addoption(
50-
"--rl-coach-mxnet-full-version",
51-
action="store",
52-
default=RLEstimator.COACH_LATEST_VERSION_MXNET,
53-
)
54-
parser.addoption(
55-
"--rl-coach-tf-full-version", action="store", default=RLEstimator.COACH_LATEST_VERSION_TF
56-
)
57-
parser.addoption(
58-
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
59-
)
60-
parser.addoption("--ei-tf-full-version", action="store")
61-
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
6247

6348

6449
def pytest_configure(config):
@@ -248,8 +233,13 @@ def rl_ray_version(request):
248233

249234

250235
@pytest.fixture(scope="module")
251-
def chainer_full_version(request):
252-
return request.config.getoption("--chainer-full-version")
236+
def chainer_full_version():
237+
return "5.0.0"
238+
239+
240+
@pytest.fixture(scope="module")
241+
def chainer_full_py_version():
242+
return "py3"
253243

254244

255245
@pytest.fixture(scope="module")
@@ -263,8 +253,8 @@ def mxnet_full_py_version():
263253

264254

265255
@pytest.fixture(scope="module")
266-
def ei_mxnet_full_version(request):
267-
return request.config.getoption("--ei-mxnet-full-version")
256+
def ei_mxnet_full_version():
257+
return "1.5.1"
268258

269259

270260
@pytest.fixture(scope="module")
@@ -283,18 +273,18 @@ def pytorch_full_ei_version():
283273

284274

285275
@pytest.fixture(scope="module")
286-
def rl_coach_mxnet_full_version(request):
287-
return request.config.getoption("--rl-coach-mxnet-full-version")
276+
def rl_coach_mxnet_full_version():
277+
return RLEstimator.COACH_LATEST_VERSION_MXNET
288278

289279

290280
@pytest.fixture(scope="module")
291-
def rl_coach_tf_full_version(request):
292-
return request.config.getoption("--rl-coach-tf-full-version")
281+
def rl_coach_tf_full_version():
282+
return RLEstimator.COACH_LATEST_VERSION_TF
293283

294284

295285
@pytest.fixture(scope="module")
296-
def rl_ray_full_version(request):
297-
return request.config.getoption("--rl-ray-full-version")
286+
def rl_ray_full_version():
287+
return RLEstimator.RAY_LATEST_VERSION
298288

299289

300290
@pytest.fixture(scope="module")
@@ -347,13 +337,19 @@ def tf_full_py_version(tf_full_version):
347337
return "py37"
348338

349339

350-
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
351-
def ei_tf_full_version(request):
352-
tf_ei_version = request.config.getoption("--ei-tf-full-version")
353-
if tf_ei_version is None:
354-
return request.param
355-
else:
356-
tf_ei_version
340+
@pytest.fixture(scope="module")
341+
def ei_tf_full_version():
342+
return "2.0.0"
343+
344+
345+
@pytest.fixture(scope="module")
346+
def xgboost_full_version():
347+
return "1.0-1"
348+
349+
350+
@pytest.fixture(scope="module")
351+
def xgboost_full_py_version():
352+
return "py3"
357353

358354

359355
@pytest.fixture(scope="session")
@@ -409,8 +405,3 @@ def pytest_generate_tests(metafunc):
409405
):
410406
params.append("ml.p2.xlarge")
411407
metafunc.parametrize("instance_type", params, scope="session")
412-
413-
414-
@pytest.fixture(scope="module")
415-
def xgboost_full_version(request):
416-
return request.config.getoption("--xgboost-full-version")

tests/integ/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import os
17-
import sys
1817

1918
import boto3
2019

@@ -23,7 +22,6 @@
2322
TUNING_DEFAULT_TIMEOUT_MINUTES = 20
2423
TRANSFORM_DEFAULT_TIMEOUT_MINUTES = 20
2524
AUTO_ML_DEFAULT_TIMEMOUT_MINUTES = 60
26-
PYTHON_VERSION = "py{}".format(sys.version_info.major)
2725

2826
# these regions have some p2 and p3 instances, but not enough for continuous testing
2927
HOSTING_NO_P2_REGIONS = [

tests/integ/test_airflow_config.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from sagemaker.utils import sagemaker_timestamp
4646
from sagemaker.workflow import airflow as sm_airflow
4747
from sagemaker.xgboost import XGBoost
48-
from tests.integ import datasets, DATA_DIR, PYTHON_VERSION
48+
from tests.integ import datasets, DATA_DIR
4949
from tests.integ.record_set import prepare_record_set_from_local_files
5050
from tests.integ.timeout import timeout
5151

@@ -404,7 +404,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
404404

405405
@pytest.mark.canary_quick
406406
def test_chainer_airflow_config_uploads_data_source_to_s3(
407-
sagemaker_local_session, cpu_instance_type, chainer_full_version
407+
sagemaker_local_session, cpu_instance_type, chainer_full_version, chainer_full_py_version
408408
):
409409
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
410410
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -416,7 +416,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
416416
train_instance_count=SINGLE_INSTANCE_COUNT,
417417
train_instance_type="local",
418418
framework_version=chainer_full_version,
419-
py_version=PYTHON_VERSION,
419+
py_version=chainer_full_py_version,
420420
sagemaker_session=sagemaker_local_session,
421421
hyperparameters={"epochs": 1},
422422
use_mpi=True,
@@ -545,20 +545,19 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
545545

546546

547547
@pytest.mark.canary_quick
548-
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="XGBoost container does not support Python 2.")
549548
def test_xgboost_airflow_config_uploads_data_source_to_s3(
550-
sagemaker_session, cpu_instance_type, xgboost_full_version
549+
sagemaker_session, cpu_instance_type, xgboost_full_version, xgboost_full_py_version
551550
):
552551
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
553552
xgboost = XGBoost(
554553
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
555554
framework_version=xgboost_full_version,
555+
py_version=xgboost_full_py_version,
556556
role=ROLE,
557557
sagemaker_session=sagemaker_session,
558558
train_instance_type=cpu_instance_type,
559559
train_instance_count=SINGLE_INSTANCE_COUNT,
560560
base_job_name="XGBoost job",
561-
py_version=PYTHON_VERSION,
562561
)
563562

564563
training_config = _build_airflow_workflow(

tests/integ/test_chainer_train.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.chainer.estimator import Chainer
2121
from sagemaker.chainer.model import ChainerModel
2222
from sagemaker.utils import unique_name_from_base
23-
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
23+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2424
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2525

2626

@@ -35,7 +35,9 @@ def test_distributed_cpu_training(sagemaker_local_session, chainer_full_version)
3535

3636

3737
@pytest.mark.local_mode
38-
def test_training_with_additional_hyperparameters(sagemaker_local_session, chainer_full_version):
38+
def test_training_with_additional_hyperparameters(
39+
sagemaker_local_session, chainer_full_version, chainer_full_py_version
40+
):
3941
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
4042
data_path = os.path.join(DATA_DIR, "chainer_mnist")
4143

@@ -45,7 +47,7 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
4547
train_instance_count=1,
4648
train_instance_type="local",
4749
framework_version=chainer_full_version,
48-
py_version=PYTHON_VERSION,
50+
py_version=chainer_full_py_version,
4951
sagemaker_session=sagemaker_local_session,
5052
hyperparameters={"epochs": 1},
5153
use_mpi=True,
@@ -62,7 +64,9 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
6264

6365
@pytest.mark.canary_quick
6466
@pytest.mark.regional_testing
65-
def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_type):
67+
def test_attach_deploy(
68+
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
69+
):
6670
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
6771
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
6872
data_path = os.path.join(DATA_DIR, "chainer_mnist")
@@ -71,7 +75,7 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
7175
entry_point=script_path,
7276
role="SageMakerRole",
7377
framework_version=chainer_full_version,
74-
py_version=PYTHON_VERSION,
78+
py_version=chainer_full_py_version,
7579
train_instance_count=1,
7680
train_instance_type=cpu_instance_type,
7781
sagemaker_session=sagemaker_session,
@@ -100,7 +104,12 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
100104

101105

102106
@pytest.mark.local_mode
103-
def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chainer_full_version):
107+
def test_deploy_model(
108+
chainer_local_training_job,
109+
sagemaker_local_session,
110+
chainer_full_version,
111+
chainer_full_py_version,
112+
):
104113
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
105114

106115
model = ChainerModel(
@@ -109,7 +118,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
109118
entry_point=script_path,
110119
sagemaker_session=sagemaker_local_session,
111120
framework_version=chainer_full_version,
112-
py_version=PYTHON_VERSION,
121+
py_version=chainer_full_py_version,
113122
)
114123

115124
predictor = model.deploy(1, "local")
@@ -120,7 +129,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
120129

121130

122131
def _run_mnist_training_job(
123-
sagemaker_session, instance_type, instance_count, chainer_full_version, wait=True
132+
sagemaker_session, instance_type, instance_count, chainer_version, py_version, wait=True
124133
):
125134
script_path = (
126135
os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -133,8 +142,8 @@ def _run_mnist_training_job(
133142
chainer = Chainer(
134143
entry_point=script_path,
135144
role="SageMakerRole",
136-
framework_version=chainer_full_version,
137-
py_version=PYTHON_VERSION,
145+
framework_version=chainer_version,
146+
py_version=py_version,
138147
train_instance_count=instance_count,
139148
train_instance_type=instance_type,
140149
sagemaker_session=sagemaker_session,

tests/integ/test_rl.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020
from sagemaker.rl import RLEstimator, RLFramework, RLToolkit
2121
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base
22-
from tests.integ import DATA_DIR, PYTHON_VERSION
22+
from tests.integ import DATA_DIR
2323
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2424

2525

2626
@pytest.mark.canary_quick
27-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
2827
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instance_type):
2928
estimator = _test_coach(
3029
sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version, cpu_instance_type
@@ -52,7 +51,6 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instanc
5251
assert 0 < action[0][1] < 1
5352

5453

55-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
5654
def test_coach_tf(sagemaker_session, rl_coach_tf_full_version, cpu_instance_type):
5755
estimator = _test_coach(
5856
sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version, cpu_instance_type
@@ -98,7 +96,6 @@ def _test_coach(sagemaker_session, rl_framework, rl_coach_version, cpu_instance_
9896

9997

10098
@pytest.mark.canary_quick
101-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
10299
def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
103100
source_dir = os.path.join(DATA_DIR, "ray_cartpole")
104101
cartpole = "train_ray.py"

tests/integ/test_tuner.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
datasets,
4646
vpc_test_utils,
4747
DATA_DIR,
48-
PYTHON_VERSION,
4948
TUNING_DEFAULT_TIMEOUT_MINUTES,
5049
)
5150
from tests.integ.record_set import prepare_record_set_from_local_files
@@ -687,7 +686,9 @@ def test_tuning_tf_vpc_multi(
687686

688687

689688
@pytest.mark.canary_quick
690-
def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_type):
689+
def test_tuning_chainer(
690+
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
691+
):
691692
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
692693
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
693694
data_path = os.path.join(DATA_DIR, "chainer_mnist")
@@ -696,7 +697,7 @@ def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_ty
696697
entry_point=script_path,
697698
role="SageMakerRole",
698699
framework_version=chainer_full_version,
699-
py_version=PYTHON_VERSION,
700+
py_version=chainer_full_py_version,
700701
train_instance_count=1,
701702
train_instance_type=cpu_instance_type,
702703
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)