Skip to content

Commit e89ecc6

Browse files
committed
change: add PyTorch configuration for image_uris.retrieve()
1 parent 8df2583 commit e89ecc6

File tree

11 files changed

+724
-156
lines changed

11 files changed

+724
-156
lines changed

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 458 additions & 0 deletions
Large diffs are not rendered by default.

src/sagemaker/image_uris.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ def config_for_framework(framework):
103103
def _validate_version_and_set_if_needed(version, config, framework):
104104
"""Checks if the framework/algorithm version is one of the supported versions."""
105105
available_versions = list(config["versions"].keys())
106+
aliased_versions = list(config.get("version_aliases", {}).keys())
106107

107-
if len(available_versions) == 1:
108+
if len(available_versions) == 1 and version not in aliased_versions:
108109
log_message = "Defaulting to the only supported framework/algorithm version: {}.".format(
109110
available_versions[0]
110111
)
@@ -115,8 +116,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
115116

116117
return available_versions[0]
117118

118-
available_versions += list(config.get("version_aliases", {}).keys())
119-
_validate_arg("{} version".format(framework), version, available_versions)
119+
_validate_arg("{} version".format(framework), version, available_versions + aliased_versions)
120120

121121
return version
122122

tests/conftest.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,25 @@ def mxnet_py_version(request):
120120
return request.param
121121

122122

123-
@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"])
124-
def pytorch_version(request):
125-
return request.param
123+
@pytest.fixture(scope="module", params=["py2", "py3"])
124+
def pytorch_training_py_version(pytorch_training_version, request):
125+
if Version(pytorch_training_version) < Version("1.5.0"):
126+
return request.param
127+
else:
128+
return "py3"
126129

127130

128131
@pytest.fixture(scope="module", params=["py2", "py3"])
129-
def pytorch_py_version(request):
130-
return request.param
132+
def pytorch_inference_py_version(pytorch_inference_version, request):
133+
if Version(pytorch_inference_version) < Version("1.4.0"):
134+
return request.param
135+
else:
136+
return "py3"
137+
138+
139+
@pytest.fixture(scope="module")
140+
def pytorch_eia_py_version():
141+
return "py3"
131142

132143

133144
@pytest.fixture(scope="module", params=["0.20.0"])
@@ -176,21 +187,6 @@ def rl_ray_version(request):
176187
return request.param
177188

178189

179-
@pytest.fixture(scope="module")
180-
def pytorch_full_version():
181-
return "1.5.0"
182-
183-
184-
@pytest.fixture(scope="module")
185-
def pytorch_full_py_version():
186-
return "py3"
187-
188-
189-
@pytest.fixture(scope="module")
190-
def pytorch_full_ei_version():
191-
return "1.3.1"
192-
193-
194190
@pytest.fixture(scope="module")
195191
def rl_coach_mxnet_full_version():
196192
return RLEstimator.COACH_LATEST_VERSION_MXNET
@@ -314,7 +310,7 @@ def pytest_generate_tests(metafunc):
314310

315311

316312
def _generate_all_framework_version_fixtures(metafunc):
317-
for fw in ("chainer", "mxnet", "tensorflow", "xgboost"):
313+
for fw in ("chainer", "mxnet", "pytorch", "tensorflow", "xgboost"):
318314
config = image_uris.config_for_framework(fw)
319315
if "scope" in config:
320316
_parametrize_framework_version_fixtures(metafunc, fw, config)

tests/integ/test_airflow_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,17 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(
578578

579579
@pytest.mark.canary_quick
580580
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
581-
sagemaker_session, cpu_instance_type, pytorch_full_version, pytorch_full_py_version
581+
sagemaker_session,
582+
cpu_instance_type,
583+
pytorch_training_latest_version,
584+
pytorch_training_latest_py_version,
582585
):
583586
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
584587
estimator = PyTorch(
585588
entry_point=PYTORCH_MNIST_SCRIPT,
586589
role=ROLE,
587-
framework_version=pytorch_full_version,
588-
py_version=pytorch_full_py_version,
590+
framework_version=pytorch_training_latest_version,
591+
py_version=pytorch_training_latest_py_version,
589592
instance_count=2,
590593
instance_type=cpu_instance_type,
591594
hyperparameters={"epochs": 6, "backend": "gloo"},

tests/integ/test_git.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,18 @@
5050

5151

5252
@pytest.mark.local_mode
53-
def test_github(sagemaker_local_session, pytorch_full_version, pytorch_full_py_version):
53+
def test_github(
54+
sagemaker_local_session, pytorch_training_latest_version, pytorch_training_latest_py_version
55+
):
5456
script_path = "mnist.py"
5557
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
5658

5759
pytorch = PyTorch(
5860
entry_point=script_path,
5961
role="SageMakerRole",
6062
source_dir="pytorch",
61-
framework_version=pytorch_full_version,
62-
py_version=pytorch_full_py_version,
63+
framework_version=pytorch_training_latest_version,
64+
py_version=pytorch_training_latest_py_version,
6365
instance_count=1,
6466
instance_type="local",
6567
sagemaker_session=sagemaker_local_session,

tests/integ/test_pytorch.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,17 @@
3838

3939
@pytest.fixture(scope="module", name="pytorch_training_job")
4040
def fixture_training_job(
41-
sagemaker_session, pytorch_full_version, pytorch_full_py_version, cpu_instance_type
41+
sagemaker_session,
42+
pytorch_training_latest_version,
43+
pytorch_training_latest_py_version,
44+
cpu_instance_type,
4245
):
4346
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
4447
pytorch = _get_pytorch_estimator(
45-
sagemaker_session, pytorch_full_version, pytorch_full_py_version, cpu_instance_type
48+
sagemaker_session,
49+
pytorch_training_latest_version,
50+
pytorch_training_latest_py_version,
51+
cpu_instance_type,
4652
)
4753

4854
pytorch.fit({"training": _upload_training_data(pytorch)})
@@ -66,12 +72,14 @@ def test_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_type):
6672

6773

6874
@pytest.mark.local_mode
69-
def test_local_fit_deploy(sagemaker_local_session, pytorch_full_version, pytorch_full_py_version):
75+
def test_local_fit_deploy(
76+
sagemaker_local_session, pytorch_training_latest_version, pytorch_training_latest_py_version
77+
):
7078
pytorch = PyTorch(
7179
entry_point=MNIST_SCRIPT,
7280
role="SageMakerRole",
73-
framework_version=pytorch_full_version,
74-
py_version=pytorch_full_py_version,
81+
framework_version=pytorch_training_latest_version,
82+
py_version=pytorch_training_latest_py_version,
7583
instance_count=1,
7684
instance_type="local",
7785
sagemaker_session=sagemaker_local_session,
@@ -94,8 +102,8 @@ def test_deploy_model(
94102
pytorch_training_job,
95103
sagemaker_session,
96104
cpu_instance_type,
97-
pytorch_full_version,
98-
pytorch_full_py_version,
105+
pytorch_inference_latest_version,
106+
pytorch_inference_latest_py_version,
99107
):
100108
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
101109

@@ -108,8 +116,8 @@ def test_deploy_model(
108116
model_data,
109117
"SageMakerRole",
110118
entry_point=MNIST_SCRIPT,
111-
framework_version=pytorch_full_version,
112-
py_version=pytorch_full_py_version,
119+
framework_version=pytorch_inference_latest_version,
120+
py_version=pytorch_inference_latest_py_version,
113121
sagemaker_session=sagemaker_session,
114122
)
115123
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -122,7 +130,10 @@ def test_deploy_model(
122130

123131

124132
def test_deploy_packed_model_with_entry_point_name(
125-
sagemaker_session, cpu_instance_type, pytorch_full_version, pytorch_full_py_version
133+
sagemaker_session,
134+
cpu_instance_type,
135+
pytorch_inference_latest_version,
136+
pytorch_inference_latest_py_version,
126137
):
127138
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
128139

@@ -132,8 +143,8 @@ def test_deploy_packed_model_with_entry_point_name(
132143
model_data,
133144
"SageMakerRole",
134145
entry_point="mnist.py",
135-
framework_version=pytorch_full_version,
136-
py_version=pytorch_full_py_version,
146+
framework_version=pytorch_inference_latest_version,
147+
py_version=pytorch_inference_latest_py_version,
137148
sagemaker_session=sagemaker_session,
138149
)
139150
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -149,16 +160,19 @@ def test_deploy_packed_model_with_entry_point_name(
149160
test_region() not in EI_SUPPORTED_REGIONS, reason="EI isn't supported in that specific region."
150161
)
151162
def test_deploy_model_with_accelerator(
152-
sagemaker_session, cpu_instance_type, pytorch_full_ei_version, pytorch_full_py_version
163+
sagemaker_session,
164+
cpu_instance_type,
165+
pytorch_eia_latest_ei_version,
166+
pytorch_eia_latest_py_version,
153167
):
154168
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
155169
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
156170
pytorch = PyTorchModel(
157171
model_data,
158172
"SageMakerRole",
159173
entry_point=EIA_SCRIPT,
160-
framework_version=pytorch_full_ei_version,
161-
py_version=pytorch_full_py_version,
174+
framework_version=pytorch_eia_latest_ei_version,
175+
py_version=pytorch_eia_latest_py_version,
162176
sagemaker_session=sagemaker_session,
163177
)
164178
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):

tests/integ/test_transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def test_attach_transform_kmeans(sagemaker_session, cpu_instance_type):
154154

155155
def test_transform_pytorch_vpc_custom_model_bucket(
156156
sagemaker_session,
157-
pytorch_full_version,
158-
pytorch_full_py_version,
157+
pytorch_inference_latest_version,
158+
pytorch_inference_latest_py_version,
159159
cpu_instance_type,
160160
custom_bucket_name,
161161
):
@@ -174,8 +174,8 @@ def test_transform_pytorch_vpc_custom_model_bucket(
174174
model_data=model_data,
175175
entry_point=os.path.join(data_dir, "mnist.py"),
176176
role="SageMakerRole",
177-
framework_version=pytorch_full_version,
178-
py_version=pytorch_full_py_version,
177+
framework_version=pytorch_inference_latest_version,
178+
py_version=pytorch_inference_latest_py_version,
179179
sagemaker_session=sagemaker_session,
180180
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
181181
code_location="s3://{}".format(custom_bucket_name),

tests/integ/test_tuner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,10 @@ def test_tuning_chainer(
771771
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
772772
)
773773
def test_attach_tuning_pytorch(
774-
sagemaker_session, cpu_instance_type, pytorch_full_version, pytorch_full_py_version
774+
sagemaker_session,
775+
cpu_instance_type,
776+
pytorch_training_latest_version,
777+
pytorch_training_latest_py_version,
775778
):
776779
mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist")
777780
mnist_script = os.path.join(mnist_dir, "mnist.py")
@@ -780,8 +783,8 @@ def test_attach_tuning_pytorch(
780783
entry_point=mnist_script,
781784
role="SageMakerRole",
782785
instance_count=1,
783-
framework_version=pytorch_full_version,
784-
py_version=pytorch_full_py_version,
786+
framework_version=pytorch_training_latest_version,
787+
py_version=pytorch_training_latest_py_version,
785788
instance_type=cpu_instance_type,
786789
sagemaker_session=sagemaker_session,
787790
)

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,96 @@ def _expected_mxnet_inference_uri(
285285
)
286286

287287

288+
def test_pytorch_training(pytorch_training_version, pytorch_training_py_version):
289+
_test_image_uris(
290+
"pytorch",
291+
pytorch_training_version,
292+
pytorch_training_py_version,
293+
"training",
294+
_expected_pytorch_training_uri,
295+
{"pytorch_version": pytorch_training_version, "py_version": pytorch_training_py_version},
296+
)
297+
298+
299+
def _expected_pytorch_training_uri(pytorch_version, py_version, processor="cpu", region=REGION):
300+
version = Version(pytorch_version)
301+
if version < Version("1.2"):
302+
repo = "sagemaker-pytorch"
303+
else:
304+
repo = "pytorch-training"
305+
306+
return expected_uris.framework_uri(
307+
repo,
308+
pytorch_version,
309+
_sagemaker_or_dlc_account(repo, region),
310+
py_version=py_version,
311+
processor=processor,
312+
region=region,
313+
)
314+
315+
316+
def test_pytorch_inference(pytorch_inference_version, pytorch_inference_py_version):
317+
_test_image_uris(
318+
"pytorch",
319+
pytorch_inference_version,
320+
pytorch_inference_py_version,
321+
"inference",
322+
_expected_pytorch_inference_uri,
323+
{"pytorch_version": pytorch_inference_version, "py_version": pytorch_inference_py_version},
324+
)
325+
326+
327+
def _expected_pytorch_inference_uri(pytorch_version, py_version, processor="cpu", region=REGION):
328+
version = Version(pytorch_version)
329+
if version < Version("1.2"):
330+
repo = "sagemaker-pytorch"
331+
else:
332+
repo = "pytorch-inference"
333+
334+
return expected_uris.framework_uri(
335+
repo,
336+
pytorch_version,
337+
_sagemaker_or_dlc_account(repo, region),
338+
py_version=py_version,
339+
processor=processor,
340+
region=region,
341+
)
342+
343+
344+
def test_pytorch_eia(pytorch_eia_version, pytorch_eia_py_version):
345+
base_args = {
346+
"framework": "pytorch",
347+
"version": pytorch_eia_version,
348+
"py_version": pytorch_eia_py_version,
349+
"image_scope": "inference",
350+
"instance_type": "ml.c4.xlarge",
351+
"accelerator_type": "ml.eia1.medium",
352+
}
353+
354+
uri = image_uris.retrieve(region=REGION, **base_args)
355+
356+
expected = expected_uris.framework_uri(
357+
"pytorch-inference-eia",
358+
pytorch_eia_version,
359+
DLC_ACCOUNT,
360+
py_version=pytorch_eia_py_version,
361+
region=REGION,
362+
)
363+
assert expected == uri
364+
365+
for region, account in DLC_ALTERNATE_REGION_ACCOUNTS.items():
366+
uri = image_uris.retrieve(region=region, **base_args)
367+
368+
expected = expected_uris.framework_uri(
369+
"pytorch-inference-eia",
370+
pytorch_eia_version,
371+
account,
372+
py_version=pytorch_eia_py_version,
373+
region=region,
374+
)
375+
assert expected == uri
376+
377+
288378
def _sagemaker_or_dlc_account(repo, region):
289379
if repo.startswith("sagemaker"):
290380
return (

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def test_retrieve_aliased_version(config_for_framework):
127127
)
128128
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:{}-cpu-py3".format(version) == uri
129129

130+
del config["versions"]["1.1.0"]
131+
uri = image_uris.retrieve(
132+
framework="useless-string",
133+
version=version,
134+
py_version="py3",
135+
instance_type="ml.c4.xlarge",
136+
region="us-west-2",
137+
image_scope="training",
138+
)
139+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:{}-cpu-py3".format(version) == uri
140+
130141

131142
@patch("sagemaker.image_uris.config_for_framework")
132143
def test_retrieve_default_version_if_possible(config_for_framework, caplog):

0 commit comments

Comments
 (0)