diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 3cbd0ad8a7..fa40719c9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config): raise ValueError( "File URIs are supported in local mode only. Please use a S3 URI instead." ) - config = _Job._load_config(inputs, estimator) current_hyperparameters = estimator.hyperparameters() diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 89779bef44..71678021d4 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -43,6 +43,8 @@ def __init__( attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, target_attribute_name: Optional[Union[str, PipelineVariable]] = None, shuffle_config: Optional["ShuffleConfig"] = None, + hub_access_config: Optional[dict] = None, + model_access_config: Optional[dict] = None, ): r"""Create a definition for input data used by an SageMaker training job. @@ -102,6 +104,13 @@ def __init__( shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables shuffling on this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + hub_access_config (dict): Specify the HubAccessConfig of a + Model Reference for which a training job is being created for. + model_access_config (dict): For models that require a Model Access Config, specify True + or False for to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ self.config = { "DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}} @@ -129,6 +138,27 @@ def __init__( self.config["TargetAttributeName"] = target_attribute_name if shuffle_config is not None: self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} + self.add_hub_access_config(hub_access_config) + self.add_model_access_config(model_access_config) + + def add_hub_access_config(self, hub_access_config=None): + """Add Hub Access Config to the channel's configuration. + + Args: + hub_access_config (dict): The HubAccessConfig to be added to the + channel's configuration. + """ + if hub_access_config is not None: + self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config + + def add_model_access_config(self, model_access_config=None): + """Add Model Access Config to the channel's configuration. + + Args: + model_access_config (dict): Whether model terms of use have been accepted. + """ + if model_access_config is not None: + self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config class ShuffleConfig(object): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 210dd426c5..1ad7e3b981 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -65,6 +65,7 @@ def stop(self): @staticmethod def _load_config(inputs, estimator, expand_role=True, validate_uri=True): """Placeholder docstring""" + model_access_config, hub_access_config = _Job._get_access_configs(estimator) input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) role = ( estimator.sagemaker_session.expand_role(estimator.role) @@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): validate_uri, content_type="application/x-sagemaker-model", input_mode="File", + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) - if estimator.enable_network_isolation(): - code_channel = _Job._prepare_channel( - input_config, estimator.code_uri, estimator.code_channel_name, validate_uri - ) + code_channel = _Job._prepare_channel( + input_config, + estimator.code_uri, + estimator.code_channel_name, + validate_uri, + ) - if code_channel: - input_config = [] if input_config is None else input_config - input_config.append(code_channel) + if code_channel: + input_config = [] if input_config is None else input_config + input_config.append(code_channel) return { "input_config": input_config, @@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): "vpc_config": vpc_config, } + @staticmethod + def _get_access_configs(estimator): + """Return access configs from estimator object. + + JumpStartEstimator uses access configs which need to be added to the model channel, + so they are passed down to the job level. + + Args: + estimator (EstimatorBase): estimator object with access config field if applicable + """ + model_access_config, hub_access_config = None, None + if hasattr(estimator, "model_access_config"): + model_access_config = estimator.model_access_config + if hasattr(estimator, "hub_access_config"): + hub_access_config = estimator.hub_access_config + return model_access_config, hub_access_config + @staticmethod def _format_inputs_to_input_config(inputs, validate_uri=True): """Placeholder docstring""" @@ -173,6 +195,8 @@ def _format_string_uri_input( input_mode=None, compression=None, target_attribute_name=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" s3_input_result = TrainingInput( @@ -181,6 +205,8 @@ def _format_string_uri_input( input_mode=input_mode, compression=compression, target_attribute_name=target_attribute_name, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input_result @@ -193,7 +219,11 @@ def _format_string_uri_input( ) if isinstance(uri_input, str): return s3_input_result - if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)): + if isinstance(uri_input, (file_input, FileSystemInput)): + return uri_input + if isinstance(uri_input, TrainingInput): + uri_input.add_hub_access_config(hub_access_config=hub_access_config) + uri_input.add_model_access_config(model_access_config=model_access_config) return uri_input if is_pipeline_variable(uri_input): return s3_input_result @@ -211,6 +241,8 @@ def _prepare_channel( validate_uri=True, content_type=None, input_mode=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" if not channel_uri: @@ -226,7 +258,12 @@ def _prepare_channel( raise ValueError("Duplicate channel {} not allowed.".format(channel_name)) channel_input = _Job._format_string_uri_input( - channel_uri, validate_uri, content_type, input_mode + channel_uri, + validate_uri, + content_type, + input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) channel = _Job._convert_input_to_channel(channel_name, channel_input) diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 90ee7dea8d..c1ad9710f1 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -29,6 +29,7 @@ get_region_fallback, verify_model_region_and_return_specs, ) +from sagemaker.s3_utils import is_s3_url from sagemaker.session import Session from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str: """Returns instance specific training artifact key or default one as fallback.""" instance_specific_training_artifact_key: Optional[str] = ( - model_specs.training_instance_type_variants.get_instance_specific_artifact_key( + model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key( instance_type=instance_type ) if instance_type @@ -185,8 +186,8 @@ def _retrieve_model_uri( os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE) or default_jumpstart_bucket ) - - model_s3_uri = f"s3://{bucket}/{model_artifact_key}" + if not is_s3_url(model_artifact_key): + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" return model_s3_uri diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 50f197c30e..af2fb5bc54 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -41,6 +41,9 @@ validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + remove_env_var_from_estimator_kwargs_if_accept_eula_present, + get_model_access_config, + get_hub_access_config, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -619,6 +622,10 @@ def _validate_model_id_and_get_type_hook(): self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation self.config_name = estimator_init_kwargs.config_name self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) + # Access configs initialized to None, would be given a value when .fit() is called + # if applicable + self.model_access_config = None + self.hub_access_config = None super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -629,6 +636,7 @@ def fit( logs: Optional[str] = None, job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Start training job by calling base ``Estimator`` class ``fit`` method. @@ -679,8 +687,16 @@ def fit( is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. However, the value of `TrialComponentDisplayName` is honored for display in Studio. (Default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ - + self.model_access_config = get_model_access_config(accept_eula) + self.hub_access_config = get_hub_access_config( + hub_content_arn=self.init_kwargs.get("model_reference_arn", None) + ) estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, @@ -695,7 +711,9 @@ def fit( tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, config_name=self.config_name, + hub_access_config=self.hub_access_config, ) + remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 2a54d9c4de..17ad7a76f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -71,7 +71,6 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, - get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, get_top_ranked_config_name, update_dict_if_key_not_present, @@ -265,6 +264,7 @@ def get_fit_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, config_name: Optional[str] = None, + hub_access_config: Optional[Dict] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -301,10 +301,32 @@ def get_fit_kwargs( estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs) + estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs( + estimator_fit_kwargs, hub_access_config + ) return estimator_fit_kwargs +def _add_hub_access_config_to_kwargs_inputs( + kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None +): + """Adds HubAccessConfig to kwargs inputs""" + + if isinstance(kwargs.inputs, str): + kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, TrainingInput): + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, dict): + for k, v in kwargs.inputs.items(): + if isinstance(v, str): + kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, TrainingInput): + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -668,18 +690,6 @@ def _add_env_to_kwargs( value, ) - environment = getattr(kwargs, "environment", {}) or {} - if ( - environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) - and str(environment.get("accept_eula", "")).lower() != "true" - ): - model_specs = kwargs.specs - if model_specs.is_gated_model(): - raise ValueError( - "Need to define ‘accept_eula'='true' within Environment. " - f"{get_eula_message(model_specs, kwargs.region)}" - ) - return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 908241812e..349396205e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -619,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]: + """Returns instance specific training artifact key. + + Returns None if a model, instance type tuple does not have specific + training artifact key. + """ + + return self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_uri" + ) or self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_key" + ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: """Returns instance specific resource requirements. diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 23245b24e5..bd81226727 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1632,6 +1632,47 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str: return neo_bucket +def remove_env_var_from_estimator_kwargs_if_accept_eula_present( + init_kwargs: dict, accept_eula: Optional[bool] +): + """Remove env vars if access configs are used + + Args: + init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if accept_eula is not None and init_kwargs["environment"]: + del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY] + + +def get_hub_access_config(hub_content_arn: Optional[str]): + """Get hub access config + + Args: + hub_content_arn (Optional[bool]): Arn of the model reference hub content + """ + if hub_content_arn is not None: + hub_access_config = {"HubContentArn": hub_content_arn} + else: + hub_access_config = None + + return hub_access_config + + +def get_model_access_config(accept_eula: Optional[bool]): + """Get access configs + + Args: + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if accept_eula is not None: + model_access_config = {"AcceptEula": accept_eula} + else: + model_access_config = None + + return model_access_config + + def get_latest_version(versions: List[str]) -> Optional[str]: """Returns the latest version using sem-ver when possible.""" try: diff --git a/src/sagemaker/s3_utils.py b/src/sagemaker/s3_utils.py index e53cdbe02a..f59c8a299f 100644 --- a/src/sagemaker/s3_utils.py +++ b/src/sagemaker/s3_utils.py @@ -45,6 +45,19 @@ def parse_s3_url(url): return parsed_url.netloc, parsed_url.path.lstrip("/") +def is_s3_url(url): + """Returns True if url is an s3 url, False if not + + Args: + url (str): + + Returns: + bool: + """ + parsed_url = urlparse(url) + return parsed_url.scheme == "s3" + + def s3_path_join(*args, with_end_slash: bool = False): """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index e8e5cc0942..a64db4a97d 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -122,9 +122,10 @@ def test_jumpstart_hub_gated_model(setup, add_model_references): assert response is not None +@pytest.mark.skip(reason="blocking PR checks and release pipeline.") def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): - model_id = "meta-textgeneration-llama-2-7b" + model_id = "meta-textgeneration-llama-3-2-1b" hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 59f38bd189..4021599120 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -3059,7 +3059,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -3135,7 +3135,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -3214,13 +3214,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -5046,7 +5046,7 @@ "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m5": { "regional_properties": {"image_uri": "$cpu_ecr_uri_1"}, - "properties": {"artifact_key": "hello-world-1"}, + "properties": {"training_artifact_key": "hello-world-1"}, }, "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, @@ -17234,13 +17234,13 @@ "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -17249,13 +17249,13 @@ "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 1fd2a47aca..4a64b413f4 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -392,23 +392,6 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - with pytest.raises(ValueError) as e: - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - assert str(e.value) == ( - "Need to define ‘accept_eula'='true' within Environment. " - "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " - "license agreement (EULA). See " - "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" - " for terms of use." - ) - mock_estimator_init.reset_mock() estimator = JumpStartEstimator(model_id=model_id, environment={"accept_eula": "true"}) @@ -510,6 +493,151 @@ def test_gated_model_s3_uri( ], ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_s3_uri_with_eula_in_fit( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_timestamp: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_timestamp.return_value = "8675309" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + model_id, _ = "js-gated-artifact-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + mock_estimator_init.reset_mock() + + estimator = JumpStartEstimator(model_id=model_id) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "1", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "True", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "1", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role=execution_role, + sagemaker_session=sagemaker_session, + max_run=360000, + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + environment={ + "SageMakerGatedModelS3Uri": "s3://sagemaker-repository-pdx/" + "model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels, accept_eula=True) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, + wait=True, + job_name="meta-textgeneration-llama-2-7b-f-8675309", + ) + + assert hasattr(estimator, "model_access_config") + assert hasattr(estimator, "hub_access_config") + + assert estimator.model_access_config == {"AcceptEula": True} + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-f-8675309", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118", + wait=True, + model_data_download_timeout=3600, + container_startup_health_check_timeout=3600, + role=execution_role, + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-f-8675309", + use_compiled_model=False, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + @mock.patch( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @@ -1218,7 +1346,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["kwargs"]) - fit_args_to_skip: Set[str] = set() + fit_args_to_skip: Set[str] = set(["accept_eula"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Estimator.__init__ @@ -1243,8 +1371,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_fit = JumpStartEstimator.fit js_class_fit_args = set(signature(js_class_fit).parameters.keys()) - assert js_class_fit_args - parent_class_fit_args == set() - assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip + assert js_class_fit_args - parent_class_fit_args == fit_args_to_skip + assert parent_class_fit_args - js_class_fit_args == set() model_class_init = Model.__init__ model_class_init_args = set(signature(model_class_init).parameters.keys()) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index 11798bc854..ebd90d98d2 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -923,15 +923,13 @@ def test_hub_content_document_from_json_obj(): "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -940,15 +938,13 @@ def test_hub_content_document_from_json_obj(): "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index e687a1c4ac..75aa93a920 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -176,7 +176,7 @@ def test_retrieve_training_artifact_key(self): "image_uri": "$alias_ecr_uri_1", }, "properties": { - "artifact_key": "in/the/way", + "training_artifact_key": "in/the/way", }, } }, diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 3efa8c8c81..acce8ef4f1 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -117,7 +117,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -193,7 +193,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -272,13 +272,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -952,27 +952,35 @@ def test_jumpstart_hosting_prepacked_artifact_key_instance_variants(): def test_jumpstart_training_artifact_key_instance_variants(): assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g6.xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g6.xlarge" + ) == "path/to/training/artifact.tar.gz" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g4.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g4.9xlarge" + ) == "path/to/prepacked/training/artifact/prefix/number2/" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.9xlarge" + ) == "do/re/mi" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.12xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.12xlarge" + ) == "you/not/entertained" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key( + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( instance_type="ml.g9dsfsdfs.12xlarge" ) is None diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index 7d9c2b2c2f..133c31eb75 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -41,6 +41,8 @@ def test_training_input_all_arguments(): record_wrapping = "RecordIO" s3_data_type = "Manifestfile" input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -49,6 +51,8 @@ def test_training_input_all_arguments(): content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { "DataSource": { @@ -56,6 +60,8 @@ def test_training_input_all_arguments(): "S3DataDistributionType": distribution, "S3DataType": s3_data_type, "S3Uri": prefix, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, @@ -76,6 +82,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): s3_data_type = "Manifestfile" instance_groups = ["data-server"] input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -85,6 +93,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): record_wrapping=record_wrapping, s3_data_type=s3_data_type, instance_groups=instance_groups, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { @@ -94,6 +104,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): "S3DataType": s3_data_type, "S3Uri": prefix, "InstanceGroupNames": instance_groups, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index c93a381c11..dc21f50b68 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -206,6 +206,32 @@ def test_load_config_with_model_channel_no_inputs(estimator): assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME +def test_load_config_with_access_configs(estimator): + estimator.model_uri = MODEL_URI + estimator.model_channel_name = MODEL_CHANNEL_NAME + estimator.model_access_config = {"AcceptEula": True} + estimator.hub_access_config = {"HubContentArn": "dummy_arn"} + + config = _Job._load_config(inputs=None, estimator=estimator) + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == estimator.model_access_config + ) + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["HubAccessConfig"] + == estimator.hub_access_config + ) + + def test_load_config_with_code_channel(framework): inputs = TrainingInput(BUCKET_NAME) @@ -347,20 +373,43 @@ def test_format_record_set_list_input(): @pytest.mark.parametrize( - "channel_uri, channel_name, content_type, input_mode", + "channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config", [ - [MODEL_URI, MODEL_CHANNEL_NAME, "application/x-sagemaker-model", "File"], - [CODE_URI, CODE_CHANNEL_NAME, None, None], + [ + MODEL_URI, + MODEL_CHANNEL_NAME, + "application/x-sagemaker-model", + "File", + {"AcceptEula": True}, + None, + ], + [CODE_URI, CODE_CHANNEL_NAME, None, None, None, {"HubContentArn": "dummy_arn"}], ], ) -def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): +def test_prepare_channel( + channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config +): channel = _Job._prepare_channel( - [], channel_uri, channel_name, content_type=content_type, input_mode=input_mode + [], + channel_uri, + channel_name, + content_type=content_type, + input_mode=input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) assert channel["DataSource"]["S3DataSource"]["S3Uri"] == channel_uri assert channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] == "FullyReplicated" assert channel["DataSource"]["S3DataSource"]["S3DataType"] == "S3Prefix" + if hub_access_config: + assert channel["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + else: + assert "HubAccessConfig" not in channel["DataSource"]["S3DataSource"] + if model_access_config: + assert channel["DataSource"]["S3DataSource"]["ModelAccessConfig"] == model_access_config + else: + assert "ModelAccessConfig" not in channel["DataSource"]["S3DataSource"] assert channel["ChannelName"] == channel_name assert "CompressionType" not in channel assert "RecordWrapperType" not in channel @@ -546,6 +595,23 @@ def test_format_string_uri_input_string(): assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs +def test_format_string_uri_input_string_with_access_configs(): + inputs = BUCKET_NAME + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_file_system_input(): file_system_id = "fs-fd85e556" file_system_type = "EFS" @@ -585,6 +651,26 @@ def test_format_string_uri_input(): ) +def test_format_string_uri_input_with_access_configs(): + inputs = TrainingInput(BUCKET_NAME) + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_input_exception(): inputs = 1 diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index a226954986..b54552cacb 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -17,6 +17,7 @@ from mock import Mock from sagemaker import s3 +from sagemaker.s3_utils import is_s3_url BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -132,6 +133,34 @@ def test_parse_s3_url_fail(): assert "Expecting 's3' scheme" in str(error) +@pytest.mark.parametrize( + "input_url", + [ + ("s3://bucket/code_location"), + ("s3://bucket/code_location/sub_location"), + ("s3://bucket/code_location/sub_location/"), + ("s3://bucket/"), + ("s3://bucket"), + ], +) +def test_is_s3_url_true(input_url): + assert is_s3_url(input_url) is True + + +@pytest.mark.parametrize( + "input_url", + [ + ("bucket/code_location"), + ("bucket/code_location/sub_location"), + ("sub_location/"), + ("s3/bucket/"), + ("t3://bucket"), + ], +) +def test_is_s3_url_false(input_url): + assert is_s3_url(input_url) is False + + @pytest.mark.parametrize( "expected_output, input_args", [