Skip to content

feat: add integ tests for training JumpStart models in private hub #5076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 10, 2025
33 changes: 27 additions & 6 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
JUMPSTART_LOGGER,
TRAINING_ENTRY_POINT_SCRIPT_NAME,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
JUMPSTART_MODEL_HUB_NAME,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.factory import model
Expand Down Expand Up @@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs(
):
"""Adds HubAccessConfig to kwargs inputs"""

dataset_uri = kwargs.specs.default_training_dataset_uri
if isinstance(kwargs.inputs, str):
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
if dataset_uri is not None and dataset_uri == kwargs.inputs:
kwargs.inputs = TrainingInput(
s3_data=kwargs.inputs, hub_access_config=hub_access_config
)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, dict):
for k, v in kwargs.inputs.items():
if isinstance(v, str):
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
training_input = TrainingInput(s3_data=v)
if dataset_uri is not None and dataset_uri == v:
training_input.add_hub_access_config(hub_access_config=hub_access_config)
kwargs.inputs[k] = training_input
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)

return kwargs

Expand Down Expand Up @@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs(

def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
"""Sets model uri in kwargs based on default or override, returns full kwargs."""

if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
# hub_arn is by default None unless the user specifies the hub_name
# If no hub_name is specified, it is assumed the public hub
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
if (
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
or is_private_hub
):
default_model_uri = model_uris.retrieve(
model_scope=JumpStartScriptScope.TRAINING,
instance_type=kwargs.instance_type,
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if json_obj.get("ValidationSupported")
else None
)
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
Expand Down Expand Up @@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
)

if self.training_supported:
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"DefaultTrainingDatasetUri"
)
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
"TrainingModelPackageArtifactUri"
)
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
specs["training_instance_type_variants"] = (
hub_model_document.training_instance_type_variants
)
if hub_model_document.default_training_dataset_uri:
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.default_training_dataset_uri
)
specs["default_training_dataset_key"] = default_training_dataset_key
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
33 changes: 29 additions & 4 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
from sagemaker.jumpstart import constants
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from packaging import version

PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"

Expand Down Expand Up @@ -219,9 +220,12 @@ def get_hub_model_version(
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION

try:
hub_content_summaries = sagemaker_session.list_hub_content_versions(
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
).get("HubContentSummaries")
hub_content_summaries = _list_hub_content_versions_helper(
hub_name=hub_name,
hub_content_name=hub_model_name,
hub_content_type=hub_model_type,
sagemaker_session=sagemaker_session,
)
except Exception as ex:
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")

Expand All @@ -238,13 +242,34 @@ def get_hub_model_version(
raise


def _list_hub_content_versions_helper(
hub_name, hub_content_name, hub_content_type, sagemaker_session
):
all_hub_content_summaries = []
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
)
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
while "NextToken" in list_hub_content_versions_response:
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
hub_name=hub_name,
hub_content_name=hub_content_name,
hub_content_type=hub_content_type,
next_token=list_hub_content_versions_response["NextToken"],
)
all_hub_content_summaries.extend(
list_hub_content_versions_response.get("HubContentSummaries")
)
return all_hub_content_summaries


def _get_hub_model_version_for_open_weight_version(
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
) -> str:
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]

if hub_model_version == "*" or hub_model_version is None:
return str(max(available_model_versions))
return str(max(version.parse(v) for v in available_model_versions))

try:
spec = SpecifierSet(f"=={hub_model_version}")
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"hosting_neuron_model_version",
"hub_content_type",
"_is_hub_content",
"default_training_dataset_key",
"default_training_dataset_uri",
]

_non_serializable_slots = ["_is_hub_content"]
Expand Down Expand Up @@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")
self.default_training_dataset_key: Optional[str] = json_obj.get(
"default_training_dataset_key"
)
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"default_training_dataset_uri"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataBaseFields object."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import time

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket

from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
)

MAX_INIT_TIME_SECONDS = 5

TEST_MODEL_IDS = {
"huggingface-spc-bert-base-cased",
"meta-textgeneration-llama-2-7b",
"catboost-regression-model",
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):
model_id, model_version = "huggingface-spc-bert-base-cased", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_estimator_with_session(setup, add_model_references):

model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

estimator.fit(
accept_eula=True,
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
},
)

predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

payload = {
"inputs": "some-payload",
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
}

response = predictor.predict(payload, custom_attributes="accept_eula=true")

assert response is not None


def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)
with pytest.raises(Exception):
estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)


def test_instantiating_estimator(setup, add_model_references):

model_id = "catboost-regression-model"

start_time = time.perf_counter()

JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

elapsed_time = time.perf_counter() - start_time

assert elapsed_time <= MAX_INIT_TIME_SECONDS
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@

@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
hub_instance.create_model_reference(model_arn=model_arn)
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
Expand Down
Loading