21
21
from botocore .config import Config
22
22
23
23
from sagemaker import Session , utils
24
- from sagemaker .chainer import Chainer
25
24
from sagemaker .local import LocalSession
26
- from sagemaker .mxnet import MXNet
27
- from sagemaker .pytorch import PyTorch
28
25
from sagemaker .rl import RLEstimator
29
- from sagemaker .sklearn .defaults import SKLEARN_VERSION
30
- from sagemaker .tensorflow import TensorFlow
31
- from sagemaker .tensorflow .defaults import LATEST_VERSION , LATEST_SERVING_VERSION
32
26
33
27
DEFAULT_REGION = "us-west-2"
34
28
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
44
38
45
39
NO_T2_REGIONS = ["eu-north-1" , "ap-east-1" , "me-south-1" ]
46
40
41
+ # TODO: refactor handling of versions, repo, image uris, validations for all frameworks
42
+ TENSORFLOW_LATEST_VERSION = "2.2.0"
43
+ TENSORFLOW_LATEST_1X_VERSION = "1.15.2"
44
+
47
45
48
46
def pytest_addoption (parser ):
49
47
parser .addoption ("--sagemaker-client-config" , action = "store" , default = None )
50
48
parser .addoption ("--sagemaker-runtime-config" , action = "store" , default = None )
51
49
parser .addoption ("--boto-config" , action = "store" , default = None )
52
- parser .addoption ("--chainer-full-version" , action = "store" , default = Chainer . LATEST_VERSION )
53
- parser .addoption ("--mxnet-full-version" , action = "store" , default = MXNet . LATEST_VERSION )
50
+ parser .addoption ("--chainer-full-version" , action = "store" , default = "5.0.0" )
51
+ parser .addoption ("--mxnet-full-version" , action = "store" , default = "1.6.0" )
54
52
parser .addoption ("--ei-mxnet-full-version" , action = "store" , default = "1.5.1" )
55
- parser .addoption ("--pytorch-full-version" , action = "store" , default = PyTorch . LATEST_VERSION )
53
+ parser .addoption ("--pytorch-full-version" , action = "store" , default = "1.5.0" )
56
54
parser .addoption (
57
55
"--rl-coach-mxnet-full-version" ,
58
56
action = "store" ,
@@ -64,10 +62,10 @@ def pytest_addoption(parser):
64
62
parser .addoption (
65
63
"--rl-ray-full-version" , action = "store" , default = RLEstimator .RAY_LATEST_VERSION
66
64
)
67
- parser .addoption ("--sklearn-full-version" , action = "store" , default = SKLEARN_VERSION )
65
+ parser .addoption ("--sklearn-full-version" , action = "store" , default = "0.20.0" )
68
66
parser .addoption ("--tf-full-version" , action = "store" )
69
67
parser .addoption ("--ei-tf-full-version" , action = "store" )
70
- parser .addoption ("--xgboost-full-version" , action = "store" , default = SKLEARN_VERSION )
68
+ parser .addoption ("--xgboost-full-version" , action = "store" , default = "1.0-1" )
71
69
72
70
73
71
def pytest_configure (config ):
@@ -291,7 +289,27 @@ def sklearn_full_version(request):
291
289
return request .config .getoption ("--sklearn-full-version" )
292
290
293
291
294
- @pytest .fixture (scope = "module" , params = [TensorFlow ._LATEST_1X_VERSION , LATEST_VERSION ])
292
+ @pytest .fixture (scope = "module" , params = [TENSORFLOW_LATEST_VERSION ])
293
+ def tf_latest_version (request ):
294
+ return request .param
295
+
296
+
297
+ @pytest .fixture (scope = "module" )
298
+ def tf_latest_py_version ():
299
+ return "py37"
300
+
301
+
302
+ @pytest .fixture (scope = "module" , params = [TENSORFLOW_LATEST_1X_VERSION ])
303
+ def tf_latest_1x_version (request ):
304
+ return request .param
305
+
306
+
307
+ @pytest .fixture (scope = "module" )
308
+ def tf_latest_serving_version ():
309
+ return "2.1.0"
310
+
311
+
312
+ @pytest .fixture (scope = "module" , params = [TENSORFLOW_LATEST_VERSION , TENSORFLOW_LATEST_1X_VERSION ])
295
313
def tf_full_version (request ):
296
314
tf_version = request .config .getoption ("--tf-full-version" )
297
315
if tf_version is None :
@@ -301,7 +319,7 @@ def tf_full_version(request):
301
319
302
320
303
321
@pytest .fixture (scope = "module" )
304
- def tf_full_py_version (tf_full_version , request ):
322
+ def tf_full_py_version (tf_full_version , tf_latest_version , tf_latest_1x_version ):
305
323
"""fixture to match tf_full_version
306
324
307
325
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
@@ -312,11 +330,18 @@ def tf_full_py_version(tf_full_version, request):
312
330
version = [int (val ) for val in tf_full_version .split ("." )]
313
331
if version < [1 , 11 ]:
314
332
return "py2"
315
- if tf_full_version in [TensorFlow . _LATEST_1X_VERSION , LATEST_VERSION ]:
333
+ if tf_full_version in [tf_latest_version , tf_latest_1x_version ]:
316
334
return "py37"
317
335
return "py3"
318
336
319
337
338
+ @pytest .fixture (scope = "module" )
339
+ def tf_serving_version (tf_full_version , tf_latest_version , tf_latest_serving_version ):
340
+ if tf_full_version == tf_latest_version :
341
+ return tf_latest_serving_version
342
+ return tf_full_version
343
+
344
+
320
345
@pytest .fixture (scope = "module" , params = ["1.15.0" , "2.0.0" ])
321
346
def ei_tf_full_version (request ):
322
347
tf_ei_version = request .config .getoption ("--ei-tf-full-version" )
@@ -384,10 +409,3 @@ def pytest_generate_tests(metafunc):
384
409
@pytest .fixture (scope = "module" )
385
410
def xgboost_full_version (request ):
386
411
return request .config .getoption ("--xgboost-full-version" )
387
-
388
-
389
- @pytest .fixture (scope = "module" )
390
- def tf_serving_version (tf_full_version ):
391
- if tf_full_version == LATEST_VERSION :
392
- return LATEST_SERVING_VERSION
393
- return tf_full_version
0 commit comments