Skip to content

Commit aa0f46d

Browse files
authored
Merge branch 'master' into change/remove-setuptools-deprecation
2 parents afe043a + d355d5b commit aa0f46d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2104
-1597
lines changed

CHANGELOG.md

+37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
11
# Changelog
22

3+
## v2.196.0 (2023-10-27)
4+
5+
### Features
6+
7+
* inference instance type conditioned on training instance type
8+
9+
### Bug Fixes and Other Changes
10+
11+
* improved jumpstart tagging
12+
13+
## v2.195.1 (2023-10-26)
14+
15+
### Bug Fixes and Other Changes
16+
17+
* Allow either instance_type or instance_group to be defined in…
18+
* enhance image_uris unit tests
19+
20+
## v2.195.0 (2023-10-25)
21+
22+
### Features
23+
24+
* jumpstart gated model artifacts
25+
* jumpstart extract generated text from response
26+
* jumpstart contruct payload utility
27+
28+
### Bug Fixes and Other Changes
29+
30+
* relax upper bound on urllib in local mode requirements
31+
* bump urllib3 version
32+
* allow smdistributed to be enabled with torch_distributed.
33+
* fix URL links
34+
35+
### Documentation Changes
36+
37+
* remove python 2 reference
38+
* update framework version links
39+
340
## v2.194.0 (2023-10-19)
441

542
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.194.1.dev0
1+
2.196.1.dev0

doc/overview.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ After you train a model, you can save it, and then serve the model as an endpoin
3232
Prepare a Training script
3333
=========================
3434

35-
Your training script must be a Python 2.7 or 3.6 compatible source file.
35+
Your training script must be a 3.6 compatible source file.
3636

3737
The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, including the following:
3838

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
urllib3>=1.26.8,<1.26.15
1+
urllib3>=1.26.8,<3.0.0
22
docker>=5.0.2,<7.0.0
33
PyYAML>=5.4.1,<7

src/sagemaker/estimator.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from sagemaker.utils import instance_supports_kms
7272
from sagemaker.job import _Job
7373
from sagemaker.jumpstart.utils import (
74-
add_jumpstart_tags,
74+
add_jumpstart_uri_tags,
7575
get_jumpstart_base_name_if_jumpstart_model,
7676
update_inference_tags_with_jumpstart_training_tags,
7777
)
@@ -577,9 +577,7 @@ def __init__(
577577
self.entry_point = entry_point
578578
self.dependencies = dependencies or []
579579
self.uploaded_code: Optional[UploadedCode] = None
580-
self.tags = add_jumpstart_tags(
581-
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
582-
)
580+
583581
if self.instance_type in ("local", "local_gpu"):
584582
if self.instance_type == "local_gpu" and self.instance_count > 1:
585583
raise RuntimeError("Distributed Training in Local GPU is not supported")
@@ -592,6 +590,15 @@ def __init__(
592590
else:
593591
self.sagemaker_session = sagemaker_session or Session()
594592

593+
self.tags = (
594+
add_jumpstart_uri_tags(
595+
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
596+
)
597+
if getattr(self.sagemaker_session, "settings", None) is not None
598+
and self.sagemaker_session.settings.include_jumpstart_tags
599+
else tags
600+
)
601+
595602
self.base_job_name = base_job_name
596603
self._current_job_name = None
597604
if (
@@ -3818,6 +3825,7 @@ def _distribution_configuration(self, distribution):
38183825

38193826
mpi_enabled = False
38203827
smdataparallel_enabled = False
3828+
p5_enabled = False
38213829
if "instance_groups" in distribution:
38223830
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
38233831
"instance_groups"
@@ -3862,10 +3870,11 @@ def _distribution_configuration(self, distribution):
38623870
elif isinstance(self.instance_type, str):
38633871
p5_enabled = "p5.48xlarge" in self.instance_type
38643872
else:
3865-
raise ValueError(
3866-
"Invalid object type for instance_type argument. Expected "
3867-
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868-
)
3873+
for instance in self.instance_groups:
3874+
if "p5.48xlarge" in instance._to_request_dict().get("InstanceType", ()):
3875+
p5_enabled = True
3876+
break
3877+
38693878
img_uri = "" if self.image_uri is None else self.image_uri
38703879
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38713880
if (

src/sagemaker/instance_types.py

+15
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def retrieve_default(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
training_instance_type: Optional[str] = None,
3637
) -> str:
3738
"""Retrieves the default instance type for the model matching the given arguments.
3839
@@ -56,6 +57,11 @@ def retrieve_default(
5657
object, used for SageMaker interactions. If not
5758
specified, one is created using the default AWS configuration
5859
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
60+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
61+
instance type used for the training job that produced the fine-tuned weights.
62+
Optionally supply this to get a inference instance type conditioned
63+
on the training instance, to ensure compatability of training artifact to inference
64+
instance. (Default: None).
5965
Returns:
6066
str: The default instance type to use for the model.
6167
@@ -78,6 +84,7 @@ def retrieve_default(
7884
tolerate_vulnerable_model,
7985
tolerate_deprecated_model,
8086
sagemaker_session=sagemaker_session,
87+
training_instance_type=training_instance_type,
8188
)
8289

8390

@@ -89,6 +96,7 @@ def retrieve(
8996
tolerate_vulnerable_model: bool = False,
9097
tolerate_deprecated_model: bool = False,
9198
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
99+
training_instance_type: Optional[str] = None,
92100
) -> List[str]:
93101
"""Retrieves the supported training instance types for the model matching the given arguments.
94102
@@ -110,6 +118,12 @@ def retrieve(
110118
object, used for SageMaker interactions. If not
111119
specified, one is created using the default AWS configuration
112120
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
121+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
122+
instance type used for the training job that produced the fine-tuned weights.
123+
Optionally supply this to get a inference instance type conditioned
124+
on the training instance, to ensure compatability of training artifact to inference
125+
instance. (Default: None).
126+
113127
Returns:
114128
list: The supported instance types to use for the model.
115129
@@ -132,4 +146,5 @@ def retrieve(
132146
tolerate_vulnerable_model,
133147
tolerate_deprecated_model,
134148
sagemaker_session=sagemaker_session,
149+
training_instance_type=training_instance_type,
135150
)

src/sagemaker/jumpstart/artifacts/instance_types.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _retrieve_default_instance_type(
3737
tolerate_vulnerable_model: bool = False,
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
training_instance_type: Optional[str] = None,
4041
) -> str:
4142
"""Retrieves the default instance type for the model.
4243
@@ -60,6 +61,11 @@ def _retrieve_default_instance_type(
6061
object, used for SageMaker interactions. If not
6162
specified, one is created using the default AWS configuration
6263
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
65+
instance type used for the training job that produced the fine-tuned weights.
66+
Optionally supply this to get a inference instance type conditioned
67+
on the training instance, to ensure compatability of training artifact to inference
68+
instance. (Default: None).
6369
Returns:
6470
str: the default instance type to use for the model or None.
6571
@@ -82,7 +88,21 @@ def _retrieve_default_instance_type(
8288
)
8389

8490
if scope == JumpStartScriptScope.INFERENCE:
85-
default_instance_type = model_specs.default_inference_instance_type
91+
instance_specific_default_instance_type = (
92+
(
93+
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
94+
training_instance_type
95+
)
96+
)
97+
if training_instance_type is not None
98+
and getattr(model_specs, "training_instance_type_variants", None) is not None
99+
else None
100+
)
101+
default_instance_type = (
102+
instance_specific_default_instance_type
103+
if instance_specific_default_instance_type is not None
104+
else model_specs.default_inference_instance_type
105+
)
86106
elif scope == JumpStartScriptScope.TRAINING:
87107
default_instance_type = model_specs.default_training_instance_type
88108
else:
@@ -103,6 +123,7 @@ def _retrieve_instance_types(
103123
tolerate_vulnerable_model: bool = False,
104124
tolerate_deprecated_model: bool = False,
105125
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126+
training_instance_type: Optional[str] = None,
106127
) -> List[str]:
107128
"""Retrieves the supported instance types for the model.
108129
@@ -126,6 +147,11 @@ def _retrieve_instance_types(
126147
object, used for SageMaker interactions. If not
127148
specified, one is created using the default AWS configuration
128149
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
151+
instance type used for the training job that produced the fine-tuned weights.
152+
Optionally supply this to get a inference instance type conditioned
153+
on the training instance, to ensure compatability of training artifact to inference
154+
instance. (Default: None).
129155
Returns:
130156
list: the supported instance types to use for the model or None.
131157
@@ -148,8 +174,24 @@ def _retrieve_instance_types(
148174
)
149175

150176
if scope == JumpStartScriptScope.INFERENCE:
151-
instance_types = model_specs.supported_inference_instance_types
177+
default_instance_types = model_specs.supported_inference_instance_types or []
178+
instance_specific_instance_types = (
179+
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
180+
training_instance_type
181+
)
182+
if training_instance_type is not None
183+
and getattr(model_specs, "training_instance_type_variants", None) is not None
184+
else []
185+
)
186+
instance_types = (
187+
instance_specific_instance_types
188+
if len(instance_specific_instance_types) > 0
189+
else default_instance_types
190+
)
191+
152192
elif scope == JumpStartScriptScope.TRAINING:
193+
if training_instance_type is not None:
194+
raise ValueError("Cannot use `training_instance_type` argument " "with training scope.")
153195
instance_types = model_specs.supported_training_instance_types
154196
else:
155197
raise NotImplementedError(

src/sagemaker/jumpstart/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@
154154
if region.gated_content_bucket is not None
155155
}
156156

157+
JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET = JUMPSTART_BUCKET_NAME_SET.union(
158+
JUMPSTART_GATED_BUCKET_NAME_SET
159+
)
160+
157161
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
158162

159163
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

src/sagemaker/jumpstart/enums.py

+3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class JumpStartTag(str, Enum):
7676
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
7777
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"
7878

79+
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
80+
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
81+
7982

8083
class SerializerType(str, Enum):
8184
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,6 @@ def deploy(
988988
use_compiled_model (bool): Flag to select whether to use compiled
989989
(optimized) model. (Default: False).
990990
"""
991-
992991
self.orig_predictor_cls = predictor_cls
993992

994993
sagemaker_session = sagemaker_session or self.sagemaker_session
@@ -1039,6 +1038,7 @@ def deploy(
10391038
dependencies=dependencies,
10401039
git_config=git_config,
10411040
use_compiled_model=use_compiled_model,
1041+
training_instance_type=self.instance_type,
10421042
)
10431043

10441044
predictor = super(JumpStartEstimator, self).deploy(

src/sagemaker/jumpstart/factory/estimator.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@
6060
JumpStartModelInitKwargs,
6161
)
6262
from sagemaker.jumpstart.utils import (
63+
add_jumpstart_model_id_version_tags,
6364
update_dict_if_key_not_present,
6465
resolve_estimator_sagemaker_config_field,
66+
verify_model_region_and_return_specs,
6567
)
6668

6769

@@ -196,6 +198,7 @@ def get_init_kwargs(
196198
estimator_init_kwargs = _add_estimator_extra_kwargs(estimator_init_kwargs)
197199
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
198200
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
201+
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
199202

200203
return estimator_init_kwargs
201204

@@ -277,6 +280,7 @@ def get_deploy_kwargs(
277280
tolerate_vulnerable_model: Optional[bool] = None,
278281
use_compiled_model: Optional[bool] = None,
279282
model_name: Optional[str] = None,
283+
training_instance_type: Optional[str] = None,
280284
) -> JumpStartEstimatorDeployKwargs:
281285
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
282286

@@ -310,7 +314,7 @@ def get_deploy_kwargs(
310314
model_id=model_id,
311315
model_from_estimator=True,
312316
model_version=model_version,
313-
instance_type=model_deploy_kwargs.instance_type, # prevent excess logging
317+
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
314318
region=region,
315319
image_uri=image_uri,
316320
source_dir=source_dir,
@@ -330,6 +334,7 @@ def get_deploy_kwargs(
330334
git_config=git_config,
331335
tolerate_vulnerable_model=tolerate_vulnerable_model,
332336
tolerate_deprecated_model=tolerate_deprecated_model,
337+
training_instance_type=training_instance_type,
333338
)
334339

335340
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
@@ -439,6 +444,26 @@ def _add_instance_type_and_count_to_kwargs(
439444
return kwargs
440445

441446

447+
def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
448+
"""Sets tags in kwargs based on default or override, returns full kwargs."""
449+
450+
full_model_version = verify_model_region_and_return_specs(
451+
model_id=kwargs.model_id,
452+
version=kwargs.model_version,
453+
scope=JumpStartScriptScope.TRAINING,
454+
region=kwargs.region,
455+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
456+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
457+
sagemaker_session=kwargs.sagemaker_session,
458+
).version
459+
460+
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
461+
kwargs.tags = add_jumpstart_model_id_version_tags(
462+
kwargs.tags, kwargs.model_id, full_model_version
463+
)
464+
return kwargs
465+
466+
442467
def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
443468
"""Sets image uri in kwargs based on default or override, returns full kwargs."""
444469

0 commit comments

Comments
 (0)