Skip to content

Commit d7983aa

Browse files
ShiboXingShibo Xing
authored and
Namrata Madan
committed
feature: add p2 deprecation for PT>=1.13 (aws#3567)
* feat: (squash and push to use git hooks) add p2 deprecation for retrieving pytorch img uri style: format _validate_instance_deprecation using black doc: add docstring for _validate_instance_deprecation doc: update deprecation error string doc: update docstring for _validate_instance_deprecation * fix: use Version packaging to compare version * test: add E1136 to pylintrc disable list Co-authored-by: Shibo Xing <[email protected]>
1 parent 989c074 commit d7983aa

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ disable=
9494
useless-object-inheritance, # TODO: Enable this check and fix code once Python 2 is no longer supported.
9595
super-with-arguments,
9696
raise-missing-from,
97+
E1136,
9798

9899
[REPORTS]
99100
# Set the output format. Available formats are text, parseable, colorized, msvs

src/sagemaker/image_uris.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import re
2020
from typing import Optional
21+
from packaging.version import Version
2122

2223
from sagemaker import utils
2324
from sagemaker.jumpstart.utils import is_jumpstart_model_input
@@ -232,6 +233,7 @@ def retrieve(
232233

233234
if repo == f"{framework}-inference-graviton":
234235
container_version = f"{container_version}-sagemaker"
236+
_validate_instance_deprecation(framework, instance_type, version)
235237

236238
tag = _get_image_tag(
237239
container_version,
@@ -365,6 +367,20 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
365367
return config if "scope" in config else config[image_scope]
366368

367369

370+
def _validate_instance_deprecation(framework, instance_type, version):
371+
"""Check if instance type is deprecated for a certain framework with a certain version"""
372+
if (
373+
framework == "pytorch"
374+
and _get_instance_type_family(instance_type) == "p2"
375+
and Version(version) >= Version("1.13")
376+
):
377+
raise ValueError(
378+
"P2 instances have been deprecated for sagemaker jobs with PyTorch 1.13 and above. "
379+
"For information about supported instance types please refer to "
380+
"https://aws.amazon.com/sagemaker/pricing/"
381+
)
382+
383+
368384
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
369385
"""Validate if framework is supported for the instance_type"""
370386
if (

0 commit comments

Comments
 (0)