Skip to content

Commit 0d25d89

Browse files
authored
Merge branch 'master' into master
2 parents 9bc4a4b + 30fe0ee commit 0d25d89

File tree

14 files changed

+443
-17
lines changed

14 files changed

+443
-17
lines changed

src/sagemaker/_studio.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def _find_config(working_dir=None):
6565
wd = Path(working_dir) if working_dir else Path.cwd()
6666

6767
path = None
68-
while path is None and not wd.match("/"):
68+
69+
# Get the root of the current working directory for both Windows and Unix-like systems
70+
root = Path(wd.anchor)
71+
while path is None and wd != root:
6972
candidate = wd / STUDIO_PROJECT_CONFIG
7073
if Path.exists(candidate):
7174
path = candidate

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@
8585
"2.2": "2.2.0",
8686
"2.3": "2.3.0",
8787
"2.4": "2.4.0",
88-
"2.5": "2.5.1"
88+
"2.5": "2.5.1",
89+
"2.6": "2.6.0"
8990
},
9091
"versions": {
9192
"0.4.0": {
@@ -1253,6 +1254,50 @@
12531254
"us-west-2": "763104351884"
12541255
},
12551256
"repository": "pytorch-inference"
1257+
},
1258+
"2.6.0": {
1259+
"py_versions": [
1260+
"py312"
1261+
],
1262+
"registries": {
1263+
"af-south-1": "626614931356",
1264+
"ap-east-1": "871362719292",
1265+
"ap-northeast-1": "763104351884",
1266+
"ap-northeast-2": "763104351884",
1267+
"ap-northeast-3": "364406365360",
1268+
"ap-south-1": "763104351884",
1269+
"ap-south-2": "772153158452",
1270+
"ap-southeast-1": "763104351884",
1271+
"ap-southeast-2": "763104351884",
1272+
"ap-southeast-3": "907027046896",
1273+
"ap-southeast-4": "457447274322",
1274+
"ap-southeast-5": "550225433462",
1275+
"ap-southeast-7": "590183813437",
1276+
"ca-central-1": "763104351884",
1277+
"ca-west-1": "204538143572",
1278+
"cn-north-1": "727897471807",
1279+
"cn-northwest-1": "727897471807",
1280+
"eu-central-1": "763104351884",
1281+
"eu-central-2": "380420809688",
1282+
"eu-north-1": "763104351884",
1283+
"eu-south-1": "692866216735",
1284+
"eu-south-2": "503227376785",
1285+
"eu-west-1": "763104351884",
1286+
"eu-west-2": "763104351884",
1287+
"eu-west-3": "763104351884",
1288+
"il-central-1": "780543022126",
1289+
"me-central-1": "914824155844",
1290+
"me-south-1": "217643126080",
1291+
"mx-central-1": "637423239942",
1292+
"sa-east-1": "763104351884",
1293+
"us-east-1": "763104351884",
1294+
"us-east-2": "763104351884",
1295+
"us-gov-east-1": "446045086412",
1296+
"us-gov-west-1": "442386744353",
1297+
"us-west-1": "763104351884",
1298+
"us-west-2": "763104351884"
1299+
},
1300+
"repository": "pytorch-inference"
12561301
}
12571302
}
12581303
},
@@ -1628,7 +1673,8 @@
16281673
"2.2": "2.2.0",
16291674
"2.3": "2.3.0",
16301675
"2.4": "2.4.0",
1631-
"2.5": "2.5.1"
1676+
"2.5": "2.5.1",
1677+
"2.6": "2.6.0"
16321678
},
16331679
"versions": {
16341680
"0.4.0": {
@@ -2801,6 +2847,50 @@
28012847
"us-west-2": "763104351884"
28022848
},
28032849
"repository": "pytorch-training"
2850+
},
2851+
"2.6.0": {
2852+
"py_versions": [
2853+
"py312"
2854+
],
2855+
"registries": {
2856+
"af-south-1": "626614931356",
2857+
"ap-east-1": "871362719292",
2858+
"ap-northeast-1": "763104351884",
2859+
"ap-northeast-2": "763104351884",
2860+
"ap-northeast-3": "364406365360",
2861+
"ap-south-1": "763104351884",
2862+
"ap-south-2": "772153158452",
2863+
"ap-southeast-1": "763104351884",
2864+
"ap-southeast-2": "763104351884",
2865+
"ap-southeast-3": "907027046896",
2866+
"ap-southeast-4": "457447274322",
2867+
"ap-southeast-5": "550225433462",
2868+
"ap-southeast-7": "590183813437",
2869+
"ca-central-1": "763104351884",
2870+
"ca-west-1": "204538143572",
2871+
"cn-north-1": "727897471807",
2872+
"cn-northwest-1": "727897471807",
2873+
"eu-central-1": "763104351884",
2874+
"eu-central-2": "380420809688",
2875+
"eu-north-1": "763104351884",
2876+
"eu-south-1": "692866216735",
2877+
"eu-south-2": "503227376785",
2878+
"eu-west-1": "763104351884",
2879+
"eu-west-2": "763104351884",
2880+
"eu-west-3": "763104351884",
2881+
"il-central-1": "780543022126",
2882+
"me-central-1": "914824155844",
2883+
"me-south-1": "217643126080",
2884+
"mx-central-1": "637423239942",
2885+
"sa-east-1": "763104351884",
2886+
"us-east-1": "763104351884",
2887+
"us-east-2": "763104351884",
2888+
"us-gov-east-1": "446045086412",
2889+
"us-gov-west-1": "442386744353",
2890+
"us-west-1": "763104351884",
2891+
"us-west-2": "763104351884"
2892+
},
2893+
"repository": "pytorch-training"
28042894
}
28052895
}
28062896
}

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
JUMPSTART_LOGGER,
5757
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5858
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59+
JUMPSTART_MODEL_HUB_NAME,
5960
)
6061
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6162
from sagemaker.jumpstart.factory import model
@@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs(
313314
):
314315
"""Adds HubAccessConfig to kwargs inputs"""
315316

317+
dataset_uri = kwargs.specs.default_training_dataset_uri
316318
if isinstance(kwargs.inputs, str):
317-
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
319+
if dataset_uri is not None and dataset_uri == kwargs.inputs:
320+
kwargs.inputs = TrainingInput(
321+
s3_data=kwargs.inputs, hub_access_config=hub_access_config
322+
)
318323
elif isinstance(kwargs.inputs, TrainingInput):
319-
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
324+
if (
325+
dataset_uri is not None
326+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
327+
):
328+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320329
elif isinstance(kwargs.inputs, dict):
321330
for k, v in kwargs.inputs.items():
322331
if isinstance(v, str):
323-
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
332+
training_input = TrainingInput(s3_data=v)
333+
if dataset_uri is not None and dataset_uri == v:
334+
training_input.add_hub_access_config(hub_access_config=hub_access_config)
335+
kwargs.inputs[k] = training_input
324336
elif isinstance(kwargs.inputs, TrainingInput):
325-
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
337+
if (
338+
dataset_uri is not None
339+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
340+
):
341+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326342

327343
return kwargs
328344

@@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs(
616632

617633
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
618634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
619-
620-
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
635+
# hub_arn is by default None unless the user specifies the hub_name
636+
# If no hub_name is specified, it is assumed the public hub
637+
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
638+
if (
639+
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
640+
or is_private_hub
641+
):
621642
default_model_uri = model_uris.retrieve(
622643
model_scope=JumpStartScriptScope.TRAINING,
623644
instance_type=kwargs.instance_type,

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
630630
if json_obj.get("ValidationSupported")
631631
else None
632632
)
633-
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
634633
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
635634
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
636635
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
671670
)
672671

673672
if self.training_supported:
673+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
674+
"DefaultTrainingDatasetUri"
675+
)
674676
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
675677
"TrainingModelPackageArtifactUri"
676678
)

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
279279
specs["training_instance_type_variants"] = (
280280
hub_model_document.training_instance_type_variants
281281
)
282+
if hub_model_document.default_training_dataset_uri:
283+
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
284+
hub_model_document.default_training_dataset_uri
285+
)
286+
specs["default_training_dataset_key"] = default_training_dataset_key
287+
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
282288
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
2121
from sagemaker.jumpstart import constants
2222
from packaging.specifiers import SpecifierSet, InvalidSpecifier
23+
from packaging import version
2324

2425
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
2526

@@ -162,9 +163,12 @@ def get_hub_model_version(
162163
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
163164

164165
try:
165-
hub_content_summaries = sagemaker_session.list_hub_content_versions(
166-
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
167-
).get("HubContentSummaries")
166+
hub_content_summaries = _list_hub_content_versions_helper(
167+
hub_name=hub_name,
168+
hub_content_name=hub_model_name,
169+
hub_content_type=hub_model_type,
170+
sagemaker_session=sagemaker_session,
171+
)
168172
except Exception as ex:
169173
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
170174

@@ -181,13 +185,34 @@ def get_hub_model_version(
181185
raise
182186

183187

188+
def _list_hub_content_versions_helper(
189+
hub_name, hub_content_name, hub_content_type, sagemaker_session
190+
):
191+
all_hub_content_summaries = []
192+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
193+
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
194+
)
195+
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
196+
while "NextToken" in list_hub_content_versions_response:
197+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
198+
hub_name=hub_name,
199+
hub_content_name=hub_content_name,
200+
hub_content_type=hub_content_type,
201+
next_token=list_hub_content_versions_response["NextToken"],
202+
)
203+
all_hub_content_summaries.extend(
204+
list_hub_content_versions_response.get("HubContentSummaries")
205+
)
206+
return all_hub_content_summaries
207+
208+
184209
def _get_hub_model_version_for_open_weight_version(
185210
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
186211
) -> str:
187212
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
188213

189214
if hub_model_version == "*" or hub_model_version is None:
190-
return str(max(available_model_versions))
215+
return str(max(version.parse(v) for v in available_model_versions))
191216

192217
try:
193218
spec = SpecifierSet(f"=={hub_model_version}")

src/sagemaker/jumpstart/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12791279
"hosting_neuron_model_version",
12801280
"hub_content_type",
12811281
"_is_hub_content",
1282+
"default_training_dataset_key",
1283+
"default_training_dataset_uri",
12821284
]
12831285

12841286
_non_serializable_slots = ["_is_hub_content"]
@@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
14621464
else None
14631465
)
14641466
self.model_subscription_link = json_obj.get("model_subscription_link")
1467+
self.default_training_dataset_key: Optional[str] = json_obj.get(
1468+
"default_training_dataset_key"
1469+
)
1470+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
1471+
"default_training_dataset_uri"
1472+
)
14651473

14661474
def to_json(self) -> Dict[str, Any]:
14671475
"""Returns json representation of JumpStartMetadataBaseFields object."""

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4949
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
50-
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
50+
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
5151
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5252
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
5353
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)