Skip to content

Commit 5ecf9e4

Browse files
committed
update for resource requirements and model package
1 parent 082f727 commit 5ecf9e4

11 files changed

+103
-50
lines changed

src/sagemaker/accept_types.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported accept types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported accept types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_accept_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -111,11 +115,11 @@ def retrieve_default(
111115
)
112116

113117
return artifacts._retrieve_default_accept_type(
114-
model_id,
115-
model_version,
116-
hub_arn,
117-
region,
118-
tolerate_vulnerable_model,
119-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
120124
sagemaker_session=sagemaker_session,
121125
)

src/sagemaker/content_types.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported content types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported content types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_content_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -111,12 +115,12 @@ def retrieve_default(
111115
)
112116

113117
return artifacts._retrieve_default_content_type(
114-
model_id,
115-
model_version,
116-
hub_arn,
117-
region,
118-
tolerate_vulnerable_model,
119-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
120124
sagemaker_session=sagemaker_session,
121125
)
122126

src/sagemaker/deserializers.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def retrieve_options(
4242
region: Optional[str] = None,
4343
model_id: Optional[str] = None,
4444
model_version: Optional[str] = None,
45+
hub_arn: Optional[str] = None,
4546
tolerate_vulnerable_model: bool = False,
4647
tolerate_deprecated_model: bool = False,
4748
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -55,6 +56,8 @@ def retrieve_options(
5556
retrieve the supported deserializers. (Default: None).
5657
model_version (str): The version of the model for which to retrieve the
5758
supported deserializers. (Default: None).
59+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
60+
model details from. (Default: None).
5861
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5962
specifications should be tolerated (exception not raised). If False, raises an
6063
exception if the script used by this version of the model has dependencies with known
@@ -79,11 +82,12 @@ def retrieve_options(
7982
)
8083

8184
return artifacts._retrieve_deserializer_options(
82-
model_id,
83-
model_version,
84-
region,
85-
tolerate_vulnerable_model,
86-
tolerate_deprecated_model,
85+
model_id=model_id,
86+
model_version=model_version,
87+
hub_arn=hub_arn,
88+
region=region,
89+
tolerate_vulnerable_model=tolerate_vulnerable_model,
90+
tolerate_deprecated_model=tolerate_deprecated_model,
8791
sagemaker_session=sagemaker_session,
8892
)
8993

@@ -132,11 +136,11 @@ def retrieve_default(
132136
)
133137

134138
return artifacts._retrieve_default_deserializer(
135-
model_id,
136-
model_version,
137-
hub_arn,
138-
region,
139-
tolerate_vulnerable_model,
140-
tolerate_deprecated_model,
139+
model_id=model_id,
140+
model_version=model_version,
141+
hub_arn=hub_arn,
142+
region=region,
143+
tolerate_vulnerable_model=tolerate_vulnerable_model,
144+
tolerate_deprecated_model=tolerate_deprecated_model,
141145
sagemaker_session=sagemaker_session,
142146
)

src/sagemaker/jumpstart/artifacts/model_packages.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _retrieve_model_package_arn(
3131
model_version: str,
3232
instance_type: Optional[str],
3333
region: Optional[str],
34+
hub_arn: Optional[str] = None,
3435
scope: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def _retrieve_model_package_arn(
4647
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
4748
specific for the instance type.
4849
region (Optional[str]): Region for which to retrieve the model package arn.
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
scope (Optional[str]): Scope for which to retrieve the model package arn.
5053
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5154
specifications should be tolerated (exception not raised). If False, raises an
@@ -69,6 +72,7 @@ def _retrieve_model_package_arn(
6972
model_specs = verify_model_region_and_return_specs(
7073
model_id=model_id,
7174
version=model_version,
75+
hub_arn=hub_arn,
7276
scope=scope,
7377
region=region,
7478
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/predictors.py

+16
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _retrieve_default_serializer(
167167
def _retrieve_deserializer_options(
168168
model_id: str,
169169
model_version: str,
170+
hub_arn: Optional[str],
170171
region: Optional[str],
171172
tolerate_vulnerable_model: bool = False,
172173
tolerate_deprecated_model: bool = False,
@@ -179,6 +180,8 @@ def _retrieve_deserializer_options(
179180
retrieve the supported deserializers.
180181
model_version (str): Version of the JumpStart model for which to retrieve the
181182
supported deserializers.
183+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
184+
model details from. (default: None).
182185
region (Optional[str]): Region for which to retrieve deserializer options.
183186
tolerate_vulnerable_model (bool): True if vulnerable versions of model
184187
specifications should be tolerated (exception not raised). If False, raises an
@@ -198,6 +201,7 @@ def _retrieve_deserializer_options(
198201
supported_accept_types = _retrieve_supported_accept_types(
199202
model_id=model_id,
200203
model_version=model_version,
204+
hub_arn=hub_arn,
201205
region=region,
202206
tolerate_vulnerable_model=tolerate_vulnerable_model,
203207
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -224,6 +228,7 @@ def _retrieve_deserializer_options(
224228
def _retrieve_serializer_options(
225229
model_id: str,
226230
model_version: str,
231+
hub_arn: Optional[str],
227232
region: Optional[str],
228233
tolerate_vulnerable_model: bool = False,
229234
tolerate_deprecated_model: bool = False,
@@ -236,6 +241,8 @@ def _retrieve_serializer_options(
236241
retrieve the supported serializers.
237242
model_version (str): Version of the JumpStart model for which to retrieve the
238243
supported serializers.
244+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
245+
model details from. (default: None).
239246
region (Optional[str]): Region for which to retrieve serializer options.
240247
tolerate_vulnerable_model (bool): True if vulnerable versions of model
241248
specifications should be tolerated (exception not raised). If False, raises an
@@ -255,6 +262,7 @@ def _retrieve_serializer_options(
255262
supported_content_types = _retrieve_supported_content_types(
256263
model_id=model_id,
257264
model_version=model_version,
265+
hub_arn=hub_arn,
258266
region=region,
259267
tolerate_vulnerable_model=tolerate_vulnerable_model,
260268
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -386,6 +394,7 @@ def _retrieve_default_accept_type(
386394
def _retrieve_supported_accept_types(
387395
model_id: str,
388396
model_version: str,
397+
hub_arn: Optional[str],
389398
region: Optional[str],
390399
tolerate_vulnerable_model: bool = False,
391400
tolerate_deprecated_model: bool = False,
@@ -398,6 +407,8 @@ def _retrieve_supported_accept_types(
398407
retrieve the supported accept types.
399408
model_version (str): Version of the JumpStart model for which to retrieve the
400409
supported accept types.
410+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
411+
model details from. (default: None).
401412
region (Optional[str]): Region for which to retrieve accept type options.
402413
tolerate_vulnerable_model (bool): True if vulnerable versions of model
403414
specifications should be tolerated (exception not raised). If False, raises an
@@ -420,6 +431,7 @@ def _retrieve_supported_accept_types(
420431
model_specs = verify_model_region_and_return_specs(
421432
model_id=model_id,
422433
version=model_version,
434+
hub_arn=hub_arn,
423435
scope=JumpStartScriptScope.INFERENCE,
424436
region=region,
425437
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -435,6 +447,7 @@ def _retrieve_supported_accept_types(
435447
def _retrieve_supported_content_types(
436448
model_id: str,
437449
model_version: str,
450+
hub_arn: Optional[str],
438451
region: Optional[str],
439452
tolerate_vulnerable_model: bool = False,
440453
tolerate_deprecated_model: bool = False,
@@ -447,6 +460,8 @@ def _retrieve_supported_content_types(
447460
retrieve the supported content types.
448461
model_version (str): Version of the JumpStart model for which to retrieve the
449462
supported content types.
463+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
464+
model details from. (default: None).
450465
region (Optional[str]): Region for which to retrieve content type options.
451466
tolerate_vulnerable_model (bool): True if vulnerable versions of model
452467
specifications should be tolerated (exception not raised). If False, raises an
@@ -469,6 +484,7 @@ def _retrieve_supported_content_types(
469484
model_specs = verify_model_region_and_return_specs(
470485
model_id=model_id,
471486
version=model_version,
487+
hub_arn=hub_arn,
472488
scope=JumpStartScriptScope.INFERENCE,
473489
region=region,
474490
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/resource_requirements.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_default_resources(
3333
model_id: str,
3434
model_version: str,
3535
scope: str,
36+
hub_arn: Optional[str] = None,
3637
region: Optional[str] = None,
3738
tolerate_vulnerable_model: bool = False,
3839
tolerate_deprecated_model: bool = False,
@@ -47,6 +48,8 @@ def _retrieve_default_resources(
4748
default resource requirements.
4849
scope (str): The script type, i.e. what it is used for.
4950
Valid values: "training" and "inference".
51+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
52+
model details from. (default: None).
5053
region (Optional[str]): Region for which to retrieve default resource requirements.
5154
(Default: None).
5255
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -76,6 +79,7 @@ def _retrieve_default_resources(
7679
model_specs = verify_model_region_and_return_specs(
7780
model_id=model_id,
7881
version=model_version,
82+
hub_arn=hub_arn,
7983
scope=scope,
8084
region=region,
8185
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/factory/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
355355
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
356356
model_id=kwargs.model_id,
357357
model_version=kwargs.model_version,
358+
hub_arn=kwargs.hub_arn,
358359
instance_type=kwargs.instance_type,
359360
scope=JumpStartScriptScope.INFERENCE,
360361
region=kwargs.region,
@@ -595,6 +596,7 @@ def get_deploy_kwargs(
595596
def get_register_kwargs(
596597
model_id: str,
597598
model_version: Optional[str] = None,
599+
hub_arn: Optional[str] = None,
598600
region: Optional[str] = None,
599601
tolerate_deprecated_model: Optional[bool] = None,
600602
tolerate_vulnerable_model: Optional[bool] = None,
@@ -626,6 +628,7 @@ def get_register_kwargs(
626628
register_kwargs = JumpStartModelRegisterKwargs(
627629
model_id=model_id,
628630
model_version=model_version,
631+
hub_arn=hub_arn,
629632
region=region,
630633
tolerate_deprecated_model=tolerate_deprecated_model,
631634
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -656,6 +659,7 @@ def get_register_kwargs(
656659
model_specs = verify_model_region_and_return_specs(
657660
model_id=model_id,
658661
version=model_version,
662+
hub_arn=hub_arn,
659663
region=region,
660664
scope=JumpStartScriptScope.INFERENCE,
661665
sagemaker_session=sagemaker_session,

src/sagemaker/jumpstart/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def register(
696696
register_kwargs = get_register_kwargs(
697697
model_id=self.model_id,
698698
model_version=self.model_version,
699+
hub_arn=self.hub_arn,
699700
region=self.region,
700701
tolerate_deprecated_model=self.tolerate_deprecated_model,
701702
tolerate_vulnerable_model=self.tolerate_vulnerable_model,

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -1657,6 +1657,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
16571657
"region",
16581658
"model_id",
16591659
"model_version",
1660+
"hub_arn",
16601661
"sagemaker_session",
16611662
"content_types",
16621663
"response_types",
@@ -1687,13 +1688,15 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
16871688
"region",
16881689
"model_id",
16891690
"model_version",
1691+
"hub_arn",
16901692
"sagemaker_session",
16911693
}
16921694

16931695
def __init__(
16941696
self,
16951697
model_id: str,
16961698
model_version: Optional[str] = None,
1699+
hub_arn: Optional[str] = None,
16971700
region: Optional[str] = None,
16981701
tolerate_deprecated_model: Optional[bool] = None,
16991702
tolerate_vulnerable_model: Optional[bool] = None,
@@ -1724,6 +1727,7 @@ def __init__(
17241727

17251728
self.model_id = model_id
17261729
self.model_version = model_version
1730+
self.hub_arn = hub_arn
17271731
self.region = region
17281732
self.image_uri = image_uri
17291733
self.sagemaker_session = sagemaker_session

0 commit comments

Comments
 (0)