Skip to content

Commit 7c87c52

Browse files
authored
Merge branch 'master-jumpstart-curated-hub' into feature/hubutil
2 parents 0937c74 + ec04711 commit 7c87c52

File tree

73 files changed

+4157
-2867
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+4157
-2867
lines changed

CHANGELOG.md

+2,285-2,269
Large diffs are not rendered by default.

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
recursive-include src/sagemaker *.py
22

33
include src/sagemaker/image_uri_config/*.json
4+
include src/sagemaker/serve/schema/*.json
45
include src/sagemaker/serve/requirements.txt
56
recursive-include requirements *
67

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.208.1.dev0
1+
2.209.1.dev0

src/sagemaker/accept_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported accept types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported accept types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_accept_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -72,6 +76,7 @@ def retrieve_default(
7276
region: Optional[str] = None,
7377
model_id: Optional[str] = None,
7478
model_version: Optional[str] = None,
79+
hub_arn: Optional[str] = None,
7580
tolerate_vulnerable_model: bool = False,
7681
tolerate_deprecated_model: bool = False,
7782
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +90,8 @@ def retrieve_default(
8590
retrieve the default accept type. (Default: None).
8691
model_version (str): The version of the model for which to retrieve the
8792
default accept type. (Default: None).
93+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
94+
model details from. (Default: None).
8895
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8996
specifications should be tolerated (exception not raised). If False, raises an
9097
exception if the script used by this version of the model has dependencies with known
@@ -108,10 +115,11 @@ def retrieve_default(
108115
)
109116

110117
return artifacts._retrieve_default_accept_type(
111-
model_id,
112-
model_version,
113-
region,
114-
tolerate_vulnerable_model,
115-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
116124
sagemaker_session=sagemaker_session,
117125
)

src/sagemaker/amazon/amazon_estimator.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,14 @@ def fit(
269269
if wait:
270270
self.latest_training_job.wait(logs=logs)
271271

272-
def record_set(self, train, labels=None, channel="train", encrypt=False):
272+
def record_set(
273+
self,
274+
train,
275+
labels=None,
276+
channel="train",
277+
encrypt=False,
278+
distribution="ShardedByS3Key",
279+
):
273280
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
274281
275282
For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -294,6 +301,8 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
294301
should be assigned to.
295302
encrypt (bool): Specifies whether the objects uploaded to S3 are
296303
encrypted on the server side using AES-256 (default: ``False``).
304+
distribution (str): The SageMaker TrainingJob channel s3 data
305+
distribution type (default: ``ShardedByS3Key``).
297306
298307
Returns:
299308
RecordSet: A RecordSet referencing the encoded, uploading training
@@ -316,6 +325,7 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
316325
num_records=train.shape[0],
317326
feature_dim=train.shape[1],
318327
channel=channel,
328+
distribution=distribution,
319329
)
320330

321331
def _get_default_mini_batch_size(self, num_records: int):
@@ -343,6 +353,7 @@ def __init__(
343353
feature_dim: int,
344354
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
345355
channel: Union[str, PipelineVariable] = "train",
356+
distribution: str = "ShardedByS3Key",
346357
):
347358
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
348359
@@ -358,12 +369,15 @@ def __init__(
358369
single s3 manifest file, listing each s3 object to train on.
359370
channel (str or PipelineVariable): The SageMaker Training Job channel this RecordSet
360371
should be bound to
372+
distribution (str): The SageMaker TrainingJob S3 data distribution type.
373+
Valid values: 'ShardedByS3Key', 'FullyReplicated'.
361374
"""
362375
self.s3_data = s3_data
363376
self.feature_dim = feature_dim
364377
self.num_records = num_records
365378
self.s3_data_type = s3_data_type
366379
self.channel = channel
380+
self.distribution = distribution
367381

368382
def __repr__(self):
369383
"""Return an unambiguous representation of this RecordSet"""
@@ -377,7 +391,7 @@ def data_channel(self):
377391
def records_s3_input(self):
378392
"""Return a TrainingInput to represent the training data"""
379393
return TrainingInput(
380-
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
394+
self.s3_data, distribution=self.distribution, s3_data_type=self.s3_data_type
381395
)
382396

383397

src/sagemaker/content_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported content types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported content types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_content_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -72,6 +76,7 @@ def retrieve_default(
7276
region: Optional[str] = None,
7377
model_id: Optional[str] = None,
7478
model_version: Optional[str] = None,
79+
hub_arn: Optional[str] = None,
7580
tolerate_vulnerable_model: bool = False,
7681
tolerate_deprecated_model: bool = False,
7782
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +90,8 @@ def retrieve_default(
8590
retrieve the default content type. (Default: None).
8691
model_version (str): The version of the model for which to retrieve the
8792
default content type. (Default: None).
93+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
94+
model details from. (default: None).
8895
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8996
specifications should be tolerated (exception not raised). If False, raises an
9097
exception if the script used by this version of the model has dependencies with known
@@ -108,11 +115,12 @@ def retrieve_default(
108115
)
109116

110117
return artifacts._retrieve_default_content_type(
111-
model_id,
112-
model_version,
113-
region,
114-
tolerate_vulnerable_model,
115-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
116124
sagemaker_session=sagemaker_session,
117125
)
118126

src/sagemaker/deserializers.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def retrieve_options(
4242
region: Optional[str] = None,
4343
model_id: Optional[str] = None,
4444
model_version: Optional[str] = None,
45+
hub_arn: Optional[str] = None,
4546
tolerate_vulnerable_model: bool = False,
4647
tolerate_deprecated_model: bool = False,
4748
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -55,6 +56,8 @@ def retrieve_options(
5556
retrieve the supported deserializers. (Default: None).
5657
model_version (str): The version of the model for which to retrieve the
5758
supported deserializers. (Default: None).
59+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
60+
model details from. (Default: None).
5861
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5962
specifications should be tolerated (exception not raised). If False, raises an
6063
exception if the script used by this version of the model has dependencies with known
@@ -79,11 +82,12 @@ def retrieve_options(
7982
)
8083

8184
return artifacts._retrieve_deserializer_options(
82-
model_id,
83-
model_version,
84-
region,
85-
tolerate_vulnerable_model,
86-
tolerate_deprecated_model,
85+
model_id=model_id,
86+
model_version=model_version,
87+
hub_arn=hub_arn,
88+
region=region,
89+
tolerate_vulnerable_model=tolerate_vulnerable_model,
90+
tolerate_deprecated_model=tolerate_deprecated_model,
8791
sagemaker_session=sagemaker_session,
8892
)
8993

@@ -92,6 +96,7 @@ def retrieve_default(
9296
region: Optional[str] = None,
9397
model_id: Optional[str] = None,
9498
model_version: Optional[str] = None,
99+
hub_arn: Optional[str] = None,
95100
tolerate_vulnerable_model: bool = False,
96101
tolerate_deprecated_model: bool = False,
97102
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -105,6 +110,8 @@ def retrieve_default(
105110
retrieve the default deserializer. (Default: None).
106111
model_version (str): The version of the model for which to retrieve the
107112
default deserializer. (Default: None).
113+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
114+
model details from. (Default: None).
108115
tolerate_vulnerable_model (bool): True if vulnerable versions of model
109116
specifications should be tolerated (exception not raised). If False, raises an
110117
exception if the script used by this version of the model has dependencies with known
@@ -129,10 +136,11 @@ def retrieve_default(
129136
)
130137

131138
return artifacts._retrieve_default_deserializer(
132-
model_id,
133-
model_version,
134-
region,
135-
tolerate_vulnerable_model,
136-
tolerate_deprecated_model,
139+
model_id=model_id,
140+
model_version=model_version,
141+
hub_arn=hub_arn,
142+
region=region,
143+
tolerate_vulnerable_model=tolerate_vulnerable_model,
144+
tolerate_deprecated_model=tolerate_deprecated_model,
137145
sagemaker_session=sagemaker_session,
138146
)

src/sagemaker/image_uri_config/huggingface-llm-neuronx.json

+29
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,35 @@
6464
"container_version": {
6565
"inf2": "ubuntu22.04"
6666
}
67+
},
68+
"0.0.18": {
69+
"py_versions": [
70+
"py310"
71+
],
72+
"registries": {
73+
"ap-northeast-1": "763104351884",
74+
"ap-south-1": "763104351884",
75+
"ap-south-2": "772153158452",
76+
"ap-southeast-1": "763104351884",
77+
"ap-southeast-2": "763104351884",
78+
"ap-southeast-4": "457447274322",
79+
"eu-central-1": "763104351884",
80+
"eu-central-2": "380420809688",
81+
"eu-south-2": "503227376785",
82+
"eu-west-1": "763104351884",
83+
"eu-west-3": "763104351884",
84+
"il-central-1": "780543022126",
85+
"sa-east-1": "763104351884",
86+
"us-east-1": "763104351884",
87+
"us-east-2": "763104351884",
88+
"us-west-2": "763104351884",
89+
"ca-west-1": "204538143572"
90+
},
91+
"tag_prefix": "1.13.1-optimum0.0.18",
92+
"repository": "huggingface-pytorch-tgi-inference",
93+
"container_version": {
94+
"inf2": "ubuntu22.04"
95+
}
6796
}
6897
}
6998
}

0 commit comments

Comments
 (0)