Skip to content

Commit 5848cb8

Browse files
authored
Merge branch 'zwei' into fm-uri
2 parents ad61d98 + cb85792 commit 5848cb8

File tree

4 files changed

+37
-34
lines changed

4 files changed

+37
-34
lines changed

tests/conftest.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,6 @@ def rl_ray_version(request):
197197
return request.param
198198

199199

200-
@pytest.fixture(scope="module")
201-
def chainer_full_version():
202-
return "5.0.0"
203-
204-
205-
@pytest.fixture(scope="module")
206-
def chainer_full_py_version():
207-
return "py3"
208-
209-
210200
@pytest.fixture(scope="module")
211201
def mxnet_full_version():
212202
return "1.6.0"
@@ -378,15 +368,28 @@ def _generate_all_framework_version_fixtures(metafunc):
378368
for fw in ("chainer", "tensorflow"):
379369
config = image_uris.config_for_framework(fw)
380370
if "scope" in config:
381-
_parametrize_framework_version_fixture(metafunc, "{}_version".format(fw), config)
371+
_parametrize_framework_version_fixtures(metafunc, fw, config)
382372
else:
383373
for image_scope in config.keys():
384-
_parametrize_framework_version_fixture(
385-
metafunc, "{}_{}_version".format(fw, image_scope), config[image_scope]
374+
_parametrize_framework_version_fixtures(
375+
metafunc, "{}_{}".format(fw, image_scope), config[image_scope]
386376
)
387377

388378

389-
def _parametrize_framework_version_fixture(metafunc, fixture_name, config):
379+
def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
380+
fixture_name = "{}_version".format(fixture_prefix)
390381
if fixture_name in metafunc.fixturenames:
391382
versions = list(config["versions"].keys()) + list(config.get("version_aliases", {}).keys())
392383
metafunc.parametrize(fixture_name, versions, scope="session")
384+
385+
latest_version = sorted(config["versions"].keys(), key=lambda v: Version(v))[-1]
386+
387+
fixture_name = "{}_latest_version".format(fixture_prefix)
388+
if fixture_name in metafunc.fixturenames:
389+
metafunc.parametrize(fixture_name, (latest_version,), scope="session")
390+
391+
fixture_name = "{}_latest_py_version".format(fixture_prefix)
392+
if fixture_name in metafunc.fixturenames:
393+
config = config["versions"]
394+
py_versions = config[latest_version].get("py_versions", config[latest_version].keys())
395+
metafunc.parametrize(fixture_name, (sorted(py_versions)[-1],), scope="session")

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
404404

405405
@pytest.mark.canary_quick
406406
def test_chainer_airflow_config_uploads_data_source_to_s3(
407-
sagemaker_local_session, cpu_instance_type, chainer_full_version, chainer_full_py_version
407+
sagemaker_local_session, cpu_instance_type, chainer_latest_version, chainer_latest_py_version
408408
):
409409
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
410410
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -415,8 +415,8 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
415415
role=ROLE,
416416
instance_count=SINGLE_INSTANCE_COUNT,
417417
instance_type="local",
418-
framework_version=chainer_full_version,
419-
py_version=chainer_full_py_version,
418+
framework_version=chainer_latest_version,
419+
py_version=chainer_latest_py_version,
420420
sagemaker_session=sagemaker_local_session,
421421
hyperparameters={"epochs": 1},
422422
use_mpi=True,

tests/integ/test_chainer.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,25 @@
2626

2727
@pytest.fixture(scope="module")
2828
def chainer_local_training_job(
29-
sagemaker_local_session, chainer_full_version, chainer_full_py_version
29+
sagemaker_local_session, chainer_latest_version, chainer_latest_py_version
3030
):
3131
return _run_mnist_training_job(
32-
sagemaker_local_session, "local", 1, chainer_full_version, chainer_full_py_version
32+
sagemaker_local_session, "local", 1, chainer_latest_version, chainer_latest_py_version
3333
)
3434

3535

3636
@pytest.mark.local_mode
3737
def test_distributed_cpu_training(
38-
sagemaker_local_session, chainer_full_version, chainer_full_py_version
38+
sagemaker_local_session, chainer_latest_version, chainer_latest_py_version
3939
):
4040
_run_mnist_training_job(
41-
sagemaker_local_session, "local", 2, chainer_full_version, chainer_full_py_version
41+
sagemaker_local_session, "local", 2, chainer_latest_version, chainer_latest_py_version
4242
)
4343

4444

4545
@pytest.mark.local_mode
4646
def test_training_with_additional_hyperparameters(
47-
sagemaker_local_session, chainer_full_version, chainer_full_py_version
47+
sagemaker_local_session, chainer_latest_version, chainer_latest_py_version
4848
):
4949
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
5050
data_path = os.path.join(DATA_DIR, "chainer_mnist")
@@ -54,8 +54,8 @@ def test_training_with_additional_hyperparameters(
5454
role="SageMakerRole",
5555
instance_count=1,
5656
instance_type="local",
57-
framework_version=chainer_full_version,
58-
py_version=chainer_full_py_version,
57+
framework_version=chainer_latest_version,
58+
py_version=chainer_latest_py_version,
5959
sagemaker_session=sagemaker_local_session,
6060
hyperparameters={"epochs": 1},
6161
use_mpi=True,
@@ -72,7 +72,7 @@ def test_training_with_additional_hyperparameters(
7272

7373
@pytest.mark.canary_quick
7474
def test_attach_deploy(
75-
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
75+
sagemaker_session, chainer_latest_version, chainer_latest_py_version, cpu_instance_type
7676
):
7777
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
7878
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -81,8 +81,8 @@ def test_attach_deploy(
8181
chainer = Chainer(
8282
entry_point=script_path,
8383
role="SageMakerRole",
84-
framework_version=chainer_full_version,
85-
py_version=chainer_full_py_version,
84+
framework_version=chainer_latest_version,
85+
py_version=chainer_latest_py_version,
8686
instance_count=1,
8787
instance_type=cpu_instance_type,
8888
sagemaker_session=sagemaker_session,
@@ -114,8 +114,8 @@ def test_attach_deploy(
114114
def test_deploy_model(
115115
chainer_local_training_job,
116116
sagemaker_local_session,
117-
chainer_full_version,
118-
chainer_full_py_version,
117+
chainer_latest_version,
118+
chainer_latest_py_version,
119119
):
120120
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
121121

@@ -124,8 +124,8 @@ def test_deploy_model(
124124
"SageMakerRole",
125125
entry_point=script_path,
126126
sagemaker_session=sagemaker_local_session,
127-
framework_version=chainer_full_version,
128-
py_version=chainer_full_py_version,
127+
framework_version=chainer_latest_version,
128+
py_version=chainer_latest_py_version,
129129
)
130130

131131
predictor = model.deploy(1, "local")

tests/integ/test_tuner.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def test_tuning_tf_vpc_multi(
687687

688688
@pytest.mark.canary_quick
689689
def test_tuning_chainer(
690-
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
690+
sagemaker_session, chainer_latest_version, chainer_latest_py_version, cpu_instance_type
691691
):
692692
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
693693
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -696,8 +696,8 @@ def test_tuning_chainer(
696696
estimator = Chainer(
697697
entry_point=script_path,
698698
role="SageMakerRole",
699-
framework_version=chainer_full_version,
700-
py_version=chainer_full_py_version,
699+
framework_version=chainer_latest_version,
700+
py_version=chainer_latest_py_version,
701701
instance_count=1,
702702
instance_type=cpu_instance_type,
703703
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)