Skip to content

Commit 6e28f8f

Browse files
committed
Feature: register proprietary models from jumpstart
1 parent b96c98e commit 6e28f8f

File tree

6 files changed

+51
-1
lines changed

6 files changed

+51
-1
lines changed

src/sagemaker/jumpstart/factory/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def get_deploy_kwargs(
623623
def get_register_kwargs(
624624
model_id: str,
625625
model_version: Optional[str] = None,
626+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
626627
region: Optional[str] = None,
627628
tolerate_deprecated_model: Optional[bool] = None,
628629
tolerate_vulnerable_model: Optional[bool] = None,
@@ -656,6 +657,7 @@ def get_register_kwargs(
656657
register_kwargs = JumpStartModelRegisterKwargs(
657658
model_id=model_id,
658659
model_version=model_version,
660+
model_type=model_type,
659661
region=region,
660662
tolerate_deprecated_model=tolerate_deprecated_model,
661663
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -688,6 +690,7 @@ def get_register_kwargs(
688690
model_specs = verify_model_region_and_return_specs(
689691
model_id=model_id,
690692
version=model_version,
693+
model_type=model_type,
691694
region=region,
692695
scope=JumpStartScriptScope.INFERENCE,
693696
sagemaker_session=sagemaker_session,

src/sagemaker/jumpstart/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ def register(
796796
register_kwargs = get_register_kwargs(
797797
model_id=self.model_id,
798798
model_version=self.model_version,
799+
model_type=self.model_type,
799800
region=self.region,
800801
tolerate_deprecated_model=self.tolerate_deprecated_model,
801802
tolerate_vulnerable_model=self.tolerate_vulnerable_model,

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
20942094
"tolerate_deprecated_model",
20952095
"region",
20962096
"model_id",
2097+
"model_type",
20972098
"model_version",
20982099
"sagemaker_session",
20992100
"content_types",
@@ -2128,13 +2129,15 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
21282129
"model_id",
21292130
"model_version",
21302131
"sagemaker_session",
2132+
"model_type",
21312133
}
21322134

21332135
def __init__(
21342136
self,
21352137
model_id: str,
21362138
model_version: Optional[str] = None,
21372139
region: Optional[str] = None,
2140+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
21382141
tolerate_deprecated_model: Optional[bool] = None,
21392142
tolerate_vulnerable_model: Optional[bool] = None,
21402143
sagemaker_session: Optional[Any] = None,
@@ -2166,6 +2169,7 @@ def __init__(
21662169

21672170
self.model_id = model_id
21682171
self.model_version = model_version
2172+
self.model_type = model_type
21692173
self.region = region
21702174
self.image_uri = image_uri
21712175
self.sagemaker_session = sagemaker_session

src/sagemaker/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4545
load_sagemaker_config,
4646
)
47+
from sagemaker.jumpstart.enums import JumpStartModelType
4748
from sagemaker.model_card import (
4849
ModelCard,
4950
ModelPackageModelCard,
@@ -512,7 +513,7 @@ def register(
512513
if image_uri is not None:
513514
self.image_uri = image_uri
514515

515-
if model_package_group_name is None and model_package_name is None:
516+
if model_package_group_name is None and model_package_name is None and self.model_type is not JumpStartModelType.PROPRIETARY:
516517
# If model package group and model package name is not set
517518
# then register to auto-generated model package group
518519
model_package_group_name = utils.base_name_from_image(
@@ -542,6 +543,10 @@ def register(
542543
if self.model_data is not None:
543544
container_def["ModelDataUrl"] = self.model_data
544545

546+
if self.model_type is JumpStartModelType.PROPRIETARY:
547+
source_uri = self.model_package_arn
548+
model_package_group_name = self.model_id
549+
545550
model_pkg_args = sagemaker.get_model_package_args(
546551
self.content_types,
547552
self.response_types,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

+29
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,32 @@ def test_proprietary_jumpstart_model(setup):
291291
response = predictor.predict(payload)
292292

293293
assert response is not None
294+
295+
@pytest.mark.skipif(
296+
True,
297+
reason="Only enable if test account is subscribed to the proprietary model",
298+
)
299+
def test_register_proprietary_jumpstart_model(setup):
300+
301+
model_id = "ai21-jurassic-2-light"
302+
303+
model = JumpStartModel(
304+
model_id=model_id,
305+
model_version="2.0.004",
306+
role=get_sm_session().get_caller_identity_arn(),
307+
sagemaker_session=get_sm_session(),
308+
)
309+
310+
pp = model.register()
311+
312+
predictor = pp.deploy(
313+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
314+
)
315+
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
316+
317+
response = predictor.predict(payload)
318+
319+
predictor.delete_predictor()
320+
321+
assert response is not None
322+

tests/unit/sagemaker/jumpstart/model/test_model.py

+8
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
473473
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
474474
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
475475
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
476+
@mock.patch("sagemaker.jumpstart.model.Model.register")
476477
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
477478
def test_proprietary_model_endpoint(
478479
self,
480+
mock_model_register: mock.Mock,
479481
mock_model_deploy: mock.Mock,
480482
mock_model_init: mock.Mock,
481483
mock_get_model_specs: mock.Mock,
@@ -507,8 +509,14 @@ def test_proprietary_model_endpoint(
507509
enable_network_isolation=False,
508510
)
509511

512+
model.register()
510513
model.deploy()
511514

515+
mock_model_register.assert_called_once_with(
516+
content_types=["application/json"],
517+
response_types=["application/json"],
518+
)
519+
512520
mock_model_deploy.assert_called_once_with(
513521
initial_instance_count=1,
514522
instance_type="ml.p4de.24xlarge",

0 commit comments

Comments
 (0)