Skip to content

Commit dfc6eee

Browse files
evakravibencrabtreenavinsonijeniyatDewen Qi
authored
feat: custom base job name for jumpstart models/estimators (#2970)
Co-authored-by: Ben Crabtree <[email protected]> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: qidewenwhen <[email protected]> Co-authored-by: Qingzi-Lan <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Miyoung <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent 34cb1c7 commit dfc6eee

File tree

10 files changed

+311
-23
lines changed

10 files changed

+311
-23
lines changed

doc/overview.rst

+4-5
Original file line numberDiff line numberDiff line change
@@ -773,11 +773,10 @@ Deployment may take about 5 minutes.
773773
   instance_type=instance_type,
774774
)
775775
776-
Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers
777-
image, the corresponding Models and Endpoints display the “pytorch”
778-
prefix when viewed in the AWS console. To verify that these models
779-
were created successfully with your desired base model, refer to
780-
the ``Tags`` section.
776+
Because the model and script URIs are distributed by SageMaker JumpStart,
777+
the endpoint, endpoint config and model resources will be prefixed with
778+
``sagemaker-jumpstart``. Refer to the model ``Tags`` to inspect the
779+
JumpStart artifacts involved in the model creation.
781780

782781
Perform Inference
783782
-----------------

src/sagemaker/estimator.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from sagemaker.job import _Job
5151
from sagemaker.jumpstart.utils import (
5252
add_jumpstart_tags,
53+
get_jumpstart_base_name_if_jumpstart_model,
5354
update_inference_tags_with_jumpstart_training_tags,
5455
)
5556
from sagemaker.local import LocalSession
@@ -569,8 +570,11 @@ def prepare_workflow_for_training(self, job_name=None):
569570
def _ensure_base_job_name(self):
570571
"""Set ``self.base_job_name`` if it is not set already."""
571572
# honor supplied base_job_name or generate it
572-
if self.base_job_name is None:
573-
self.base_job_name = base_name_from_image(self.training_image_uri())
573+
self.base_job_name = (
574+
self.base_job_name
575+
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
576+
or base_name_from_image(self.training_image_uri())
577+
)
574578

575579
def _get_or_create_name(self, name=None):
576580
"""Generate a name based on the base job name or training image if needed.
@@ -1208,7 +1212,15 @@ def deploy(
12081212
is_serverless = serverless_inference_config is not None
12091213
self._ensure_latest_training_job()
12101214
self._ensure_base_job_name()
1211-
default_name = name_from_base(self.base_job_name)
1215+
1216+
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1217+
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
1218+
)
1219+
default_name = (
1220+
name_from_base(jumpstart_base_name)
1221+
if jumpstart_base_name
1222+
else name_from_base(self.base_job_name)
1223+
)
12121224
endpoint_name = endpoint_name or default_name
12131225
model_name = model_name or default_name
12141226

src/sagemaker/jumpstart/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _get_manifest_key_from_model_id_semantic_version(
229229
)
230230

231231
else:
232-
possible_model_ids = [header.model_id for header in manifest.values()]
232+
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
233233
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234234
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235235

src/sagemaker/jumpstart/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,5 @@
124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125125

126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127+
128+
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"

src/sagemaker/jumpstart/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,22 @@ def add_single_jumpstart_tag(
232232
return curr_tags
233233

234234

235+
def get_jumpstart_base_name_if_jumpstart_model(
236+
*uris: Optional[str],
237+
) -> Optional[str]:
238+
"""Return default JumpStart base name if a URI belongs to JumpStart.
239+
240+
If no URIs belong to JumpStart, return None.
241+
242+
Args:
243+
*uris (Optional[str]): URI to test for association with JumpStart.
244+
"""
245+
for uri in uris:
246+
if is_jumpstart_model_uri(uri):
247+
return constants.JUMPSTART_RESOURCE_BASE_NAME
248+
return None
249+
250+
235251
def add_jumpstart_tags(
236252
tags: Optional[List[Dict[str, str]]] = None,
237253
inference_model_uri: Optional[str] = None,

src/sagemaker/model.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sagemaker.predictor import PredictorBase
3434
from sagemaker.serverless import ServerlessInferenceConfig
3535
from sagemaker.transformer import Transformer
36-
from sagemaker.jumpstart.utils import add_jumpstart_tags
36+
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
@@ -517,7 +517,9 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
517517
"""
518518
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
519519

520-
self._ensure_base_name_if_needed(container_def["Image"])
520+
self._ensure_base_name_if_needed(
521+
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
522+
)
521523
self._set_model_name_if_needed()
522524

523525
enable_network_isolation = self.enable_network_isolation()
@@ -532,10 +534,17 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
532534
tags=tags,
533535
)
534536

535-
def _ensure_base_name_if_needed(self, image_uri):
536-
"""Create a base name from the image URI if there is no model name provided."""
537+
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
538+
"""Create a base name from the image URI if there is no model name provided.
539+
540+
If a JumpStart script or model uri is used, select the JumpStart base name.
541+
"""
537542
if self.name is None:
538-
self._base_name = self._base_name or utils.base_name_from_image(image_uri)
543+
self._base_name = (
544+
self._base_name
545+
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
546+
or utils.base_name_from_image(image_uri)
547+
)
539548

540549
def _set_model_name_if_needed(self):
541550
"""Generate a new model name if ``self._base_name`` is present."""
@@ -966,7 +975,9 @@ def deploy(
966975

967976
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
968977
if self._is_compiled_model and not is_serverless:
969-
self._ensure_base_name_if_needed(self.image_uri)
978+
self._ensure_base_name_if_needed(
979+
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
980+
)
970981
if self._base_name is not None:
971982
self._base_name = "-".join((self._base_name, compiled_model_suffix))
972983

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
TRAINING_ENTRY_POINT_SCRIPT_NAME,
2222
)
2323
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24-
from sagemaker.utils import name_from_base
24+
from sagemaker.predictor import Predictor
2525
from tests.integ.sagemaker.jumpstart.constants import (
2626
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
2727
JUMPSTART_TAG,
@@ -106,19 +106,17 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
106106
model_id=model_id, model_version=model_version, model_scope="inference"
107107
)
108108

109-
endpoint_name = name_from_base(f"{model_id}-transfer-learning")
110-
111-
estimator.deploy(
109+
predictor: Predictor = estimator.deploy(
112110
initial_instance_count=instance_count,
113111
instance_type=inference_instance_type,
114112
entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
115113
image_uri=image_uri,
116114
source_dir=script_uri,
117-
endpoint_name=endpoint_name,
115+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
118116
)
119117

120118
endpoint_invoker = EndpointInvoker(
121-
endpoint_name=endpoint_name,
119+
endpoint_name=predictor.endpoint_name,
122120
)
123121

124122
response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])

tests/unit/sagemaker/jumpstart/test_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
2121
JUMPSTART_BUCKET_NAME_SET,
2222
JUMPSTART_REGION_NAME_SET,
23+
JUMPSTART_RESOURCE_BASE_NAME,
2324
JumpStartScriptScope,
2425
)
2526
from sagemaker.jumpstart.enums import JumpStartTag
@@ -874,3 +875,23 @@ def make_deprecated_spec(*largs, **kwargs):
874875
"pytorch-eqa-bert-base-cased",
875876
"*",
876877
)
878+
879+
880+
def test_get_jumpstart_base_name_if_jumpstart_model():
881+
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]
882+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)
883+
884+
uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(0, 10))]
885+
assert utils.get_jumpstart_base_name_if_jumpstart_model(*uris) is None
886+
887+
uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(1, 10))] + [
888+
random_jumpstart_s3_uri("random_key")
889+
]
890+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)
891+
892+
uris = (
893+
["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(1, 10))]
894+
+ [random_jumpstart_s3_uri("random_key")]
895+
+ ["s3://not-jumpstart-bucket/some-key-2" for _ in range(random.randint(1, 10))]
896+
)
897+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)

tests/unit/sagemaker/model/test_model.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sagemaker
2020
from sagemaker.model import FrameworkModel, Model
2121
from sagemaker.huggingface.model import HuggingFaceModel
22-
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET
22+
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
2323
from sagemaker.jumpstart.enums import JumpStartTag
2424
from sagemaker.mxnet.model import MXNetModel
2525
from sagemaker.pytorch.model import PyTorchModel
@@ -569,3 +569,93 @@ def test_all_framework_models_add_jumpstart_tags(
569569

570570
sagemaker_session.create_model.reset_mock()
571571
sagemaker_session.endpoint_from_production_variants.reset_mock()
572+
573+
574+
@patch("sagemaker.utils.repack_model")
575+
def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_session):
576+
577+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
578+
t = Model(
579+
entry_point=ENTRY_POINT_INFERENCE,
580+
role=ROLE,
581+
sagemaker_session=sagemaker_session,
582+
source_dir=jumpstart_source_dir,
583+
image_uri=IMAGE_URI,
584+
model_data=MODEL_DATA,
585+
)
586+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
587+
588+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
589+
JUMPSTART_RESOURCE_BASE_NAME
590+
)
591+
592+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
593+
JUMPSTART_RESOURCE_BASE_NAME
594+
)
595+
596+
sagemaker_session.create_model.reset_mock()
597+
sagemaker_session.endpoint_from_production_variants.reset_mock()
598+
599+
non_jumpstart_source_dir = "s3://blah/blah/blah"
600+
t = Model(
601+
entry_point=ENTRY_POINT_INFERENCE,
602+
role=ROLE,
603+
sagemaker_session=sagemaker_session,
604+
source_dir=non_jumpstart_source_dir,
605+
image_uri=IMAGE_URI,
606+
model_data=MODEL_DATA,
607+
)
608+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
609+
610+
assert not sagemaker_session.create_model.call_args_list[0][0][0].startswith(
611+
JUMPSTART_RESOURCE_BASE_NAME
612+
)
613+
614+
assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][
615+
"name"
616+
].startswith(JUMPSTART_RESOURCE_BASE_NAME)
617+
618+
619+
@patch("sagemaker.utils.repack_model")
620+
@patch("sagemaker.fw_utils.tar_and_upload_dir")
621+
def test_all_framework_models_add_jumpstart_base_name(
622+
repack_model, tar_and_uload_dir, sagemaker_session
623+
):
624+
framework_model_classes_to_kwargs = {
625+
PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"},
626+
TensorFlowModel: {
627+
"framework_version": "2.3",
628+
},
629+
HuggingFaceModel: {
630+
"pytorch_version": "1.7.1",
631+
"py_version": "py36",
632+
"transformers_version": "4.6.1",
633+
},
634+
MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"},
635+
SKLearnModel: {
636+
"framework_version": "0.23-1",
637+
},
638+
XGBoostModel: {
639+
"framework_version": "1.3-1",
640+
},
641+
}
642+
jumpstart_model_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
643+
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
644+
framework_model_class(
645+
entry_point=ENTRY_POINT_INFERENCE,
646+
role=ROLE,
647+
sagemaker_session=sagemaker_session,
648+
model_data=jumpstart_model_dir,
649+
**kwargs,
650+
).deploy(instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT)
651+
652+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
653+
JUMPSTART_RESOURCE_BASE_NAME
654+
)
655+
656+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
657+
JUMPSTART_RESOURCE_BASE_NAME
658+
)
659+
660+
sagemaker_session.create_model.reset_mock()
661+
sagemaker_session.endpoint_from_production_variants.reset_mock()

0 commit comments

Comments
 (0)