Skip to content

Commit d203f9d

Browse files
authored
fix: skip pytorch ei test in unsupported regions (aws#1337)
* fix: skip pytorch ei test in unsupported regions
1 parent 5a3d1ad commit d203f9d

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tests/integ/test_pytorch_train.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,23 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import os
16-
1715
import numpy
16+
import os
1817
import pytest
19-
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
20-
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
21-
18+
from sagemaker.pytorch.defaults import LATEST_PY2_VERSION
2219
from sagemaker.pytorch.estimator import PyTorch
2320
from sagemaker.pytorch.model import PyTorchModel
24-
from sagemaker.pytorch.defaults import LATEST_PY2_VERSION
2521
from sagemaker.utils import sagemaker_timestamp
2622

23+
from tests.integ import (
24+
test_region,
25+
DATA_DIR,
26+
PYTHON_VERSION,
27+
TRAINING_DEFAULT_TIMEOUT_MINUTES,
28+
EI_SUPPORTED_REGIONS,
29+
)
30+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
31+
2732
MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
2833
MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py")
2934

@@ -120,6 +125,9 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
120125

121126

122127
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch EIA does not support Python 2.")
128+
@pytest.mark.skipif(
129+
test_region() not in EI_SUPPORTED_REGIONS, reason="EI isn't supported in that specific region."
130+
)
123131
def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
124132
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
125133
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
@@ -134,7 +142,7 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
134142
predictor = pytorch.deploy(
135143
initial_instance_count=1,
136144
instance_type=cpu_instance_type,
137-
accelerator_type="ml.eia2.medium",
145+
accelerator_type="ml.eia1.medium",
138146
endpoint_name=endpoint_name,
139147
)
140148

0 commit comments

Comments
 (0)