21
21
22
22
from sagemaker .pytorch .estimator import PyTorch
23
23
from sagemaker .pytorch .model import PyTorchModel
24
+ from sagemaker .pytorch .defaults import LATEST_PY2_VERSION
24
25
from sagemaker .utils import sagemaker_timestamp
25
26
26
27
MNIST_DIR = os .path .join (DATA_DIR , "pytorch_mnist" )
@@ -38,7 +39,10 @@ def fixture_training_job(sagemaker_session, pytorch_full_version, cpu_instance_t
38
39
39
40
@pytest .mark .canary_quick
40
41
@pytest .mark .regional_testing
41
- @pytest .mark .skipif (PYTHON_VERSION == "py2" , reason = "PyTorch Inference not supporting Python2." )
42
+ @pytest .mark .skipif (
43
+ PYTHON_VERSION == "py2" ,
44
+ reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
45
+ )
42
46
def test_sync_fit_deploy (pytorch_training_job , sagemaker_session , cpu_instance_type ):
43
47
# TODO: add tests against local mode when it's ready to be used
44
48
endpoint_name = "test-pytorch-sync-fit-attach-deploy{}" .format (sagemaker_timestamp ())
@@ -55,7 +59,11 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
55
59
assert output .shape == (batch_size , 10 )
56
60
57
61
58
- @pytest .mark .skipif (PYTHON_VERSION == "py2" , reason = "PyTorch Inference not supporting Python2." )
62
+ @pytest .mark .local_mode
63
+ @pytest .mark .skipif (
64
+ PYTHON_VERSION == "py2" ,
65
+ reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
66
+ )
59
67
def test_fit_deploy (sagemaker_local_session , pytorch_full_version ):
60
68
pytorch = PyTorch (
61
69
entry_point = MNIST_SCRIPT ,
0 commit comments