Skip to content

feat: custom base job name for jumpstart models/estimators #2970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,10 @@ Deployment may take about 5 minutes.
   instance_type=instance_type,
)

Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers
image, the corresponding Models and Endpoints display the “pytorch”
prefix when viewed in the AWS console. To verify that these models
were created successfully with your desired base model, refer to
the ``Tags`` section.
Because the model and script URIs are owned by JumpStart, the endpoint,
endpoint config and model resources will be prefixed with
``sagemaker-jumpstart``. Refer to the model ``Tags`` to inspect the
JumpStart artifacts involved in the model creation.

Perform Inference
-----------------
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import (
add_jumpstart_tags,
get_jumpstart_base_name_if_jumpstart_model,
update_inference_tags_with_jumpstart_training_tags,
)
from sagemaker.local import LocalSession
Expand Down Expand Up @@ -570,7 +571,9 @@ def _ensure_base_job_name(self):
"""Set ``self.base_job_name`` if it is not set already."""
# honor supplied base_job_name or generate it
if self.base_job_name is None:
self.base_job_name = base_name_from_image(self.training_image_uri())
self.base_job_name = get_jumpstart_base_name_if_jumpstart_model(
self.source_dir, self.model_uri
) or base_name_from_image(self.training_image_uri())

def _get_or_create_name(self, name=None):
"""Generate a name based on the base job name or training image if needed.
Expand Down Expand Up @@ -1208,7 +1211,15 @@ def deploy(
is_serverless = serverless_inference_config is not None
self._ensure_latest_training_job()
self._ensure_base_job_name()
default_name = name_from_base(self.base_job_name)

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
)
default_name = (
name_from_base(jumpstart_base_name)
if jumpstart_base_name
else name_from_base(self.base_job_name)
)
endpoint_name = endpoint_name or default_name
model_name = model_name or default_name

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,5 @@
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"

JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
13 changes: 13 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def add_single_jumpstart_tag(
return curr_tags


def get_jumpstart_base_name_if_jumpstart_model(
*uris,
) -> Optional[str]:
"""Return default JumpStart base name if a URI belongs to JumpStart.

If no URIs belong to JumpStart, return None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please address Mufis' comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did (*uris (Optional[str]): URI to test for association with JumpStart.). I believe you're looking at an old version.

"""
for uri in uris:
if is_jumpstart_model_uri(uri):
return constants.JUMPSTART_RESOURCE_BASE_NAME
return None


def add_jumpstart_tags(
tags: Optional[List[Dict[str, str]]] = None,
inference_model_uri: Optional[str] = None,
Expand Down
18 changes: 13 additions & 5 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from sagemaker.predictor import PredictorBase
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.transformer import Transformer
from sagemaker.jumpstart.utils import add_jumpstart_tags
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
from sagemaker.utils import unique_name_from_base
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.predictor_async import AsyncPredictor
Expand Down Expand Up @@ -514,7 +514,9 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
"""
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)

self._ensure_base_name_if_needed(container_def["Image"])
self._ensure_base_name_if_needed(
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
)
self._set_model_name_if_needed()

enable_network_isolation = self.enable_network_isolation()
Expand All @@ -529,10 +531,14 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
tags=tags,
)

def _ensure_base_name_if_needed(self, image_uri):
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
"""Create a base name from the image URI if there is no model name provided."""
if self.name is None:
self._base_name = self._base_name or utils.base_name_from_image(image_uri)
self._base_name = (
self._base_name
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
or utils.base_name_from_image(image_uri)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: I would prefer to define a jumpstart_base_name variable here as well.


def _set_model_name_if_needed(self):
"""Generate a new model name if ``self._base_name`` is present."""
Expand Down Expand Up @@ -963,7 +969,9 @@ def deploy(

compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
if self._is_compiled_model and not is_serverless:
self._ensure_base_name_if_needed(self.image_uri)
self._ensure_base_name_if_needed(
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
)
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TRAINING_ENTRY_POINT_SCRIPT_NAME,
)
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.utils import name_from_base
from sagemaker.predictor import Predictor
from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
JUMPSTART_TAG,
Expand Down Expand Up @@ -106,19 +106,17 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
model_id=model_id, model_version=model_version, model_scope="inference"
)

endpoint_name = name_from_base(f"{model_id}-transfer-learning")

estimator.deploy(
predictor: Predictor = estimator.deploy(
initial_instance_count=instance_count,
instance_type=inference_instance_type,
entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
image_uri=image_uri,
source_dir=script_uri,
endpoint_name=endpoint_name,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

endpoint_invoker = EndpointInvoker(
endpoint_name=endpoint_name,
endpoint_name=predictor.endpoint_name,
)

response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
JUMPSTART_BUCKET_NAME_SET,
JUMPSTART_REGION_NAME_SET,
JUMPSTART_RESOURCE_BASE_NAME,
JumpStartScriptScope,
)
from sagemaker.jumpstart.enums import JumpStartTag
Expand Down Expand Up @@ -874,3 +875,11 @@ def make_deprecated_spec(*largs, **kwargs):
"pytorch-eqa-bert-base-cased",
"*",
)


def test_get_jumpstart_base_name_if_jumpstart_model():
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)

uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(0, 10))]
assert utils.get_jumpstart_base_name_if_jumpstart_model(*uris) is None
92 changes: 91 additions & 1 deletion tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sagemaker
from sagemaker.model import FrameworkModel, Model
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
from sagemaker.jumpstart.enums import JumpStartTag
from sagemaker.mxnet.model import MXNetModel
from sagemaker.pytorch.model import PyTorchModel
Expand Down Expand Up @@ -551,3 +551,93 @@ def test_all_framework_models_add_jumpstart_tags(

sagemaker_session.create_model.reset_mock()
sagemaker_session.endpoint_from_production_variants.reset_mock()


@patch("sagemaker.utils.repack_model")
def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_session):

jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=jumpstart_source_dir,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
JUMPSTART_RESOURCE_BASE_NAME
)

assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
JUMPSTART_RESOURCE_BASE_NAME
)

sagemaker_session.create_model.reset_mock()
sagemaker_session.endpoint_from_production_variants.reset_mock()

non_jumpstart_source_dir = "s3://blah/blah/blah"
t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=non_jumpstart_source_dir,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert not sagemaker_session.create_model.call_args_list[0][0][0].startswith(
JUMPSTART_RESOURCE_BASE_NAME
)

assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][
"name"
].startswith(JUMPSTART_RESOURCE_BASE_NAME)


@patch("sagemaker.utils.repack_model")
@patch("sagemaker.fw_utils.tar_and_upload_dir")
def test_all_framework_models_add_jumpstart_base_name(
repack_model, tar_and_uload_dir, sagemaker_session
):
framework_model_classes_to_kwargs = {
PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"},
TensorFlowModel: {
"framework_version": "2.3",
},
HuggingFaceModel: {
"pytorch_version": "1.7.1",
"py_version": "py36",
"transformers_version": "4.6.1",
},
MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"},
SKLearnModel: {
"framework_version": "0.23-1",
},
XGBoostModel: {
"framework_version": "1.3-1",
},
}
jumpstart_model_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
framework_model_class(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
model_data=jumpstart_model_dir,
**kwargs,
).deploy(instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT)

assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
JUMPSTART_RESOURCE_BASE_NAME
)

assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
JUMPSTART_RESOURCE_BASE_NAME
)

sagemaker_session.create_model.reset_mock()
sagemaker_session.endpoint_from_production_variants.reset_mock()
Loading