Skip to content

Commit 21f0a6a

Browse files
author
Deng
committed
update pytorch py2 test skip reason
1 parent ea781ed commit 21f0a6a

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tests/integ/test_pytorch_train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from sagemaker.pytorch.estimator import PyTorch
2323
from sagemaker.pytorch.model import PyTorchModel
24+
from sagemaker.pytorch.defaults import LATEST_PY2_VERSION
2425
from sagemaker.utils import sagemaker_timestamp
2526

2627
MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
@@ -38,7 +39,10 @@ def fixture_training_job(sagemaker_session, pytorch_full_version, cpu_instance_t
3839

3940
@pytest.mark.canary_quick
4041
@pytest.mark.regional_testing
41-
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch Inference not supporting Python2.")
42+
@pytest.mark.skipif(
43+
PYTHON_VERSION == "py2",
44+
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
45+
)
4246
def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_type):
4347
# TODO: add tests against local mode when it's ready to be used
4448
endpoint_name = "test-pytorch-sync-fit-attach-deploy{}".format(sagemaker_timestamp())
@@ -55,7 +59,11 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
5559
assert output.shape == (batch_size, 10)
5660

5761

58-
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch Inference not supporting Python2.")
62+
@pytest.mark.local_mode
63+
@pytest.mark.skipif(
64+
PYTHON_VERSION == "py2",
65+
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
66+
)
5967
def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
6068
pytorch = PyTorch(
6169
entry_point=MNIST_SCRIPT,

0 commit comments

Comments
 (0)