Skip to content

Commit edc98b9

Browse files
Captainiabenieric
authored andcommitted
Add supported inference and incremental training configs (aws#4637)
* supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format
1 parent 92219f2 commit edc98b9

File tree

8 files changed

+359
-52
lines changed

8 files changed

+359
-52
lines changed

src/sagemaker/jumpstart/estimator.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def __init__(
507507
enable_session_tag_chaining (bool or PipelineVariable): Optional.
508508
Specifies whether SessionTagChaining is enabled for the training job
509509
config_name (Optional[str]):
510-
Name of the JumpStart Model config to apply. (Default: None).
510+
Name of the training configuration to apply to the Estimator. (Default: None).
511511
512512
Raises:
513513
ValueError: If the model ID is not recognized by JumpStart.
@@ -690,6 +690,7 @@ def attach(
690690
model_version: Optional[str] = None,
691691
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
692692
model_channel_name: str = "model",
693+
config_name: Optional[str] = None,
693694
) -> "JumpStartEstimator":
694695
"""Attach to an existing training job.
695696
@@ -725,6 +726,8 @@ def attach(
725726
model data will be downloaded (default: 'model'). If no channel
726727
with the same name exists in the training job, this option will
727728
be ignored.
729+
config_name (str): Optional. Name of the training configuration to use
730+
when attaching to the training job. (Default: None).
728731
729732
Returns:
730733
Instance of the calling ``JumpStartEstimator`` Class with the attached
@@ -736,7 +739,6 @@ def attach(
736739
"""
737740
config_name = None
738741
if model_id is None:
739-
740742
model_id, model_version, _, config_name = get_model_info_from_training_job(
741743
training_job_name=training_job_name, sagemaker_session=sagemaker_session
742744
)
@@ -750,6 +752,9 @@ def attach(
750752
"tolerate_deprecated_model": True, # model is already trained
751753
}
752754

755+
if config_name:
756+
additional_kwargs.update({"config_name": config_name})
757+
753758
model_specs = verify_model_region_and_return_specs(
754759
model_id=model_id,
755760
version=model_version,
@@ -808,6 +813,7 @@ def deploy(
808813
dependencies: Optional[List[str]] = None,
809814
git_config: Optional[Dict[str, str]] = None,
810815
use_compiled_model: bool = False,
816+
inference_config_name: Optional[str] = None,
811817
) -> PredictorBase:
812818
"""Creates endpoint from training job.
813819
@@ -1043,6 +1049,8 @@ def deploy(
10431049
(Default: None).
10441050
use_compiled_model (bool): Flag to select whether to use compiled
10451051
(optimized) model. (Default: False).
1052+
inference_config_name (Optional[str]): Name of the inference configuration to
1053+
be used in the model. (Default: None).
10461054
"""
10471055
self.orig_predictor_cls = predictor_cls
10481056

@@ -1095,7 +1103,8 @@ def deploy(
10951103
git_config=git_config,
10961104
use_compiled_model=use_compiled_model,
10971105
training_instance_type=self.instance_type,
1098-
config_name=self.config_name,
1106+
training_config_name=self.config_name,
1107+
inference_config_name=inference_config_name,
10991108
)
11001109

11011110
predictor = super(JumpStartEstimator, self).deploy(
@@ -1112,7 +1121,7 @@ def deploy(
11121121
tolerate_deprecated_model=self.tolerate_deprecated_model,
11131122
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
11141123
sagemaker_session=self.sagemaker_session,
1115-
config_name=self.config_name,
1124+
config_name=estimator_deploy_kwargs.config_name,
11161125
)
11171126

11181127
# If a predictor class was passed, do not mutate predictor
@@ -1144,7 +1153,9 @@ def set_training_config(self, config_name: str) -> None:
11441153
config_name (str): The name of the config.
11451154
"""
11461155
self.__init__(
1147-
model_id=self.model_id, model_version=self.model_version, config_name=config_name
1156+
model_id=self.model_id,
1157+
model_version=self.model_version,
1158+
config_name=config_name,
11481159
)
11491160

11501161
def __str__(self) -> str:

src/sagemaker/jumpstart/factory/estimator.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def get_init_kwargs(
209209
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
210210
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
211211
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
212+
estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs)
212213

213214
return estimator_init_kwargs
214215

@@ -293,7 +294,8 @@ def get_deploy_kwargs(
293294
use_compiled_model: Optional[bool] = None,
294295
model_name: Optional[str] = None,
295296
training_instance_type: Optional[str] = None,
296-
config_name: Optional[str] = None,
297+
training_config_name: Optional[str] = None,
298+
inference_config_name: Optional[str] = None,
297299
) -> JumpStartEstimatorDeployKwargs:
298300
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
299301

@@ -321,7 +323,8 @@ def get_deploy_kwargs(
321323
tolerate_vulnerable_model=tolerate_vulnerable_model,
322324
tolerate_deprecated_model=tolerate_deprecated_model,
323325
sagemaker_session=sagemaker_session,
324-
config_name=config_name,
326+
training_config_name=training_config_name,
327+
config_name=inference_config_name,
325328
)
326329

327330
model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs(
@@ -350,7 +353,7 @@ def get_deploy_kwargs(
350353
tolerate_deprecated_model=tolerate_deprecated_model,
351354
training_instance_type=training_instance_type,
352355
disable_instance_type_logging=True,
353-
config_name=config_name,
356+
config_name=model_deploy_kwargs.config_name,
354357
)
355358

356359
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
@@ -395,7 +398,7 @@ def get_deploy_kwargs(
395398
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
396399
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
397400
use_compiled_model=use_compiled_model,
398-
config_name=config_name,
401+
config_name=model_deploy_kwargs.config_name,
399402
)
400403

401404
return estimator_deploy_kwargs
@@ -795,3 +798,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
795798
setattr(kwargs, key, value)
796799

797800
return kwargs
801+
802+
803+
def _add_config_name_to_kwargs(
804+
kwargs: JumpStartEstimatorInitKwargs,
805+
) -> JumpStartEstimatorInitKwargs:
806+
"""Sets tags in kwargs based on default or override, returns full kwargs."""
807+
808+
specs = verify_model_region_and_return_specs(
809+
model_id=kwargs.model_id,
810+
version=kwargs.model_version,
811+
scope=JumpStartScriptScope.TRAINING,
812+
region=kwargs.region,
813+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
814+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
815+
sagemaker_session=kwargs.sagemaker_session,
816+
config_name=kwargs.config_name,
817+
)
818+
819+
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
820+
kwargs.config_name = (
821+
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
822+
)
823+
824+
return kwargs

src/sagemaker/jumpstart/factory/model.py

+68-9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
JumpStartModelDeployKwargs,
4343
JumpStartModelInitKwargs,
4444
JumpStartModelRegisterKwargs,
45+
JumpStartModelSpecs,
4546
)
4647
from sagemaker.jumpstart.utils import (
4748
add_jumpstart_model_info_tags,
@@ -548,7 +549,27 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
548549
return kwargs
549550

550551

551-
def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
552+
def _select_inference_config_from_training_config(
553+
specs: JumpStartModelSpecs, training_config_name: str
554+
) -> Optional[str]:
555+
"""Selects the inference config from the training config.
556+
557+
Args:
558+
specs (JumpStartModelSpecs): The specs for the model.
559+
training_config_name (str): The name of the training config.
560+
561+
Returns:
562+
str: The name of the inference config.
563+
"""
564+
if specs.training_configs:
565+
resolved_training_config = specs.training_configs.configs.get(training_config_name)
566+
if resolved_training_config:
567+
return resolved_training_config.default_inference_config
568+
569+
return None
570+
571+
572+
def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
552573
"""Sets default config name to the kwargs. Returns full kwargs.
553574
554575
Raises:
@@ -566,13 +587,9 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
566587
model_type=kwargs.model_type,
567588
config_name=kwargs.config_name,
568589
)
569-
if (
570-
specs.inference_configs
571-
and specs.inference_configs.get_top_config_from_ranking().config_name
572-
):
573-
kwargs.config_name = (
574-
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
575-
)
590+
if specs.inference_configs:
591+
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
592+
kwargs.config_name = kwargs.config_name or default_config_name
576593

577594
if not kwargs.config_name:
578595
return kwargs
@@ -593,6 +610,42 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
593610
return kwargs
594611

595612

613+
def _add_config_name_to_deploy_kwargs(
614+
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
615+
) -> JumpStartModelInitKwargs:
616+
"""Sets default config name to the kwargs. Returns full kwargs.
617+
618+
If a training_config_name is passed, then choose the inference config
619+
based on the supported inference configs in that training config.
620+
621+
Raises:
622+
ValueError: If the instance_type is not supported with the current config.
623+
"""
624+
625+
specs = verify_model_region_and_return_specs(
626+
model_id=kwargs.model_id,
627+
version=kwargs.model_version,
628+
scope=JumpStartScriptScope.INFERENCE,
629+
region=kwargs.region,
630+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
631+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
632+
sagemaker_session=kwargs.sagemaker_session,
633+
model_type=kwargs.model_type,
634+
config_name=kwargs.config_name,
635+
)
636+
637+
if training_config_name:
638+
kwargs.config_name = _select_inference_config_from_training_config(
639+
specs=specs, training_config_name=training_config_name
640+
)
641+
642+
if specs.inference_configs:
643+
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
644+
kwargs.config_name = kwargs.config_name or default_config_name
645+
646+
return kwargs
647+
648+
596649
def get_deploy_kwargs(
597650
model_id: str,
598651
model_version: Optional[str] = None,
@@ -623,6 +676,7 @@ def get_deploy_kwargs(
623676
resources: Optional[ResourceRequirements] = None,
624677
managed_instance_scaling: Optional[str] = None,
625678
endpoint_type: Optional[EndpointType] = None,
679+
training_config_name: Optional[str] = None,
626680
config_name: Optional[str] = None,
627681
) -> JumpStartModelDeployKwargs:
628682
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
@@ -664,6 +718,10 @@ def get_deploy_kwargs(
664718

665719
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
666720

721+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
722+
kwargs=deploy_kwargs, training_config_name=training_config_name
723+
)
724+
667725
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
668726

669727
deploy_kwargs.initial_instance_count = initial_instance_count or 1
@@ -858,6 +916,7 @@ def get_init_kwargs(
858916
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
859917

860918
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
861-
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)
919+
920+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
862921

863922
return model_init_kwargs

src/sagemaker/jumpstart/types.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10771077
"config_components",
10781078
"resolved_metadata_config",
10791079
"config_name",
1080+
"default_inference_config",
1081+
"default_incremental_trainig_config",
1082+
"supported_inference_configs",
1083+
"supported_incremental_training_configs",
10801084
]
10811085

10821086
def __init__(
10831087
self,
10841088
config_name: str,
1089+
config: Dict[str, Any],
10851090
base_fields: Dict[str, Any],
10861091
config_components: Dict[str, JumpStartConfigComponent],
1087-
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
10881092
):
10891093
"""Initializes a JumpStartMetadataConfig object from its json representation.
10901094
10911095
Args:
1096+
config_name (str): Name of the config,
1097+
config (Dict[str, Any]):
1098+
Dictionary representation of the config.
10921099
base_fields (Dict[str, Any]):
10931100
The default base fields that are used to construct the final resolved config.
10941101
config_components (Dict[str, JumpStartConfigComponent]):
10951102
The list of components that are used to construct the resolved config.
1096-
benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]):
1097-
The dictionary of benchmark metrics with name being the key.
10981103
"""
10991104
self.base_fields = base_fields
11001105
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
1101-
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
1106+
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
1107+
{
1108+
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1109+
for stat_name, stats in config.get("benchmark_metrics").items()
1110+
}
1111+
if config and config.get("benchmark_metrics")
1112+
else None
1113+
)
11021114
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11031115
self.config_name: Optional[str] = config_name
1116+
self.default_inference_config: Optional[str] = config.get("default_inference_config")
1117+
self.default_incremental_trainig_config: Optional[str] = config.get(
1118+
"default_incremental_training_config"
1119+
)
1120+
self.supported_inference_configs: Optional[List[str]] = config.get(
1121+
"supported_inference_configs"
1122+
)
1123+
self.supported_incremental_training_configs: Optional[List[str]] = config.get(
1124+
"supported_incremental_training_configs"
1125+
)
11041126

11051127
def to_json(self) -> Dict[str, Any]:
11061128
"""Returns json representation of JumpStartMetadataConfig object."""
@@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12551277
{
12561278
alias: JumpStartMetadataConfig(
12571279
alias,
1280+
config,
12581281
json_obj,
12591282
(
12601283
{
@@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12641287
if config and config.get("component_names")
12651288
else None
12661289
),
1267-
(
1268-
{
1269-
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1270-
for stat_name, stats in config.get("benchmark_metrics").items()
1271-
}
1272-
if config and config.get("benchmark_metrics")
1273-
else None
1274-
),
12751290
)
12761291
for alias, config in json_obj["inference_configs"].items()
12771292
}
@@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13081323
{
13091324
alias: JumpStartMetadataConfig(
13101325
alias,
1326+
config,
13111327
json_obj,
13121328
(
13131329
{
@@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13171333
if config and config.get("component_names")
13181334
else None
13191335
),
1320-
(
1321-
{
1322-
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1323-
for stat_name, stats in config.get("benchmark_metrics").items()
1324-
}
1325-
if config and config.get("benchmark_metrics")
1326-
else None
1327-
),
13281336
)
13291337
for alias, config in json_obj["training_configs"].items()
13301338
}

0 commit comments

Comments
 (0)