From a0fd751724e73cc67b3d62e7ccfd0e7ab736adf2 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 03:13:08 +0000 Subject: [PATCH 01/10] supported inference configs --- src/sagemaker/jumpstart/estimator.py | 26 +++-- src/sagemaker/jumpstart/factory/estimator.py | 104 +++++++++++++++--- src/sagemaker/jumpstart/types.py | 60 ++++++---- tests/unit/sagemaker/jumpstart/constants.py | 26 +++++ .../jumpstart/estimator/test_estimator.py | 55 +++++++-- .../jumpstart/test_notebook_utils.py | 2 +- tests/unit/sagemaker/jumpstart/test_types.py | 7 ++ 7 files changed, 222 insertions(+), 58 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4939be4041..0713d3a0fa 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -111,7 +111,8 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - config_name: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ): """Initializes a ``JumpStartEstimator``. @@ -503,8 +504,11 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job - config_name (Optional[str]): - Name of the JumpStart Model config to apply. (Default: None). + training_config_name (Optional[str]): + Name of the training configuration to apply to the Estimator. (Default: None). + inference_config_name (Optional[str]): + Name of the inference configuraion to apply to the Estimator, + to be used when deploying the fine-tuned mode. (Default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -583,7 +587,8 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - config_name=config_name, + training_config_name=training_config_name, + inference_config_name=inference_config_name, ) self.model_id = estimator_init_kwargs.model_id @@ -597,7 +602,8 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation - self.config_name = estimator_init_kwargs.config_name + self.training_config_name = estimator_init_kwargs.training_config_name + self.inference_config_name = estimator_init_kwargs.inference_config_name self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -673,7 +679,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, - config_name=self.config_name, + config_name=self.training_config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -1091,7 +1097,7 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, - config_name=self.config_name, + config_name=self.inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1108,7 +1114,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - config_name=self.config_name, + config_name=self.inference_config_name, ) # If a predictor class was passed, do not mutate predictor @@ -1140,7 +1146,9 @@ def set_training_config(self, config_name: str) -> None: config_name (str): The name of the config. """ self.__init__( - model_id=self.model_id, model_version=self.model_version, config_name=config_name + model_id=self.model_id, + model_version=self.model_version, + training_config_name=config_name, ) def __str__(self) -> str: diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 604b20bc81..9b9959f00b 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,6 +29,7 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -130,7 +131,8 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - config_name: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -189,7 +191,8 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - config_name=config_name, + training_config_name=training_config_name, + inference_config_name=inference_config_name, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) @@ -207,6 +210,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs) return estimator_init_kwargs @@ -449,7 +453,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -473,7 +477,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -481,7 +485,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags, kwargs.model_id, full_model_version, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, scope=JumpStartScriptScope.TRAINING, ) return kwargs @@ -500,7 +504,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) return kwargs @@ -526,7 +530,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) if ( @@ -539,7 +543,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) ): JUMPSTART_LOGGER.warning( @@ -575,7 +579,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) return kwargs @@ -596,7 +600,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -607,7 +611,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) if model_package_artifact_uri: @@ -635,7 +639,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -696,7 +700,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) for key, value in default_hyperparameters.items(): @@ -730,7 +734,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) or [] ) @@ -760,7 +764,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + config_name=kwargs.training_config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -793,3 +797,73 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim setattr(kwargs, key, value) return kwargs + + +def _add_config_name_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets tags in kwargs based on default or override, returns full kwargs.""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.training_config_name, + ) + + if kwargs.base_job_name: + _, _, _, base_training_config_name = get_model_info_from_training_job( + training_job_name=kwargs.base_job_name, sagemaker_session=kwargs.sagemaker_session + ) + + kwargs.training_config_name = ( + kwargs.training_config_name + or specs.training_configs.configs.get( + base_training_config_name + ).default_incremental_trainig_config + or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config + ) + + if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name: + kwargs.training_config_name = ( + kwargs.training_config_name + or specs.training_configs.get_top_config_from_ranking().config_name + ) + + kwargs.inference_config_name = ( + kwargs.inference_config_name + or specs.training_configs.configs.get( + kwargs.training_config_name + ).default_inference_config + ) + + if ( + kwargs.inference_config_name + and kwargs.inference_config_name + not in specs.training_configs.configs.get( + kwargs.training_config_name + ).supported_inference_configs + ): + raise ValueError( + f"Inference config {kwargs.inference_config_name} is not supported for model {kwargs.model_id}." + ) + + if not kwargs.training_config_name: + return kwargs + + resolved_config = specs.training_configs.configs[ + kwargs.training_config_name + ].resolved_config + supported_instance_types = resolved_config.get("supported_training_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + raise ValueError( + f"Instance type {kwargs.instance_type} " + f"is not supported for config {kwargs.training_config_name}." + ) + + return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e0a0f9bea7..cf147d6b57 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): "config_components", "resolved_metadata_config", "config_name", + "default_inference_config", + "default_incremental_trainig_config", + "supported_inference_configs", + "supported_incremental_training_configs", ] def __init__( self, config_name: str, + config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], ): """Initializes a JumpStartMetadataConfig object from its json representation. Args: + config_name (str): Name of the config, + config (Dict[str, Any]): + Dictionary representation of the config. base_fields (Dict[str, Any]): The default base fields that are used to construct the final resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): - The dictionary of benchmark metrics with name being the key. """ self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( + { + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() + } + if config and config.get("benchmark_metrics") + else None + ) self.resolved_metadata_config: Optional[Dict[str, Any]] = None self.config_name: Optional[str] = config_name + self.default_inference_config: Optional[str] = config.get("default_inference_config") + self.default_incremental_trainig_config: Optional[str] = config.get( + "default_incremental_training_config" + ) + self.supported_inference_configs: Optional[List[str]] = config.get( + "supported_inference_configs" + ) + self.supported_incremental_training_configs: Optional[List[str]] = config.get( + "supported_incremental_training_configs" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: { alias: JumpStartMetadataConfig( alias, + config, json_obj, ( { @@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["inference_configs"].items() } @@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: { alias: JumpStartMetadataConfig( alias, + config, json_obj, ( { @@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["training_configs"].items() } @@ -1766,7 +1774,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", - "config_name", + "training_config_name", + "inference_config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1776,7 +1785,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", - "config_name", + "training_config_name", + "inference_config_name", } def __init__( @@ -1835,7 +1845,8 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - config_name: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1895,7 +1906,8 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug - self.config_name = config_name + self.training_config_name = training_config_name + self.inference_config_name = inference_config_name class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 90f037daea..3815bfc9ef 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7752,6 +7752,10 @@ "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "neuron-training-budget": { "benchmark_metrics": { @@ -7759,24 +7763,43 @@ "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], }, "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, "gpu-training-budget": { "benchmark_metrics": { "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, }, "training_config_components": { "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -7788,6 +7811,7 @@ }, }, "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", "training_instance_type_variants": { @@ -7804,6 +7828,7 @@ }, }, "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", "training_instance_type_variants": { @@ -7817,6 +7842,7 @@ }, }, "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", "training_instance_type_variants": { diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 2202b15ece..cd1d329e0a 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -682,7 +682,6 @@ def test_estimator_use_kwargs(self): "input_mode": "File", "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", - "base_job_name": "Optional[str] = None", "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [], @@ -1117,7 +1116,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", - "config_name", + "training_config_name", + "inference_config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1884,14 +1884,17 @@ def test_estimator_initialization_with_config_name( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, config_name="neuron-training") + estimator = JumpStartEstimator( + model_id=model_id, + training_config_name="gpu-training", + ) mock_estimator_init.assert_called_once_with( instance_type="ml.p2.xlarge", instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" - "neuron-training/model/", + "gpu-training/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -1901,7 +1904,7 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"}, ], enable_network_isolation=False, ) @@ -1934,16 +1937,16 @@ def test_estimator_set_config_name( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id) + estimator = JumpStartEstimator(model_id=model_id, training_config_name="gpu-training") - estimator.set_training_config(config_name="neuron-training") + estimator.set_training_config(config_name="gpu-training-budget") mock_estimator_init.assert_called_with( instance_type="ml.p2.xlarge", instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" - "neuron-training/model/", + "gpu-training-budget/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -1953,7 +1956,7 @@ def test_estimator_set_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"}, ], enable_network_isolation=False, ) @@ -1962,6 +1965,40 @@ def test_estimator_set_config_name( mock_estimator_fit.assert_called_once_with(wait=True) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_default_inference_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, training_config_name="gpu-training") + + assert estimator.inference_config_name == "gpu-inference" + assert estimator.training_config_name == "gpu-training" + + estimator.set_training_config("gpu-training-budget") + + assert estimator.inference_config_name == "gpu-inference-budget" + assert estimator.training_config_name == "gpu-training-budget" + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c00d271ef1..24dbce583d 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -228,7 +228,7 @@ def test_list_jumpstart_models_simple_case( patched_get_model_specs.assert_not_called() @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), + datetime.datetime.now() < datetime.datetime(year=2024, month=6, day=1), reason="Contact JumpStart team to fix flaky test.", ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index c52bf76f4e..9b4f5a26a2 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1193,8 +1193,15 @@ def test_training_configs_parsing(): assert config.config_components["neuron-training"] == JumpStartConfigComponent( "neuron-training", { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1" + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { From 7f08766d90beb24150f55d9fb478a6488f7c2f09 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 13:55:38 +0000 Subject: [PATCH 02/10] add tests --- src/sagemaker/jumpstart/estimator.py | 8 ++-- src/sagemaker/jumpstart/factory/estimator.py | 6 +-- .../jumpstart/estimator/test_estimator.py | 47 ++++++++++++++++--- tests/unit/sagemaker/jumpstart/test_types.py | 2 +- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 0713d3a0fa..73b66aa3f1 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -111,7 +111,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - training_config_name: Optional[str] = None, + config_name: Optional[str] = None, inference_config_name: Optional[str] = None, ): """Initializes a ``JumpStartEstimator``. @@ -504,7 +504,7 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job - training_config_name (Optional[str]): + config_name (Optional[str]): Name of the training configuration to apply to the Estimator. (Default: None). inference_config_name (Optional[str]): Name of the inference configuraion to apply to the Estimator, @@ -587,7 +587,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - training_config_name=training_config_name, + training_config_name=config_name, inference_config_name=inference_config_name, ) @@ -1148,7 +1148,7 @@ def set_training_config(self, config_name: str) -> None: self.__init__( model_id=self.model_id, model_version=self.model_version, - training_config_name=config_name, + config_name=config_name, ) def __str__(self) -> str: diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 9b9959f00b..07ebbcc1bd 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -812,7 +812,6 @@ def _add_config_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, config_name=kwargs.training_config_name, ) @@ -826,7 +825,7 @@ def _add_config_name_to_kwargs( or specs.training_configs.configs.get( base_training_config_name ).default_incremental_trainig_config - or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config + or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config # pylint: disable=c0301 ) if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name: @@ -850,7 +849,8 @@ def _add_config_name_to_kwargs( ).supported_inference_configs ): raise ValueError( - f"Inference config {kwargs.inference_config_name} is not supported for model {kwargs.model_id}." + f"Inference config {kwargs.inference_config_name}" + f"is not supported for model {kwargs.model_id}." ) if not kwargs.training_config_name: diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index cd1d329e0a..fa272ed7cf 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1116,7 +1116,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", - "training_config_name", + "config_name", "inference_config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1886,7 +1886,7 @@ def test_estimator_initialization_with_config_name( estimator = JumpStartEstimator( model_id=model_id, - training_config_name="gpu-training", + config_name="gpu-training", ) mock_estimator_init.assert_called_once_with( @@ -1937,7 +1937,7 @@ def test_estimator_set_config_name( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, training_config_name="gpu-training") + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") estimator.set_training_config(config_name="gpu-training-budget") @@ -1969,11 +1969,9 @@ def test_estimator_set_config_name( @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") - @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_estimator_default_inference_config( self, - mock_estimator_init: mock.Mock, mock_estimator_fit: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, @@ -1989,7 +1987,7 @@ def test_estimator_default_inference_config( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, training_config_name="gpu-training") + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") assert estimator.inference_config_name == "gpu-inference" assert estimator.training_config_name == "gpu-training" @@ -1999,6 +1997,43 @@ def test_estimator_default_inference_config( assert estimator.inference_config_name == "gpu-inference-budget" assert estimator.training_config_name == "gpu-training-budget" + @mock.patch("sagemaker.jumpstart.factory.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_incremental_training_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + ): + mock_get_model_info_from_training_job.return_value = ( + "js-trainable-model-prepacked", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, base_job_name="base_job") + + assert estimator.inference_config_name == "gpu-inference-budget" + assert estimator.training_config_name == "gpu-training-budget" + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 9b4f5a26a2..b7e5f16ae1 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1200,7 +1200,7 @@ def test_training_configs_parsing(): "framework": "huggingface", "framework_version": "2.0.0", "py_version": "py310", - "huggingface_transformers_version": "4.28.1" + "huggingface_transformers_version": "4.28.1", }, "training_instance_type_variants": { "regional_aliases": { From 1a1a6bb43ce1f0e451dae80b8c820c2776a459b5 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 14:07:01 +0000 Subject: [PATCH 03/10] format --- src/sagemaker/jumpstart/factory/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 07ebbcc1bd..ca0fbade7a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -825,7 +825,7 @@ def _add_config_name_to_kwargs( or specs.training_configs.configs.get( base_training_config_name ).default_incremental_trainig_config - or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config # pylint: disable=c0301 + or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config # noqa E501 # pylint: disable=c0301 ) if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name: From edf61af9a0ce9d0346d58131f021b7144fa915b7 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 14:10:20 +0000 Subject: [PATCH 04/10] tests --- .../sagemaker/jumpstart/test_notebook_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 24dbce583d..50f35d19bb 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -227,10 +227,6 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=6, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_script_filter( @@ -240,7 +236,7 @@ def test_list_jumpstart_models_script_filter( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) patched_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region) ) manifest_length = len(get_prototype_manifest()) @@ -248,15 +244,15 @@ def test_list_jumpstart_models_script_filter( for val in vals: kwargs = {"filter": f"training_supported == {val}"} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * manifest_length + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -273,7 +269,7 @@ def test_list_jumpstart_models_script_filter( ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -282,7 +278,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": f"training_supported not in {vals}"} models = list_jumpstart_models(**kwargs) assert [] == models - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") From 4ec00fdf7f91f7d3401b8a08412a3217a6f13d21 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 14:26:48 +0000 Subject: [PATCH 05/10] tests --- .../jumpstart/estimator/test_estimator.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index fa272ed7cf..84b17ffd11 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -2034,6 +2034,67 @@ def test_estimator_incremental_training_config( assert estimator.inference_config_name == "gpu-inference-budget" assert estimator.training_config_name == "gpu-training-budget" + @mock.patch("sagemaker.jumpstart.factory.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_deploy_with_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_info_from_training_job.return_value = ( + "js-trainable-model-prepacked", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, base_job_name="base_job") + + assert estimator.inference_config_name == "gpu-inference-budget" + assert estimator.training_config_name == "gpu-training-budget" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"}, + ], + ) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): From 634073609e42d4f59456e16f6a3273bfeb336848 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 21:49:31 +0000 Subject: [PATCH 06/10] address comments --- src/sagemaker/jumpstart/estimator.py | 27 +++--- src/sagemaker/jumpstart/factory/estimator.py | 94 +++++-------------- src/sagemaker/jumpstart/factory/model.py | 71 +++++++++++++- src/sagemaker/jumpstart/types.py | 12 +-- tests/integ/sagemaker/jumpstart/constants.py | 1 + .../estimator/test_jumpstart_estimator.py | 2 +- .../jumpstart/estimator/test_estimator.py | 74 ++++++++++----- 7 files changed, 165 insertions(+), 116 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 73b66aa3f1..3132ea4d26 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -112,7 +112,6 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, - inference_config_name: Optional[str] = None, ): """Initializes a ``JumpStartEstimator``. @@ -506,9 +505,6 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job config_name (Optional[str]): Name of the training configuration to apply to the Estimator. (Default: None). - inference_config_name (Optional[str]): - Name of the inference configuraion to apply to the Estimator, - to be used when deploying the fine-tuned mode. (Default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -587,8 +583,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - training_config_name=config_name, - inference_config_name=inference_config_name, + config_name=config_name, ) self.model_id = estimator_init_kwargs.model_id @@ -602,8 +597,7 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation - self.training_config_name = estimator_init_kwargs.training_config_name - self.inference_config_name = estimator_init_kwargs.inference_config_name + self.config_name = estimator_init_kwargs.config_name self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -679,7 +673,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, - config_name=self.training_config_name, + config_name=self.config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -692,6 +686,7 @@ def attach( model_version: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", + config_name: Optional[str] = None, ) -> "JumpStartEstimator": """Attach to an existing training job. @@ -727,6 +722,8 @@ def attach( model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. + config_name (str): Optional. Name of the training configuration to use + when attaching to the training job. (Default: None). Returns: Instance of the calling ``JumpStartEstimator`` Class with the attached @@ -738,7 +735,6 @@ def attach( """ config_name = None if model_id is None: - model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -752,6 +748,9 @@ def attach( "tolerate_deprecated_model": True, # model is already trained } + if config_name: + additional_kwargs.update({"config_name": config_name}) + model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -810,6 +809,7 @@ def deploy( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, use_compiled_model: bool = False, + inference_config_name: Optional[str] = None, ) -> PredictorBase: """Creates endpoint from training job. @@ -1045,6 +1045,8 @@ def deploy( (Default: None). use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). + inference_config_name (Optional[str]): Name of the inference configuration to + be used in the model. (Default: None). """ self.orig_predictor_cls = predictor_cls @@ -1097,7 +1099,8 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, - config_name=self.inference_config_name, + training_config_name=self.config_name, + inference_config_name=inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1114,7 +1117,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - config_name=self.inference_config_name, + config_name=estimator_deploy_kwargs.config_name, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index ca0fbade7a..8a28a69aaa 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -131,8 +131,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - training_config_name: Optional[str] = None, - inference_config_name: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -191,8 +190,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - training_config_name=training_config_name, - inference_config_name=inference_config_name, + config_name=config_name, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) @@ -295,7 +293,8 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, - config_name: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -323,7 +322,8 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, - config_name=config_name, + training_config_name=training_config_name, + config_name=inference_config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -352,7 +352,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, - config_name=config_name, + config_name=model_deploy_kwargs.config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -397,7 +397,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, - config_name=config_name, + config_name=model_deploy_kwargs.config_name, ) return estimator_deploy_kwargs @@ -453,7 +453,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -477,7 +477,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -485,7 +485,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags, kwargs.model_id, full_model_version, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, scope=JumpStartScriptScope.TRAINING, ) return kwargs @@ -504,7 +504,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) return kwargs @@ -530,7 +530,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) if ( @@ -543,7 +543,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) ): JUMPSTART_LOGGER.warning( @@ -579,7 +579,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) return kwargs @@ -600,7 +600,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -611,7 +611,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -639,7 +639,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -700,7 +700,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -734,7 +734,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) or [] ) @@ -764,7 +764,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -812,58 +812,12 @@ def _add_config_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.training_config_name, + config_name=kwargs.config_name, ) - if kwargs.base_job_name: - _, _, _, base_training_config_name = get_model_info_from_training_job( - training_job_name=kwargs.base_job_name, sagemaker_session=kwargs.sagemaker_session - ) - - kwargs.training_config_name = ( - kwargs.training_config_name - or specs.training_configs.configs.get( - base_training_config_name - ).default_incremental_trainig_config - or specs.training_configs.get_top_config_from_ranking().default_incremental_trainig_config # noqa E501 # pylint: disable=c0301 - ) - if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name: - kwargs.training_config_name = ( - kwargs.training_config_name - or specs.training_configs.get_top_config_from_ranking().config_name - ) - - kwargs.inference_config_name = ( - kwargs.inference_config_name - or specs.training_configs.configs.get( - kwargs.training_config_name - ).default_inference_config + kwargs.config_name = ( + kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name ) - if ( - kwargs.inference_config_name - and kwargs.inference_config_name - not in specs.training_configs.configs.get( - kwargs.training_config_name - ).supported_inference_configs - ): - raise ValueError( - f"Inference config {kwargs.inference_config_name}" - f"is not supported for model {kwargs.model_id}." - ) - - if not kwargs.training_config_name: - return kwargs - - resolved_config = specs.training_configs.configs[ - kwargs.training_config_name - ].resolved_config - supported_instance_types = resolved_config.get("supported_training_instance_types", []) - if kwargs.instance_type not in supported_instance_types: - raise ValueError( - f"Instance type {kwargs.instance_type} " - f"is not supported for config {kwargs.training_config_name}." - ) - return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 54301973e8..f44c715882 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -42,6 +42,7 @@ JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, + JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( add_jumpstart_model_info_tags, @@ -548,7 +549,28 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs -def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _select_inference_config_from_training_config( + specs: JumpStartModelSpecs, training_config_name: str +) -> Optional[str]: + """Selects the inference config from the training config. + + Args: + specs (JumpStartModelSpecs): The specs for the model. + training_config_name (str): The name of the training config. + + Returns: + str: The name of the inference config. + """ + if ( + specs.training_configs + and specs.training_configs.configs.get(training_config_name).default_inference_config + ): + return specs.training_configs.configs.get(training_config_name).default_inference_config + + return None + + +def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets default config name to the kwargs. Returns full kwargs. Raises: @@ -592,6 +614,45 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod return kwargs +def _add_config_name_to_deploy_kwargs( + kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + If a training_config_name is passed, then choose the inference config + based on the supported inference configs in that training config. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + if training_config_name: + kwargs.config_name = _select_inference_config_from_training_config( + specs=specs, training_config_name=training_config_name + ) + + if ( + specs.inference_configs + and specs.inference_configs.get_top_config_from_ranking().config_name + ): + kwargs.config_name = ( + kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name + ) + + return kwargs + def get_deploy_kwargs( model_id: str, @@ -623,6 +684,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + training_config_name: Optional[str] = None, config_name: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -664,6 +726,10 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -858,6 +924,7 @@ def get_init_kwargs( model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs) + + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) return model_init_kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cf147d6b57..0a586f60aa 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1774,8 +1774,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", - "training_config_name", - "inference_config_name", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1785,8 +1784,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", - "training_config_name", - "inference_config_name", + "config_name", } def __init__( @@ -1845,8 +1843,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - training_config_name: Optional[str] = None, - inference_config_name: Optional[str] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1906,8 +1903,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug - self.training_config_name = training_config_name - self.inference_config_name = inference_config_name + self.config_name = config_name class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index f5ffbf7a3a..b839866b1f 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -48,6 +48,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "3.*"): ("training-datasets/sec_amazon/"), + ("meta-textgeneration-llama-2-7b", "4.*"): ("training-datasets/sec_amazon/"), ("meta-textgenerationneuron-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), } diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index b7aec3b555..fb3d1ebd1f 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -140,7 +140,7 @@ def test_gated_model_training_v1(setup): def test_gated_model_training_v2(setup): model_id = "meta-textgeneration-llama-2-7b" - model_version = "3.*" # model artifacts retrieved from jumpstart-private-cache-* buckets + model_version = "4.*" # model artifacts retrieved from jumpstart-private-cache-* buckets estimator = JumpStartEstimator( model_id=model_id, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 84b17ffd11..bba98c7ff3 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1117,7 +1117,6 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "tolerate_vulnerable_model", "tolerate_deprecated_model", "config_name", - "inference_config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1139,7 +1138,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_deploy = JumpStartEstimator.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - { + assert js_class_deploy_args - parent_class_deploy_args - { + "inference_config_name" + } == model_class_init_args - { "model_data", "self", "name", @@ -1968,15 +1969,18 @@ def test_estimator_set_config_name( @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_estimator_default_inference_config( self, mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, ): + mock_estimator_deploy.return_value = default_predictor mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) @@ -1989,15 +1993,31 @@ def test_estimator_default_inference_config( estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") - assert estimator.inference_config_name == "gpu-inference" - assert estimator.training_config_name == "gpu-training" + assert estimator.config_name == "gpu-training" - estimator.set_training_config("gpu-training-budget") + estimator.deploy() - assert estimator.inference_config_name == "gpu-inference-budget" - assert estimator.training_config_name == "gpu-training-budget" + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"}, + ], + ) - @mock.patch("sagemaker.jumpstart.factory.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -2012,9 +2032,10 @@ def test_estimator_incremental_training_config( mock_session: mock.Mock, mock_get_manifest: mock.Mock, mock_get_model_info_from_training_job: mock.Mock, + mock_attach: mock.Mock, ): mock_get_model_info_from_training_job.return_value = ( - "js-trainable-model-prepacked", + "pytorch-eqa-bert-base-cased", "1.0.0", None, "gpu-training-budget", @@ -2029,12 +2050,27 @@ def test_estimator_incremental_training_config( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, base_job_name="base_job") + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", sagemaker_session=mock_session + ) - assert estimator.inference_config_name == "gpu-inference-budget" - assert estimator.training_config_name == "gpu-training-budget" + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "pytorch-eqa-bert-base-cased", + "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, + "config_name": "gpu-training-budget", + }, + ) - @mock.patch("sagemaker.jumpstart.factory.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -2050,15 +2086,8 @@ def test_estimator_deploy_with_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_model_info_from_training_job: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_get_model_info_from_training_job.return_value = ( - "js-trainable-model-prepacked", - "1.0.0", - None, - "gpu-training-budget", - ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) @@ -2069,10 +2098,9 @@ def test_estimator_deploy_with_config( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, base_job_name="base_job") + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget") - assert estimator.inference_config_name == "gpu-inference-budget" - assert estimator.training_config_name == "gpu-training-budget" + assert estimator.config_name == "gpu-training-budget" estimator.deploy() From dc33402e3b05e8c37d799e63235cabb5b80e0f1b Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 22:26:24 +0000 Subject: [PATCH 07/10] format and address comments --- src/sagemaker/jumpstart/factory/estimator.py | 3 +-- src/sagemaker/jumpstart/factory/model.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 8a28a69aaa..9177265d74 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,7 +29,6 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base -from sagemaker.jumpstart.session_utils import get_model_info_from_training_job from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -815,7 +814,7 @@ def _add_config_name_to_kwargs( config_name=kwargs.config_name, ) - if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name: + if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): kwargs.config_name = ( kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f44c715882..e04f672df8 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -590,7 +590,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta ) if ( specs.inference_configs - and specs.inference_configs.get_top_config_from_ranking().config_name + and specs.inference_configs.get_top_config_from_ranking() ): kwargs.config_name = ( kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name @@ -614,6 +614,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta return kwargs + def _add_config_name_to_deploy_kwargs( kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None ) -> JumpStartModelInitKwargs: @@ -645,7 +646,7 @@ def _add_config_name_to_deploy_kwargs( if ( specs.inference_configs - and specs.inference_configs.get_top_config_from_ranking().config_name + and specs.inference_configs.get_top_config_from_ranking() ): kwargs.config_name = ( kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name From 096de7b860f1ec9f391b13139b389ecde846cba0 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 22:28:12 +0000 Subject: [PATCH 08/10] updates --- src/sagemaker/jumpstart/factory/model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index e04f672df8..f8bcb23090 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -592,9 +592,8 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta specs.inference_configs and specs.inference_configs.get_top_config_from_ranking() ): - kwargs.config_name = ( - kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name - ) + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name if not kwargs.config_name: return kwargs @@ -648,9 +647,9 @@ def _add_config_name_to_deploy_kwargs( specs.inference_configs and specs.inference_configs.get_top_config_from_ranking() ): - kwargs.config_name = ( - kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name - ) + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + return kwargs From fd720d8284b47c0961ed50c338cb1f477fdff17d Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 1 May 2024 23:40:45 +0000 Subject: [PATCH 09/10] formt --- src/sagemaker/jumpstart/factory/model.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f8bcb23090..62712f60c7 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -588,10 +588,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta model_type=kwargs.model_type, config_name=kwargs.config_name, ) - if ( - specs.inference_configs - and specs.inference_configs.get_top_config_from_ranking() - ): + if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking(): default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name kwargs.config_name = kwargs.config_name or default_config_name @@ -643,13 +640,9 @@ def _add_config_name_to_deploy_kwargs( specs=specs, training_config_name=training_config_name ) - if ( - specs.inference_configs - and specs.inference_configs.get_top_config_from_ranking() - ): + if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking(): default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name kwargs.config_name = kwargs.config_name or default_config_name - return kwargs From cf0e8abbf25d475f037fed25fec56a3f89951caa Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 2 May 2024 14:28:03 +0000 Subject: [PATCH 10/10] format --- src/sagemaker/jumpstart/factory/model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 62712f60c7..79a7b18788 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -561,11 +561,10 @@ def _select_inference_config_from_training_config( Returns: str: The name of the inference config. """ - if ( - specs.training_configs - and specs.training_configs.configs.get(training_config_name).default_inference_config - ): - return specs.training_configs.configs.get(training_config_name).default_inference_config + if specs.training_configs: + resolved_training_config = specs.training_configs.configs.get(training_config_name) + if resolved_training_config: + return resolved_training_config.default_inference_config return None @@ -588,7 +587,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta model_type=kwargs.model_type, config_name=kwargs.config_name, ) - if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking(): + if specs.inference_configs: default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name kwargs.config_name = kwargs.config_name or default_config_name @@ -640,7 +639,7 @@ def _add_config_name_to_deploy_kwargs( specs=specs, training_config_name=training_config_name ) - if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking(): + if specs.inference_configs: default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name kwargs.config_name = kwargs.config_name or default_config_name