Skip to content

Commit 220a01f

Browse files
authored
Add supported inference and incremental training configs (#4637)
* supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format
1 parent bc51dc1 commit 220a01f

File tree

10 files changed

+354
-58
lines changed

10 files changed

+354
-58
lines changed

src/sagemaker/jumpstart/estimator.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def __init__(
504504
enable_remote_debug (bool or PipelineVariable): Optional.
505505
Specifies whether RemoteDebug is enabled for the training job
506506
config_name (Optional[str]):
507-
Name of the JumpStart Model config to apply. (Default: None).
507+
Name of the training configuration to apply to the Estimator. (Default: None).
508508
509509
Raises:
510510
ValueError: If the model ID is not recognized by JumpStart.
@@ -686,6 +686,7 @@ def attach(
686686
model_version: Optional[str] = None,
687687
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
688688
model_channel_name: str = "model",
689+
config_name: Optional[str] = None,
689690
) -> "JumpStartEstimator":
690691
"""Attach to an existing training job.
691692
@@ -721,6 +722,8 @@ def attach(
721722
model data will be downloaded (default: 'model'). If no channel
722723
with the same name exists in the training job, this option will
723724
be ignored.
725+
config_name (str): Optional. Name of the training configuration to use
726+
when attaching to the training job. (Default: None).
724727
725728
Returns:
726729
Instance of the calling ``JumpStartEstimator`` Class with the attached
@@ -732,7 +735,6 @@ def attach(
732735
"""
733736
config_name = None
734737
if model_id is None:
735-
736738
model_id, model_version, _, config_name = get_model_info_from_training_job(
737739
training_job_name=training_job_name, sagemaker_session=sagemaker_session
738740
)
@@ -746,6 +748,9 @@ def attach(
746748
"tolerate_deprecated_model": True, # model is already trained
747749
}
748750

751+
if config_name:
752+
additional_kwargs.update({"config_name": config_name})
753+
749754
model_specs = verify_model_region_and_return_specs(
750755
model_id=model_id,
751756
version=model_version,
@@ -804,6 +809,7 @@ def deploy(
804809
dependencies: Optional[List[str]] = None,
805810
git_config: Optional[Dict[str, str]] = None,
806811
use_compiled_model: bool = False,
812+
inference_config_name: Optional[str] = None,
807813
) -> PredictorBase:
808814
"""Creates endpoint from training job.
809815
@@ -1039,6 +1045,8 @@ def deploy(
10391045
(Default: None).
10401046
use_compiled_model (bool): Flag to select whether to use compiled
10411047
(optimized) model. (Default: False).
1048+
inference_config_name (Optional[str]): Name of the inference configuration to
1049+
be used in the model. (Default: None).
10421050
"""
10431051
self.orig_predictor_cls = predictor_cls
10441052

@@ -1091,7 +1099,8 @@ def deploy(
10911099
git_config=git_config,
10921100
use_compiled_model=use_compiled_model,
10931101
training_instance_type=self.instance_type,
1094-
config_name=self.config_name,
1102+
training_config_name=self.config_name,
1103+
inference_config_name=inference_config_name,
10951104
)
10961105

10971106
predictor = super(JumpStartEstimator, self).deploy(
@@ -1108,7 +1117,7 @@ def deploy(
11081117
tolerate_deprecated_model=self.tolerate_deprecated_model,
11091118
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
11101119
sagemaker_session=self.sagemaker_session,
1111-
config_name=self.config_name,
1120+
config_name=estimator_deploy_kwargs.config_name,
11121121
)
11131122

11141123
# If a predictor class was passed, do not mutate predictor
@@ -1140,7 +1149,9 @@ def set_training_config(self, config_name: str) -> None:
11401149
config_name (str): The name of the config.
11411150
"""
11421151
self.__init__(
1143-
model_id=self.model_id, model_version=self.model_version, config_name=config_name
1152+
model_id=self.model_id,
1153+
model_version=self.model_version,
1154+
config_name=config_name,
11441155
)
11451156

11461157
def __str__(self) -> str:

src/sagemaker/jumpstart/factory/estimator.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def get_init_kwargs(
207207
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
208208
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
209209
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
210+
estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs)
210211

211212
return estimator_init_kwargs
212213

@@ -291,7 +292,8 @@ def get_deploy_kwargs(
291292
use_compiled_model: Optional[bool] = None,
292293
model_name: Optional[str] = None,
293294
training_instance_type: Optional[str] = None,
294-
config_name: Optional[str] = None,
295+
training_config_name: Optional[str] = None,
296+
inference_config_name: Optional[str] = None,
295297
) -> JumpStartEstimatorDeployKwargs:
296298
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
297299

@@ -319,7 +321,8 @@ def get_deploy_kwargs(
319321
tolerate_vulnerable_model=tolerate_vulnerable_model,
320322
tolerate_deprecated_model=tolerate_deprecated_model,
321323
sagemaker_session=sagemaker_session,
322-
config_name=config_name,
324+
training_config_name=training_config_name,
325+
config_name=inference_config_name,
323326
)
324327

325328
model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs(
@@ -348,7 +351,7 @@ def get_deploy_kwargs(
348351
tolerate_deprecated_model=tolerate_deprecated_model,
349352
training_instance_type=training_instance_type,
350353
disable_instance_type_logging=True,
351-
config_name=config_name,
354+
config_name=model_deploy_kwargs.config_name,
352355
)
353356

354357
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
@@ -393,7 +396,7 @@ def get_deploy_kwargs(
393396
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
394397
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
395398
use_compiled_model=use_compiled_model,
396-
config_name=config_name,
399+
config_name=model_deploy_kwargs.config_name,
397400
)
398401

399402
return estimator_deploy_kwargs
@@ -793,3 +796,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
793796
setattr(kwargs, key, value)
794797

795798
return kwargs
799+
800+
801+
def _add_config_name_to_kwargs(
802+
kwargs: JumpStartEstimatorInitKwargs,
803+
) -> JumpStartEstimatorInitKwargs:
804+
"""Sets tags in kwargs based on default or override, returns full kwargs."""
805+
806+
specs = verify_model_region_and_return_specs(
807+
model_id=kwargs.model_id,
808+
version=kwargs.model_version,
809+
scope=JumpStartScriptScope.TRAINING,
810+
region=kwargs.region,
811+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
812+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
813+
sagemaker_session=kwargs.sagemaker_session,
814+
config_name=kwargs.config_name,
815+
)
816+
817+
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
818+
kwargs.config_name = (
819+
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
820+
)
821+
822+
return kwargs

src/sagemaker/jumpstart/factory/model.py

+68-9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
JumpStartModelDeployKwargs,
4343
JumpStartModelInitKwargs,
4444
JumpStartModelRegisterKwargs,
45+
JumpStartModelSpecs,
4546
)
4647
from sagemaker.jumpstart.utils import (
4748
add_jumpstart_model_info_tags,
@@ -548,7 +549,27 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
548549
return kwargs
549550

550551

551-
def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
552+
def _select_inference_config_from_training_config(
553+
specs: JumpStartModelSpecs, training_config_name: str
554+
) -> Optional[str]:
555+
"""Selects the inference config from the training config.
556+
557+
Args:
558+
specs (JumpStartModelSpecs): The specs for the model.
559+
training_config_name (str): The name of the training config.
560+
561+
Returns:
562+
str: The name of the inference config.
563+
"""
564+
if specs.training_configs:
565+
resolved_training_config = specs.training_configs.configs.get(training_config_name)
566+
if resolved_training_config:
567+
return resolved_training_config.default_inference_config
568+
569+
return None
570+
571+
572+
def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
552573
"""Sets default config name to the kwargs. Returns full kwargs.
553574
554575
Raises:
@@ -566,13 +587,9 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
566587
model_type=kwargs.model_type,
567588
config_name=kwargs.config_name,
568589
)
569-
if (
570-
specs.inference_configs
571-
and specs.inference_configs.get_top_config_from_ranking().config_name
572-
):
573-
kwargs.config_name = (
574-
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
575-
)
590+
if specs.inference_configs:
591+
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
592+
kwargs.config_name = kwargs.config_name or default_config_name
576593

577594
if not kwargs.config_name:
578595
return kwargs
@@ -593,6 +610,42 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
593610
return kwargs
594611

595612

613+
def _add_config_name_to_deploy_kwargs(
614+
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
615+
) -> JumpStartModelInitKwargs:
616+
"""Sets default config name to the kwargs. Returns full kwargs.
617+
618+
If a training_config_name is passed, then choose the inference config
619+
based on the supported inference configs in that training config.
620+
621+
Raises:
622+
ValueError: If the instance_type is not supported with the current config.
623+
"""
624+
625+
specs = verify_model_region_and_return_specs(
626+
model_id=kwargs.model_id,
627+
version=kwargs.model_version,
628+
scope=JumpStartScriptScope.INFERENCE,
629+
region=kwargs.region,
630+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
631+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
632+
sagemaker_session=kwargs.sagemaker_session,
633+
model_type=kwargs.model_type,
634+
config_name=kwargs.config_name,
635+
)
636+
637+
if training_config_name:
638+
kwargs.config_name = _select_inference_config_from_training_config(
639+
specs=specs, training_config_name=training_config_name
640+
)
641+
642+
if specs.inference_configs:
643+
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
644+
kwargs.config_name = kwargs.config_name or default_config_name
645+
646+
return kwargs
647+
648+
596649
def get_deploy_kwargs(
597650
model_id: str,
598651
model_version: Optional[str] = None,
@@ -623,6 +676,7 @@ def get_deploy_kwargs(
623676
resources: Optional[ResourceRequirements] = None,
624677
managed_instance_scaling: Optional[str] = None,
625678
endpoint_type: Optional[EndpointType] = None,
679+
training_config_name: Optional[str] = None,
626680
config_name: Optional[str] = None,
627681
) -> JumpStartModelDeployKwargs:
628682
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
@@ -664,6 +718,10 @@ def get_deploy_kwargs(
664718

665719
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
666720

721+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
722+
kwargs=deploy_kwargs, training_config_name=training_config_name
723+
)
724+
667725
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
668726

669727
deploy_kwargs.initial_instance_count = initial_instance_count or 1
@@ -858,6 +916,7 @@ def get_init_kwargs(
858916
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
859917

860918
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
861-
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)
919+
920+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
862921

863922
return model_init_kwargs

src/sagemaker/jumpstart/types.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10771077
"config_components",
10781078
"resolved_metadata_config",
10791079
"config_name",
1080+
"default_inference_config",
1081+
"default_incremental_trainig_config",
1082+
"supported_inference_configs",
1083+
"supported_incremental_training_configs",
10801084
]
10811085

10821086
def __init__(
10831087
self,
10841088
config_name: str,
1089+
config: Dict[str, Any],
10851090
base_fields: Dict[str, Any],
10861091
config_components: Dict[str, JumpStartConfigComponent],
1087-
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
10881092
):
10891093
"""Initializes a JumpStartMetadataConfig object from its json representation.
10901094
10911095
Args:
1096+
config_name (str): Name of the config,
1097+
config (Dict[str, Any]):
1098+
Dictionary representation of the config.
10921099
base_fields (Dict[str, Any]):
10931100
The default base fields that are used to construct the final resolved config.
10941101
config_components (Dict[str, JumpStartConfigComponent]):
10951102
The list of components that are used to construct the resolved config.
1096-
benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]):
1097-
The dictionary of benchmark metrics with name being the key.
10981103
"""
10991104
self.base_fields = base_fields
11001105
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
1101-
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
1106+
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
1107+
{
1108+
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1109+
for stat_name, stats in config.get("benchmark_metrics").items()
1110+
}
1111+
if config and config.get("benchmark_metrics")
1112+
else None
1113+
)
11021114
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11031115
self.config_name: Optional[str] = config_name
1116+
self.default_inference_config: Optional[str] = config.get("default_inference_config")
1117+
self.default_incremental_trainig_config: Optional[str] = config.get(
1118+
"default_incremental_training_config"
1119+
)
1120+
self.supported_inference_configs: Optional[List[str]] = config.get(
1121+
"supported_inference_configs"
1122+
)
1123+
self.supported_incremental_training_configs: Optional[List[str]] = config.get(
1124+
"supported_incremental_training_configs"
1125+
)
11041126

11051127
def to_json(self) -> Dict[str, Any]:
11061128
"""Returns json representation of JumpStartMetadataConfig object."""
@@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12551277
{
12561278
alias: JumpStartMetadataConfig(
12571279
alias,
1280+
config,
12581281
json_obj,
12591282
(
12601283
{
@@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12641287
if config and config.get("component_names")
12651288
else None
12661289
),
1267-
(
1268-
{
1269-
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1270-
for stat_name, stats in config.get("benchmark_metrics").items()
1271-
}
1272-
if config and config.get("benchmark_metrics")
1273-
else None
1274-
),
12751290
)
12761291
for alias, config in json_obj["inference_configs"].items()
12771292
}
@@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13081323
{
13091324
alias: JumpStartMetadataConfig(
13101325
alias,
1326+
config,
13111327
json_obj,
13121328
(
13131329
{
@@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13171333
if config and config.get("component_names")
13181334
else None
13191335
),
1320-
(
1321-
{
1322-
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1323-
for stat_name, stats in config.get("benchmark_metrics").items()
1324-
}
1325-
if config and config.get("benchmark_metrics")
1326-
else None
1327-
),
13281336
)
13291337
for alias, config in json_obj["training_configs"].items()
13301338
}

0 commit comments

Comments
 (0)