Skip to content

Commit c0cc779

Browse files
author
Deng
committed
address comments: accelerator_type, script name, framework and python version check
1 parent adbf8db commit c0cc779

File tree

7 files changed

+22
-22
lines changed

7 files changed

+22
-22
lines changed

doc/using_pytorch.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ In order to attach an Elastic Inference accelerator to your endpoint provide the
259259
260260
predictor = pytorch_estimator.deploy(instance_type='ml.m4.xlarge',
261261
initial_instance_count=1,
262-
accelerator_type='ml.eia1.medium')
262+
accelerator_type='ml.eia2.medium')
263263
264264
The SageMaker PyTorch Model Server
265265
==================================

src/sagemaker/fw_utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
"tensorflow-serving",
5959
"mxnet",
6060
"mxnet-serving",
61-
"pytorch",
6261
"pytorch-serving",
6362
]
6463
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
@@ -90,7 +89,7 @@
9089
"mxnet-serving-eia": [1, 4, 1],
9190
"pytorch": [1, 2, 0],
9291
"pytorch-serving": [1, 2, 0],
93-
"pytorch-serving-eia": {"py3": [1, 3, 1]},
92+
"pytorch-serving-eia": [1, 3, 1],
9493
}
9594

9695
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
@@ -126,8 +125,6 @@ def _is_dlc_version(framework, framework_version, py_version):
126125
"""
127126
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
128127
if isinstance(lowest_version_list, dict):
129-
if py_version not in lowest_version_list:
130-
raise ValueError("{} is not supported in {}.".format(framework, py_version))
131128
lowest_version_list = lowest_version_list[py_version]
132129

133130
if lowest_version_list:

src/sagemaker/pytorch/model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,15 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
136136
For example, 'ml.p2.xlarge'.
137137
accelerator_type (str): The Elastic Inference accelerator type to
138138
deploy to the instance for loading and making inferences to the
139-
model. Currently supported with PyTorch 1.3.1 Python 3.
139+
model.
140140
141141
Returns:
142142
dict[str, str]: A container definition object usable with the
143143
CreateModel API.
144144
"""
145+
if accelerator_type and self.py_version == "py2":
146+
raise ValueError("PyTorch EIA is not supported in Python 2.")
147+
145148
deploy_image = self.image
146149
if not deploy_image:
147150
region_name = self.sagemaker_session.boto_session.region_name
@@ -169,7 +172,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
169172
(cpu/gpu/family-specific optimized).
170173
accelerator_type (str): The Elastic Inference accelerator type to
171174
deploy to the instance for loading and making inferences to the
172-
model. Currently supported with PyTorch 1.3.1 Python 3.
175+
model.
173176
174177
Returns:
175178
str: The appropriate image URI based on the given parameters.

tests/integ/test_pytorch_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
EIA_DIR = os.path.join(DATA_DIR, "pytorch_eia")
3131
EIA_MODEL = os.path.join(EIA_DIR, "model_mnist.tar.gz")
32-
EIA_SCRIPT = os.path.join(EIA_DIR, "mnist.py")
32+
EIA_SCRIPT = os.path.join(EIA_DIR, "empty_inference_script.py")
3333

3434

3535
@pytest.fixture(scope="module", name="pytorch_training_job")

tests/unit/test_fw_utils.py

-14
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,6 @@ def test_pytorch_eia_images():
328328
)
329329

330330

331-
def test_pytorch_eia_py2_error():
332-
error_message = "pytorch-serving-eia is not supported in py2."
333-
with pytest.raises(ValueError) as error:
334-
fw_utils.create_image_uri(
335-
"us-east-1",
336-
"pytorch-serving",
337-
"ml.c4.2xlarge",
338-
"1.3.1",
339-
"py2",
340-
accelerator_type="ml.eia1.large",
341-
)
342-
assert error_message in str(error)
343-
344-
345331
def test_create_image_uri_override_account():
346332
image_uri = fw_utils.create_image_uri(
347333
"us-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"

tests/unit/test_pytorch.py

+14
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,20 @@ def test_model_image_accelerator(sagemaker_session):
352352
assert isinstance(predictor, PyTorchPredictor)
353353

354354

355+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
356+
def test_model_image_accelerator(sagemaker_session):
357+
with pytest.raises(ValueError) as error:
358+
model = PyTorchModel(
359+
MODEL_DATA,
360+
role=ROLE,
361+
entry_point=SCRIPT_PATH,
362+
sagemaker_session=sagemaker_session,
363+
py_version="py2",
364+
)
365+
model.deploy(1, CPU, accelerator_type=ACCELERATOR_TYPE)
366+
assert "PyTorch EIA is not supported in Python 2." in str(error)
367+
368+
355369
def test_train_image_default(sagemaker_session):
356370
pytorch = PyTorch(
357371
entry_point=SCRIPT_PATH,

0 commit comments

Comments
 (0)