diff --git a/tests/conftest.py b/tests/conftest.py index 7c379a9f06..ddecdd687a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,11 @@ from sagemaker import Session from sagemaker.local import LocalSession +from sagemaker.chainer.defaults import CHAINER_VERSION +from sagemaker.pytorch.defaults import PYTORCH_VERSION +from sagemaker.mxnet.defaults import MXNET_VERSION +from sagemaker.tensorflow.defaults import TF_VERSION + DEFAULT_REGION = 'us-west-2' @@ -91,21 +96,21 @@ def chainer_version(request): return request.param -@pytest.fixture(scope='module', params=['1.4.1', '1.5.0', '1.6.0', '1.7.0', '1.8.0']) +@pytest.fixture(scope='module', params=[TF_VERSION]) def tf_full_version(request): return request.param -@pytest.fixture(scope='module', params=['0.12.1', '1.0.0', '1.1.0', '1.2.1']) +@pytest.fixture(scope='module', params=[MXNET_VERSION]) def mxnet_full_version(request): return request.param -@pytest.fixture(scope='module', params=["0.4.0"]) +@pytest.fixture(scope='module', params=[PYTORCH_VERSION]) def pytorch_full_version(request): return request.param -@pytest.fixture(scope='module', params=['4.0.0', '4.1.0']) +@pytest.fixture(scope='module', params=[CHAINER_VERSION]) def chainer_full_version(request): return request.param