Skip to content

Commit afe043a

Browse files
authored
Merge branch 'master' into change/remove-setuptools-deprecation
2 parents 5719166 + 610d00f commit afe043a

File tree

17 files changed

+524
-22
lines changed

17 files changed

+524
-22
lines changed

requirements/extras/test_requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pandas>=1.3.5,<1.5
2424
scikit-learn==1.3.0
2525
cloudpickle==2.2.1
2626
scipy==1.10.1
27-
urllib3>=1.26.8,<1.26.15
27+
urllib3>=1.26.8,<3.0.0
2828
docker>=5.0.2,<7.0.0
2929
PyYAML==6.0
3030
pyspark==3.3.1

src/sagemaker/chainer/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
framework_version (str): Chainer version you want to use for
109109
executing your model training code. Defaults to ``None``. Required unless
110110
``image_uri`` is provided. List of supported versions:
111-
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
111+
https://sagemaker.readthedocs.io/en/stable/frameworks/chainer/using_chainer.html#using-chainer-with-the-sagemaker-python-sdk.
112112
image_uri (str): If specified, the estimator will use this image
113113
for training and hosting, instead of selecting the appropriate
114114
SageMaker official image based on framework_version and

src/sagemaker/estimator.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from sagemaker.workflow import is_pipeline_variable
103103
from sagemaker.workflow.entities import PipelineVariable
104+
from sagemaker.workflow.parameters import ParameterString
104105
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
105106

106107
logger = logging.getLogger(__name__)
@@ -3198,6 +3199,7 @@ class Framework(EstimatorBase):
31983199
"""
31993200

32003201
_framework_name = None
3202+
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121", "2.0-gpu-py310-cu121")
32013203

32023204
def __init__(
32033205
self,
@@ -3843,16 +3845,43 @@ def _distribution_configuration(self, distribution):
38433845
"custom_mpi_options", ""
38443846
)
38453847

3846-
if get_mp_parameters(distribution):
3847-
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3848-
3849-
elif "modelparallel" in distribution.get("smdistributed", {}):
3850-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
3851-
38523848
if "smdistributed" in distribution:
38533849
# smdistributed strategy selected
3850+
if get_mp_parameters(distribution):
3851+
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3852+
# first make sure torch_distributed is enabled if instance type is p5
3853+
torch_distributed_enabled = False
3854+
if "torch_distributed" in distribution:
3855+
torch_distributed_enabled = distribution.get("torch_distributed").get(
3856+
"enabled", False
3857+
)
38543858
smdistributed = distribution["smdistributed"]
38553859
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
3860+
if isinstance(self.instance_type, ParameterString):
3861+
p5_enabled = "p5.48xlarge" in self.instance_type.default_value
3862+
elif isinstance(self.instance_type, str):
3863+
p5_enabled = "p5.48xlarge" in self.instance_type
3864+
else:
3865+
raise ValueError(
3866+
"Invalid object type for instance_type argument. Expected "
3867+
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868+
)
3869+
img_uri = "" if self.image_uri is None else self.image_uri
3870+
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
3871+
if (
3872+
unsupported_image in img_uri and not torch_distributed_enabled
3873+
): # disabling DLC images with CUDA12
3874+
raise ValueError(
3875+
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
3876+
"(Could be due to CUDA version being greater than 11.)"
3877+
)
3878+
if (
3879+
not torch_distributed_enabled and p5_enabled
3880+
): # disabling p5 when torch distributed is disabled
3881+
raise ValueError(
3882+
"SMModelParallel and SMDataParallel currently do not support p5 instances."
3883+
)
3884+
# smdistributed strategy selected with supported instance type
38563885
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
38573886
distribution_config[self.INSTANCE_TYPE] = self.instance_type
38583887
if smdataparallel_enabled:

src/sagemaker/huggingface/model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,10 @@ def __init__(
138138
unless ``image_uri`` is provided.
139139
tensorflow_version (str): TensorFlow version you want to use for
140140
executing your inference code. Defaults to ``None``. Required unless
141-
``pytorch_version`` is provided. List of supported versions:
142-
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
141+
``pytorch_version`` is provided. The current supported version is ``2.4.1``.
143142
pytorch_version (str): PyTorch version you want to use for
144143
executing your inference code. Defaults to ``None``. Required unless
145-
``tensorflow_version`` is provided. List of supported versions:
146-
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
144+
``tensorflow_version`` is provided. The current supported versions are ``1.7.1`` and ``1.6.0``.
147145
py_version (str): Python version you want to use for executing your
148146
model training code. Defaults to ``None``. Required unless
149147
``image_uri`` is provided.

src/sagemaker/jumpstart/accessors.py

+11
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class JumpStartModelsAccessor(object):
127127
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
128128

129129
_content_bucket: Optional[str] = None
130+
_gated_content_bucket: Optional[str] = None
130131

131132
_cache_kwargs: Dict[str, Any] = {}
132133

@@ -140,6 +141,16 @@ def get_jumpstart_content_bucket() -> Optional[str]:
140141
"""Returns JumpStart content bucket."""
141142
return JumpStartModelsAccessor._content_bucket
142143

144+
@staticmethod
145+
def set_jumpstart_gated_content_bucket(gated_content_bucket: str) -> None:
146+
"""Sets JumpStart gated content bucket."""
147+
JumpStartModelsAccessor._gated_content_bucket = gated_content_bucket
148+
149+
@staticmethod
150+
def get_jumpstart_gated_content_bucket() -> Optional[str]:
151+
"""Returns JumpStart gated content bucket."""
152+
return JumpStartModelsAccessor._gated_content_bucket
153+
143154
@staticmethod
144155
def _validate_and_mutate_region_cache_kwargs(
145156
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None

src/sagemaker/jumpstart/artifacts/model_uris.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from sagemaker.jumpstart.utils import (
2727
get_jumpstart_content_bucket,
28+
get_jumpstart_gated_content_bucket,
2829
verify_model_region_and_return_specs,
2930
)
3031
from sagemaker.session import Session
@@ -157,9 +158,16 @@ def _retrieve_model_uri(
157158

158159
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
159160

160-
bucket = os.environ.get(
161-
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
162-
) or get_jumpstart_content_bucket(region)
161+
default_jumpstart_bucket: str = (
162+
get_jumpstart_gated_content_bucket(region)
163+
if model_specs.gated_bucket
164+
else get_jumpstart_content_bucket(region)
165+
)
166+
167+
bucket = (
168+
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
169+
or default_jumpstart_bucket
170+
)
163171

164172
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
165173

src/sagemaker/jumpstart/constants.py

+25
Original file line numberDiff line numberDiff line change
@@ -38,82 +38,102 @@
3838
JumpStartLaunchedRegionInfo(
3939
region_name="us-west-2",
4040
content_bucket="jumpstart-cache-prod-us-west-2",
41+
gated_content_bucket="jumpstart-private-cache-prod-us-west-2",
4142
),
4243
JumpStartLaunchedRegionInfo(
4344
region_name="us-east-1",
4445
content_bucket="jumpstart-cache-prod-us-east-1",
46+
gated_content_bucket="jumpstart-private-cache-prod-us-east-1",
4547
),
4648
JumpStartLaunchedRegionInfo(
4749
region_name="us-east-2",
4850
content_bucket="jumpstart-cache-prod-us-east-2",
51+
gated_content_bucket="jumpstart-private-cache-prod-us-east-2",
4952
),
5053
JumpStartLaunchedRegionInfo(
5154
region_name="eu-west-1",
5255
content_bucket="jumpstart-cache-prod-eu-west-1",
56+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-1",
5357
),
5458
JumpStartLaunchedRegionInfo(
5559
region_name="eu-central-1",
5660
content_bucket="jumpstart-cache-prod-eu-central-1",
61+
gated_content_bucket="jumpstart-private-cache-prod-eu-central-1",
5762
),
5863
JumpStartLaunchedRegionInfo(
5964
region_name="eu-north-1",
6065
content_bucket="jumpstart-cache-prod-eu-north-1",
66+
gated_content_bucket="jumpstart-private-cache-prod-eu-north-1",
6167
),
6268
JumpStartLaunchedRegionInfo(
6369
region_name="me-south-1",
6470
content_bucket="jumpstart-cache-prod-me-south-1",
71+
gated_content_bucket="jumpstart-private-cache-prod-me-south-1",
6572
),
6673
JumpStartLaunchedRegionInfo(
6774
region_name="ap-south-1",
6875
content_bucket="jumpstart-cache-prod-ap-south-1",
76+
gated_content_bucket="jumpstart-private-cache-prod-ap-south-1",
6977
),
7078
JumpStartLaunchedRegionInfo(
7179
region_name="eu-west-3",
7280
content_bucket="jumpstart-cache-prod-eu-west-3",
81+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-3",
7382
),
7483
JumpStartLaunchedRegionInfo(
7584
region_name="af-south-1",
7685
content_bucket="jumpstart-cache-prod-af-south-1",
86+
gated_content_bucket="jumpstart-private-cache-prod-af-south-1",
7787
),
7888
JumpStartLaunchedRegionInfo(
7989
region_name="sa-east-1",
8090
content_bucket="jumpstart-cache-prod-sa-east-1",
91+
gated_content_bucket="jumpstart-private-cache-prod-sa-east-1",
8192
),
8293
JumpStartLaunchedRegionInfo(
8394
region_name="ap-east-1",
8495
content_bucket="jumpstart-cache-prod-ap-east-1",
96+
gated_content_bucket="jumpstart-private-cache-prod-ap-east-1",
8597
),
8698
JumpStartLaunchedRegionInfo(
8799
region_name="ap-northeast-2",
88100
content_bucket="jumpstart-cache-prod-ap-northeast-2",
101+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
89102
),
90103
JumpStartLaunchedRegionInfo(
91104
region_name="eu-west-2",
92105
content_bucket="jumpstart-cache-prod-eu-west-2",
106+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-2",
93107
),
94108
JumpStartLaunchedRegionInfo(
95109
region_name="eu-south-1",
96110
content_bucket="jumpstart-cache-prod-eu-south-1",
111+
gated_content_bucket="jumpstart-private-cache-prod-eu-south-1",
97112
),
98113
JumpStartLaunchedRegionInfo(
99114
region_name="ap-northeast-1",
100115
content_bucket="jumpstart-cache-prod-ap-northeast-1",
116+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
101117
),
102118
JumpStartLaunchedRegionInfo(
103119
region_name="us-west-1",
104120
content_bucket="jumpstart-cache-prod-us-west-1",
121+
gated_content_bucket="jumpstart-private-cache-prod-us-west-1",
105122
),
106123
JumpStartLaunchedRegionInfo(
107124
region_name="ap-southeast-1",
108125
content_bucket="jumpstart-cache-prod-ap-southeast-1",
126+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
109127
),
110128
JumpStartLaunchedRegionInfo(
111129
region_name="ap-southeast-2",
112130
content_bucket="jumpstart-cache-prod-ap-southeast-2",
131+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
113132
),
114133
JumpStartLaunchedRegionInfo(
115134
region_name="ca-central-1",
116135
content_bucket="jumpstart-cache-prod-ca-central-1",
136+
gated_content_bucket="jumpstart-private-cache-prod-ca-central-1",
117137
),
118138
JumpStartLaunchedRegionInfo(
119139
region_name="cn-north-1",
@@ -128,6 +148,11 @@
128148
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
129149

130150
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
151+
JUMPSTART_GATED_BUCKET_NAME_SET = {
152+
region.gated_content_bucket
153+
for region in JUMPSTART_LAUNCHED_REGIONS
154+
if region.gated_content_bucket is not None
155+
}
131156

132157
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
133158

src/sagemaker/jumpstart/types.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,21 @@ class JumpStartS3FileType(str, Enum):
107107
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
108108
"""Data class for launched region info."""
109109

110-
__slots__ = ["content_bucket", "region_name"]
110+
__slots__ = ["content_bucket", "region_name", "gated_content_bucket"]
111111

112-
def __init__(self, content_bucket: str, region_name: str):
112+
def __init__(
113+
self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None
114+
):
113115
"""Instantiates JumpStartLaunchedRegionInfo object.
114116
115117
Args:
116118
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
117119
region_name (str): Name of JumpStart launched region.
120+
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
121+
optionally associated with region.
118122
"""
119123
self.content_bucket = content_bucket
124+
self.gated_content_bucket = gated_content_bucket
120125
self.region_name = region_name
121126

122127

@@ -691,6 +696,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
691696
"hosting_instance_type_variants",
692697
"training_instance_type_variants",
693698
"default_payloads",
699+
"gated_bucket",
694700
]
695701

696702
def __init__(self, spec: Dict[str, Any]):
@@ -767,6 +773,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
767773
if json_obj.get("default_payloads")
768774
else None
769775
)
776+
self.gated_bucket = json_obj.get("gated_bucket", False)
770777
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
771778
self.inference_enable_network_isolation: bool = json_obj.get(
772779
"inference_enable_network_isolation", False

src/sagemaker/jumpstart/utils.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,64 @@ def get_jumpstart_launched_regions_message() -> str:
6363
return f"JumpStart is available in {formatted_launched_regions_str} regions."
6464

6565

66+
def get_jumpstart_gated_content_bucket(
67+
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
68+
) -> str:
69+
"""Returns regionalized private content bucket name for JumpStart.
70+
71+
Raises:
72+
ValueError: If JumpStart is not launched in ``region`` or private content
73+
unavailable in that region.
74+
"""
75+
76+
old_gated_content_bucket: Optional[
77+
str
78+
] = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
79+
80+
info_logs: List[str] = []
81+
82+
gated_bucket_to_return: Optional[str] = None
83+
if (
84+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
85+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
86+
):
87+
gated_bucket_to_return = os.environ[
88+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE
89+
]
90+
info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'")
91+
else:
92+
try:
93+
gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
94+
region
95+
].gated_content_bucket
96+
if gated_bucket_to_return is None:
97+
raise ValueError(
98+
f"No private content bucket for JumpStart exists in {region} region."
99+
)
100+
except KeyError:
101+
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
102+
raise ValueError(
103+
f"Unable to get private content bucket for JumpStart in {region} region. "
104+
f"{formatted_launched_regions_str}"
105+
)
106+
107+
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)
108+
109+
if gated_bucket_to_return != old_gated_content_bucket:
110+
accessors.JumpStartModelsAccessor.reset_cache()
111+
for info_log in info_logs:
112+
constants.JUMPSTART_LOGGER.info(info_log)
113+
114+
return gated_bucket_to_return
115+
116+
66117
def get_jumpstart_content_bucket(
67118
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
68119
) -> str:
69120
"""Returns regionalized content bucket name for JumpStart.
70121
71122
Raises:
72-
RuntimeError: If JumpStart is not launched in ``region``.
123+
ValueError: If JumpStart is not launched in ``region``.
73124
"""
74125

75126
old_content_bucket: Optional[

src/sagemaker/mxnet/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
framework_version (str): MXNet version you want to use for executing
7878
your model training code. Defaults to `None`. Required unless
7979
``image_uri`` is provided. List of supported versions.
80-
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
80+
https://aws.amazon.com/releasenotes/available-deep-learning-containers-images/.
8181
py_version (str): Python version you want to use for executing your
8282
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
8383
unless ``image_uri`` is provided.

src/sagemaker/pytorch/estimator.py

+3
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326326
if self.instance_type is not None:
327327
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
328328
elif torch_distributed_enabled:
329+
if "smdistributed" in distribution:
330+
# Enable torch_distributed for smdistributed.
331+
distribution_config = self._distribution_configuration(distribution=distribution)
329332
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
330333
if self.instance_type is not None:
331334
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type

0 commit comments

Comments
 (0)