diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 17ad7a76f5..12eb30daaf 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -56,6 +56,7 @@ JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, + JUMPSTART_MODEL_HUB_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model @@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs( ): """Adds HubAccessConfig to kwargs inputs""" + dataset_uri = kwargs.specs.default_training_dataset_uri if isinstance(kwargs.inputs, str): - kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) + if dataset_uri is not None and dataset_uri == kwargs.inputs: + kwargs.inputs = TrainingInput( + s3_data=kwargs.inputs, hub_access_config=hub_access_config + ) elif isinstance(kwargs.inputs, TrainingInput): - kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) elif isinstance(kwargs.inputs, dict): for k, v in kwargs.inputs.items(): if isinstance(v, str): - kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config) + training_input = TrainingInput(s3_data=v) + if dataset_uri is not None and dataset_uri == v: + training_input.add_hub_access_config(hub_access_config=hub_access_config) + kwargs.inputs[k] = training_input elif isinstance(kwargs.inputs, TrainingInput): - kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) return kwargs @@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs( def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" - - if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)): + # hub_arn is by default None unless the user specifies the hub_name + # If no hub_name is specified, it is assumed the public hub + is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False + if ( + _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)) + or is_private_hub + ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index fd38868dcc..6ba5a37c3c 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("ValidationSupported") else None ) - self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "DefaultTrainingDatasetUri" + ) self.training_model_package_artifact_uri: Optional[str] = json_obj.get( "TrainingModelPackageArtifactUri" ) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 01b6c5fe87..8070b54e87 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response( specs["training_instance_type_variants"] = ( hub_model_document.training_instance_type_variants ) + if hub_model_document.default_training_dataset_uri: + _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.default_training_dataset_uri + ) + specs["default_training_dataset_key"] = default_training_dataset_key + specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 1bbc6198a2..75af019ca6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +from packaging import version PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" @@ -219,9 +220,12 @@ def get_hub_model_version( sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: - hub_content_summaries = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type - ).get("HubContentSummaries") + hub_content_summaries = _list_hub_content_versions_helper( + hub_name=hub_name, + hub_content_name=hub_model_name, + hub_content_type=hub_model_type, + sagemaker_session=sagemaker_session, + ) except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") @@ -238,13 +242,34 @@ def get_hub_model_version( raise +def _list_hub_content_versions_helper( + hub_name, hub_content_name, hub_content_type, sagemaker_session +): + all_hub_content_summaries = [] + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type + ) + all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + while "NextToken" in list_hub_content_versions_response: + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, + hub_content_name=hub_content_name, + hub_content_type=hub_content_type, + next_token=list_hub_content_versions_response["NextToken"], + ) + all_hub_content_summaries.extend( + list_hub_content_versions_response.get("HubContentSummaries") + ) + return all_hub_content_summaries + + def _get_hub_model_version_for_open_weight_version( hub_content_summaries: List[Any], hub_model_version: Optional[str] = None ) -> str: available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: - return str(max(available_model_versions)) + return str(max(version.parse(v) for v in available_model_versions)) try: spec = SpecifierSet(f"=={hub_model_version}") diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 349396205e..0cd4bcc902 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "hosting_neuron_model_version", "hub_content_type", "_is_hub_content", + "default_training_dataset_key", + "default_training_dataset_uri", ] _non_serializable_slots = ["_is_hub_content"] @@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) self.model_subscription_link = json_obj.get("model_subscription_link") + self.default_training_dataset_key: Optional[str] = json_obj.get( + "default_training_dataset_key" + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "default_training_dataset_uri" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index 1ffb1d8dc0..740d88e9c0 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"), - ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"), + ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"), ("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"), ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py new file mode 100644 index 0000000000..a6e33f1bdf --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -0,0 +1,204 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, + get_training_dataset_for_model_and_version, +) + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "huggingface-spc-bert-base-cased", + "meta-textgeneration-llama-2-7b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_estimator(setup, add_model_references): + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_estimator_with_session(setup, add_model_references): + + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + sagemaker_session = get_sm_session() + + estimator = JumpStartEstimator( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + accept_eula=True, + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + }, + ) + + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + with pytest.raises(Exception): + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + +def test_instantiating_estimator(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index a64db4a97d..c7e039693b 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -48,7 +48,10 @@ @with_exponential_backoff() def create_model_reference(hub_instance, model_arn): - hub_instance.create_model_reference(model_arn=model_arn) + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass @pytest.fixture(scope="session") diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 4021599120..0c9065feb5 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -15553,6 +15553,8 @@ }, "inference_enable_network_isolation": True, "training_enable_network_isolation": True, + "default_training_dataset_uri": None, + "default_training_dataset_key": "training-datasets/tf_flowers/", "resource_name_base": "pt-ic-mobilenet-v2", "hosting_eula_key": None, "hosting_model_package_arns": {}, diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index acce8ef4f1..0b5ef63947 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -378,6 +378,7 @@ def test_jumpstart_model_specs(): specs1.training_script_key == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) + assert specs1.default_training_dataset_key == "training-datasets/tf_flowers/" assert specs1.hyperparameters == [ JumpStartHyperparameter( {