Skip to content

feat: add integ tests for training JumpStart models in private hub #5076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 10, 2025
24 changes: 19 additions & 5 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,31 @@ def _add_hub_access_config_to_kwargs_inputs(
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
):
"""Adds HubAccessConfig to kwargs inputs"""

dataset_uri = kwargs.specs.default_training_dataset_uri
if isinstance(kwargs.inputs, str):
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
if dataset_uri is not None and dataset_uri == kwargs.inputs:
kwargs.inputs = TrainingInput(
s3_data=kwargs.inputs, hub_access_config=hub_access_config
)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, dict):
for k, v in kwargs.inputs.items():
if isinstance(v, str):
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
training_input = TrainingInput(s3_data=v)
if dataset_uri is not None and dataset_uri == v:
training_input.add_hub_access_config(hub_access_config=hub_access_config)
kwargs.inputs[k] = training_input
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)

return kwargs

Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
specs["training_instance_type_variants"] = (
hub_model_document.training_instance_type_variants
)
if hub_model_document.default_training_dataset_uri:
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.default_training_dataset_uri
)
specs["default_training_dataset_key"] = default_training_dataset_key
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"hosting_neuron_model_version",
"hub_content_type",
"_is_hub_content",
"default_training_dataset_key",
"default_training_dataset_uri",
]

_non_serializable_slots = ["_is_hub_content"]
Expand Down Expand Up @@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")
self.default_training_dataset_key: Optional[str] = json_obj.get(
"default_training_dataset_key"
)
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"default_training_dataset_uri"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataBaseFields object."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import time

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket

from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
)

MAX_INIT_TIME_SECONDS = 5

TEST_MODEL_IDS = {
"huggingface-spc-bert-base-cased",
"meta-textgeneration-llama-2-7b",
"catboost-regression-model",
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
hub_instance.create_model_reference(model_arn=model_arn)


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):

model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references):
model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
accept_eula=True,
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
},
)

estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)
with pytest.raises(Exception):
estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)


def test_instantiating_estimator(setup, add_model_references):

model_id = "catboost-regression-model"

start_time = time.perf_counter()

JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

elapsed_time = time.perf_counter() - start_time

assert elapsed_time <= MAX_INIT_TIME_SECONDS
Loading