Skip to content

Commit 1151215

Browse files
author
Deng
committed
check framework and py version in accelerator validate method
1 parent e96e629 commit 1151215

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

src/sagemaker/fw_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"mxnet-serving",
6161
"pytorch-serving",
6262
]
63+
PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"]
6364
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
6465
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"}
6566
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"}
@@ -215,6 +216,7 @@ def create_image_uri(
215216

216217
if _accelerator_type_valid_for_framework(
217218
framework=framework,
219+
py_version=py_version,
218220
accelerator_type=accelerator_type,
219221
optimized_families=optimized_families,
220222
):
@@ -267,17 +269,23 @@ def create_image_uri(
267269

268270

269271
def _accelerator_type_valid_for_framework(
270-
framework, accelerator_type=None, optimized_families=None
272+
framework, py_version, accelerator_type=None, optimized_families=None
271273
):
272274
"""
273275
Args:
274276
framework:
277+
py_version:
275278
accelerator_type:
276279
optimized_families:
277280
"""
278281
if accelerator_type is None:
279282
return False
280283

284+
if py_version == "py2" and framework in PY2_RESTRICTED_EIA_FRAMEWORKS:
285+
raise ValueError(
286+
"{} is not supported with Amazon Elastic Inference in Python 2.".format(framework)
287+
)
288+
281289
if framework not in VALID_EIA_FRAMEWORKS:
282290
raise ValueError(
283291
"{} is not supported with Amazon Elastic Inference. Currently only "

src/sagemaker/pytorch/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,6 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
142142
dict[str, str]: A container definition object usable with the
143143
CreateModel API.
144144
"""
145-
if accelerator_type and self.py_version == "py2":
146-
raise ValueError("PyTorch EIA is not supported in Python 2.")
147-
148145
deploy_image = self.image
149146
if not deploy_image:
150147
region_name = self.sagemaker_session.boto_session.region_name

tests/unit/test_fw_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,20 @@ def test_pytorch_eia_images():
328328
)
329329

330330

331+
def test_pytorch_eia_py2_error():
332+
error_message = "pytorch-serving is not supported with Amazon Elastic Inference in Python 2."
333+
with pytest.raises(ValueError) as error:
334+
fw_utils.create_image_uri(
335+
"us-east-1",
336+
"pytorch-serving",
337+
"ml.c4.2xlarge",
338+
"1.3.1",
339+
"py2",
340+
accelerator_type="ml.eia1.large",
341+
)
342+
assert error_message in str(error)
343+
344+
331345
def test_create_image_uri_override_account():
332346
image_uri = fw_utils.create_image_uri(
333347
"us-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"

tests/unit/test_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,9 @@ def test_model_image_accelerator(sagemaker_session):
355355
py_version="py2",
356356
)
357357
model.deploy(1, CPU, accelerator_type=ACCELERATOR_TYPE)
358-
assert "PyTorch EIA is not supported in Python 2." in str(error)
358+
assert "pytorch-serving is not supported with Amazon Elastic Inference in Python 2." in str(
359+
error
360+
)
359361

360362

361363
def test_train_image_default(sagemaker_session):

0 commit comments

Comments
 (0)