Skip to content

Commit c639b19

Browse files
committed
feat: custom base job name for jumpstart models/estimators
1 parent 26f2446 commit c639b19

File tree

5 files changed

+45
-12
lines changed

5 files changed

+45
-12
lines changed

src/sagemaker/estimator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from sagemaker.job import _Job
5151
from sagemaker.jumpstart.utils import (
5252
add_jumpstart_tags,
53+
get_jumpstart_base_name_if_jumpstart_model,
5354
update_inference_tags_with_jumpstart_training_tags,
5455
)
5556
from sagemaker.local import LocalSession
@@ -570,7 +571,9 @@ def _ensure_base_job_name(self):
570571
"""Set ``self.base_job_name`` if it is not set already."""
571572
# honor supplied base_job_name or generate it
572573
if self.base_job_name is None:
573-
self.base_job_name = base_name_from_image(self.training_image_uri())
574+
self.base_job_name = get_jumpstart_base_name_if_jumpstart_model(
575+
self.source_dir, self.model_uri
576+
) or base_name_from_image(self.training_image_uri())
574577

575578
def _get_or_create_name(self, name=None):
576579
"""Generate a name based on the base job name or training image if needed.
@@ -1208,7 +1211,15 @@ def deploy(
12081211
is_serverless = serverless_inference_config is not None
12091212
self._ensure_latest_training_job()
12101213
self._ensure_base_job_name()
1211-
default_name = name_from_base(self.base_job_name)
1214+
1215+
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1216+
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
1217+
)
1218+
default_name = (
1219+
name_from_base(jumpstart_base_name)
1220+
if jumpstart_base_name
1221+
else name_from_base(self.base_job_name)
1222+
)
12121223
endpoint_name = endpoint_name or default_name
12131224
model_name = model_name or default_name
12141225

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,5 @@
124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125125

126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127+
128+
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"

src/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,19 @@ def add_single_jumpstart_tag(
232232
return curr_tags
233233

234234

235+
def get_jumpstart_base_name_if_jumpstart_model(
236+
*uris,
237+
) -> Optional[str]:
238+
"""Return default JumpStart base name if a URI belongs to JumpStart.
239+
240+
If no URIs belong to JumpStart, return None.
241+
"""
242+
for uri in uris:
243+
if is_jumpstart_model_uri(uri):
244+
return constants.JUMPSTART_RESOURCE_BASE_NAME
245+
return None
246+
247+
235248
def add_jumpstart_tags(
236249
tags: Optional[List[Dict[str, str]]] = None,
237250
inference_model_uri: Optional[str] = None,

src/sagemaker/model.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sagemaker.predictor import PredictorBase
3434
from sagemaker.serverless import ServerlessInferenceConfig
3535
from sagemaker.transformer import Transformer
36-
from sagemaker.jumpstart.utils import add_jumpstart_tags
36+
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
@@ -514,7 +514,9 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
514514
"""
515515
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
516516

517-
self._ensure_base_name_if_needed(container_def["Image"])
517+
self._ensure_base_name_if_needed(
518+
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
519+
)
518520
self._set_model_name_if_needed()
519521

520522
enable_network_isolation = self.enable_network_isolation()
@@ -529,10 +531,14 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
529531
tags=tags,
530532
)
531533

532-
def _ensure_base_name_if_needed(self, image_uri):
534+
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
533535
"""Create a base name from the image URI if there is no model name provided."""
534536
if self.name is None:
535-
self._base_name = self._base_name or utils.base_name_from_image(image_uri)
537+
self._base_name = (
538+
self._base_name
539+
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
540+
or utils.base_name_from_image(image_uri)
541+
)
536542

537543
def _set_model_name_if_needed(self):
538544
"""Generate a new model name if ``self._base_name`` is present."""
@@ -963,7 +969,9 @@ def deploy(
963969

964970
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
965971
if self._is_compiled_model and not is_serverless:
966-
self._ensure_base_name_if_needed(self.image_uri)
972+
self._ensure_base_name_if_needed(
973+
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
974+
)
967975
if self._base_name is not None:
968976
self._base_name = "-".join((self._base_name, compiled_model_suffix))
969977

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TRAINING_ENTRY_POINT_SCRIPT_NAME,
2222
)
2323
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.predictor import Predictor
2425
from sagemaker.utils import name_from_base
2526
from tests.integ.sagemaker.jumpstart.constants import (
2627
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
@@ -106,19 +107,17 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
106107
model_id=model_id, model_version=model_version, model_scope="inference"
107108
)
108109

109-
endpoint_name = name_from_base(f"{model_id}-transfer-learning")
110-
111-
estimator.deploy(
110+
predictor: Predictor = estimator.deploy(
112111
initial_instance_count=instance_count,
113112
instance_type=inference_instance_type,
114113
entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
115114
image_uri=image_uri,
116115
source_dir=script_uri,
117-
endpoint_name=endpoint_name,
116+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
118117
)
119118

120119
endpoint_invoker = EndpointInvoker(
121-
endpoint_name=endpoint_name,
120+
endpoint_name=predictor.endpoint_name,
122121
)
123122

124123
response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])

0 commit comments

Comments
 (0)