Skip to content

Commit c6581ff

Browse files
authored
feat: use Neo bucket in speculative decoding data source (aws#1479)
* Use Neo bucket in speculative decoding data source * address comments * format * address comments * add buckets to regional config * remove opt-in regions for neo buckets
1 parent f55e3c9 commit c6581ff

File tree

8 files changed

+267
-66
lines changed

8 files changed

+267
-66
lines changed

src/sagemaker/jumpstart/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -44,31 +44,37 @@
4444
region_name="us-west-2",
4545
content_bucket="jumpstart-cache-prod-us-west-2",
4646
gated_content_bucket="jumpstart-private-cache-prod-us-west-2",
47+
neo_content_bucket="sagemaker-sd-models-prod-us-west-2",
4748
),
4849
JumpStartLaunchedRegionInfo(
4950
region_name="us-east-1",
5051
content_bucket="jumpstart-cache-prod-us-east-1",
5152
gated_content_bucket="jumpstart-private-cache-prod-us-east-1",
53+
neo_content_bucket="sagemaker-sd-models-prod-us-east-1",
5254
),
5355
JumpStartLaunchedRegionInfo(
5456
region_name="us-east-2",
5557
content_bucket="jumpstart-cache-prod-us-east-2",
5658
gated_content_bucket="jumpstart-private-cache-prod-us-east-2",
59+
neo_content_bucket="sagemaker-sd-models-prod-us-east-2",
5760
),
5861
JumpStartLaunchedRegionInfo(
5962
region_name="eu-west-1",
6063
content_bucket="jumpstart-cache-prod-eu-west-1",
6164
gated_content_bucket="jumpstart-private-cache-prod-eu-west-1",
65+
neo_content_bucket="sagemaker-sd-models-prod-eu-west-1",
6266
),
6367
JumpStartLaunchedRegionInfo(
6468
region_name="eu-central-1",
6569
content_bucket="jumpstart-cache-prod-eu-central-1",
6670
gated_content_bucket="jumpstart-private-cache-prod-eu-central-1",
71+
neo_content_bucket="sagemaker-sd-models-prod-eu-central-1",
6772
),
6873
JumpStartLaunchedRegionInfo(
6974
region_name="eu-north-1",
7075
content_bucket="jumpstart-cache-prod-eu-north-1",
7176
gated_content_bucket="jumpstart-private-cache-prod-eu-north-1",
77+
neo_content_bucket="sagemaker-sd-models-prod-eu-north-1",
7278
),
7379
JumpStartLaunchedRegionInfo(
7480
region_name="me-south-1",
@@ -84,11 +90,13 @@
8490
region_name="ap-south-1",
8591
content_bucket="jumpstart-cache-prod-ap-south-1",
8692
gated_content_bucket="jumpstart-private-cache-prod-ap-south-1",
93+
neo_content_bucket="sagemaker-sd-models-prod-ap-south-1",
8794
),
8895
JumpStartLaunchedRegionInfo(
8996
region_name="eu-west-3",
9097
content_bucket="jumpstart-cache-prod-eu-west-3",
9198
gated_content_bucket="jumpstart-private-cache-prod-eu-west-3",
99+
neo_content_bucket="sagemaker-sd-models-prod-eu-west-3",
92100
),
93101
JumpStartLaunchedRegionInfo(
94102
region_name="af-south-1",
@@ -99,6 +107,7 @@
99107
region_name="sa-east-1",
100108
content_bucket="jumpstart-cache-prod-sa-east-1",
101109
gated_content_bucket="jumpstart-private-cache-prod-sa-east-1",
110+
neo_content_bucket="sagemaker-sd-models-prod-sa-east-1",
102111
),
103112
JumpStartLaunchedRegionInfo(
104113
region_name="ap-east-1",
@@ -109,21 +118,25 @@
109118
region_name="ap-northeast-2",
110119
content_bucket="jumpstart-cache-prod-ap-northeast-2",
111120
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
121+
neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-2",
112122
),
113123
JumpStartLaunchedRegionInfo(
114124
region_name="ap-northeast-3",
115125
content_bucket="jumpstart-cache-prod-ap-northeast-3",
116126
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-3",
127+
neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-3",
117128
),
118129
JumpStartLaunchedRegionInfo(
119130
region_name="ap-southeast-3",
120131
content_bucket="jumpstart-cache-prod-ap-southeast-3",
121132
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-3",
133+
neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-3",
122134
),
123135
JumpStartLaunchedRegionInfo(
124136
region_name="eu-west-2",
125137
content_bucket="jumpstart-cache-prod-eu-west-2",
126138
gated_content_bucket="jumpstart-private-cache-prod-eu-west-2",
139+
neo_content_bucket="sagemaker-sd-models-prod-eu-west-2",
127140
),
128141
JumpStartLaunchedRegionInfo(
129142
region_name="eu-south-1",
@@ -134,26 +147,31 @@
134147
region_name="ap-northeast-1",
135148
content_bucket="jumpstart-cache-prod-ap-northeast-1",
136149
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
150+
neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-1",
137151
),
138152
JumpStartLaunchedRegionInfo(
139153
region_name="us-west-1",
140154
content_bucket="jumpstart-cache-prod-us-west-1",
141155
gated_content_bucket="jumpstart-private-cache-prod-us-west-1",
156+
neo_content_bucket="sagemaker-sd-models-prod-us-west-1",
142157
),
143158
JumpStartLaunchedRegionInfo(
144159
region_name="ap-southeast-1",
145160
content_bucket="jumpstart-cache-prod-ap-southeast-1",
146161
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
162+
neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-1",
147163
),
148164
JumpStartLaunchedRegionInfo(
149165
region_name="ap-southeast-2",
150166
content_bucket="jumpstart-cache-prod-ap-southeast-2",
151167
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
168+
neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-2",
152169
),
153170
JumpStartLaunchedRegionInfo(
154171
region_name="ca-central-1",
155172
content_bucket="jumpstart-cache-prod-ca-central-1",
156173
gated_content_bucket="jumpstart-private-cache-prod-ca-central-1",
174+
neo_content_bucket="sagemaker-sd-models-prod-ca-central-1",
157175
),
158176
JumpStartLaunchedRegionInfo(
159177
region_name="cn-north-1",
@@ -184,6 +202,7 @@
184202
)
185203

186204
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
205+
NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
187206

188207
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
189208
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
@@ -201,6 +220,7 @@
201220
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
202221
)
203222
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"
223+
ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE = "AWS_NEO_CONTENT_BUCKET_OVERRIDE"
204224

205225
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
206226

src/sagemaker/jumpstart/factory/model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from sagemaker.jumpstart.utils import (
4747
add_jumpstart_model_info_tags,
4848
get_default_jumpstart_session_with_user_agent_suffix,
49+
get_neo_content_bucket,
4950
update_dict_if_key_not_present,
5051
resolve_model_sagemaker_config_field,
5152
verify_model_region_and_return_specs,
@@ -631,14 +632,16 @@ def _add_additional_model_data_sources_to_kwargs(
631632
model_type=kwargs.model_type,
632633
config_name=kwargs.config_name,
633634
)
634-
635-
additional_data_sources = specs.get_additional_s3_data_sources()
635+
# Append speculative decoding data source from metadata
636+
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
637+
for data_source in speculative_decoding_data_sources:
638+
data_source.s3_data_source.set_bucket(get_neo_content_bucket())
636639
api_shape_additional_model_data_sources = (
637640
[
638641
camel_case_to_pascal_case(data_source.to_json())
639-
for data_source in additional_data_sources
642+
for data_source in speculative_decoding_data_sources
640643
]
641-
if specs.get_additional_s3_data_sources()
644+
if specs.get_speculative_decoding_s3_data_sources()
642645
else None
643646
)
644647

src/sagemaker/jumpstart/types.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18-
from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict
18+
from sagemaker.utils import (
19+
S3_PREFIX,
20+
get_instance_type_family,
21+
format_tags,
22+
Tags,
23+
deep_override_dict,
24+
)
1925
from sagemaker.model_metrics import ModelMetrics
2026
from sagemaker.metadata_properties import MetadataProperties
2127
from sagemaker.drift_check_baselines import DriftCheckBaselines
@@ -116,10 +122,14 @@ class JumpStartS3FileType(str, Enum):
116122
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
117123
"""Data class for launched region info."""
118124

119-
__slots__ = ["content_bucket", "region_name", "gated_content_bucket"]
125+
__slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"]
120126

121127
def __init__(
122-
self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None
128+
self,
129+
content_bucket: str,
130+
region_name: str,
131+
gated_content_bucket: Optional[str] = None,
132+
neo_content_bucket: Optional[str] = None,
123133
):
124134
"""Instantiates JumpStartLaunchedRegionInfo object.
125135
@@ -128,10 +138,13 @@ def __init__(
128138
region_name (str): Name of JumpStart launched region.
129139
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
130140
optionally associated with region.
141+
neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket
142+
optionally associated with region.
131143
"""
132144
self.content_bucket = content_bucket
133145
self.gated_content_bucket = gated_content_bucket
134146
self.region_name = region_name
147+
self.neo_content_bucket = neo_content_bucket
135148

136149

137150
class JumpStartModelHeader(JumpStartDataHolderType):
@@ -848,6 +861,21 @@ def to_json(self) -> Dict[str, Any]:
848861
json_obj[att] = cur_val
849862
return json_obj
850863

864+
def set_bucket(self, bucket: str) -> None:
865+
"""Sets bucket name from S3 URI."""
866+
867+
if self.s3_uri.startswith(S3_PREFIX):
868+
s3_path = self.s3_uri[len(S3_PREFIX) :]
869+
old_bucket = s3_path.split("/")[0]
870+
key = s3_path[len(old_bucket) :]
871+
self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201
872+
return
873+
874+
if not bucket.endswith("/"):
875+
bucket += "/"
876+
877+
self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201
878+
851879

852880
class AdditionalModelDataSource(JumpStartDataHolderType):
853881
"""Data class of additional model data source mirrors CreateModel API."""
@@ -1638,8 +1666,10 @@ def supports_incremental_training(self) -> bool:
16381666
"""Returns True if the model supports incremental training."""
16391667
return self.incremental_training_supported
16401668

1641-
def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:
1669+
def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]:
16421670
"""Returns data sources for speculative decoding."""
1671+
if not self.hosting_additional_data_sources:
1672+
return []
16431673
return self.hosting_additional_data_sources.speculative_decoding or []
16441674

16451675
def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:

src/sagemaker/jumpstart/utils.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def get_jumpstart_content_bucket(
156156
except KeyError:
157157
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
158158
raise ValueError(
159-
f"Unable to get content bucket for JumpStart in {region} region. "
159+
f"Unable to get content bucket for Neo in {region} region. "
160160
f"{formatted_launched_regions_str}"
161161
)
162162

@@ -170,6 +170,34 @@ def get_jumpstart_content_bucket(
170170
return bucket_to_return
171171

172172

173+
def get_neo_content_bucket(
174+
region: str = constants.NEO_DEFAULT_REGION_NAME,
175+
) -> str:
176+
"""Returns the regionalized S3 bucket name for Neo service.
177+
178+
Raises:
179+
ValueError: If Neo is not launched in ``region``.
180+
"""
181+
182+
bucket_to_return: Optional[str] = None
183+
if (
184+
constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ
185+
and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0
186+
):
187+
bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]
188+
info_log = f"Using Neo bucket override: '{bucket_to_return}'"
189+
constants.JUMPSTART_LOGGER.info(info_log)
190+
else:
191+
try:
192+
bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
193+
region
194+
].neo_content_bucket
195+
except KeyError:
196+
raise ValueError(f"Unable to get content bucket for Neo in {region} region.")
197+
198+
return bucket_to_return
199+
200+
173201
def get_formatted_manifest(
174202
manifest: List[Dict],
175203
) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]:

tests/unit/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -7806,6 +7806,7 @@
78067806
},
78077807
},
78087808
"gpu-accelerated": {
7809+
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
78097810
"hosting_instance_type_variants": {
78107811
"regional_aliases": {
78117812
"us-west-2": {

tests/unit/sagemaker/jumpstart/model/test_model.py

+66
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,72 @@ def test_model_set_deployment_config(
16871687
endpoint_logging=False,
16881688
)
16891689

1690+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
1691+
@mock.patch(
1692+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1693+
)
1694+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1695+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1696+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
1697+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1698+
def test_model_deployment_config_additional_model_data_source(
1699+
self,
1700+
mock_model_init: mock.Mock,
1701+
mock_model_deploy: mock.Mock,
1702+
mock_get_model_specs: mock.Mock,
1703+
mock_session: mock.Mock,
1704+
mock_get_manifest: mock.Mock,
1705+
):
1706+
mock_session.return_value = sagemaker_session
1707+
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
1708+
mock_get_manifest.side_effect = (
1709+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
1710+
)
1711+
mock_model_deploy.return_value = default_predictor
1712+
1713+
model_id, _ = "pytorch-eqa-bert-base-cased", "*"
1714+
1715+
model = JumpStartModel(model_id=model_id, config_name="gpu-accelerated")
1716+
1717+
mock_model_init.assert_called_once_with(
1718+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
1719+
"pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04",
1720+
model_data="s3://jumpstart-cache-prod-us-west-2/pytorch-infer/"
1721+
"infer-pytorch-eqa-bert-base-cased.tar.gz",
1722+
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/"
1723+
"pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz",
1724+
entry_point="inference.py",
1725+
predictor_cls=Predictor,
1726+
role=execution_role,
1727+
sagemaker_session=sagemaker_session,
1728+
enable_network_isolation=False,
1729+
additional_model_data_sources=[
1730+
{
1731+
"ChannelName": "draft_model_name",
1732+
"S3DataSource": {
1733+
"CompressionType": "None",
1734+
"S3DataType": "S3Prefix",
1735+
"S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/",
1736+
"ModelAccessConfig": {"AcceptEula": False},
1737+
},
1738+
}
1739+
],
1740+
)
1741+
1742+
model.deploy()
1743+
1744+
mock_model_deploy.assert_called_once_with(
1745+
initial_instance_count=1,
1746+
instance_type="ml.p2.xlarge",
1747+
tags=[
1748+
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1749+
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1750+
{"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-accelerated"},
1751+
],
1752+
wait=True,
1753+
endpoint_logging=False,
1754+
)
1755+
16901756
@mock.patch(
16911757
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
16921758
)

0 commit comments

Comments
 (0)