-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: JumpStart alternative config parsing #4566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
35f235a
d62a8fc
750b9f0
6db438e
21a0e60
76fee33
cd27c33
183a0c9
44f4c6f
decbe18
8dade8a
1481ba8
4d7a31a
decc22e
c19a28a
866956c
4850e4f
1a21b82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,9 +37,9 @@ | |
) | ||
from sagemaker.jumpstart.types import ( | ||
JumpStartBenchmarkStat, | ||
JumpStartMetadataConfig, | ||
JumpStartModelHeader, | ||
JumpStartModelSpecs, | ||
JumpStartPresetConfig, | ||
JumpStartVersionedModelId, | ||
) | ||
from sagemaker.session import Session | ||
|
@@ -882,15 +882,15 @@ def get_region_fallback( | |
return list(combined_regions)[0] | ||
|
||
|
||
def get_preset_names( | ||
def get_config_names( | ||
region: str, | ||
model_id: str, | ||
model_version: str, | ||
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, | ||
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, | ||
) -> List[str]: | ||
"""Returns a list of preset names for the given model ID and region.""" | ||
"""Returns a list of config names for the given model ID and region.""" | ||
model_specs = verify_model_region_and_return_specs( | ||
region=region, | ||
model_id=model_id, | ||
|
@@ -901,12 +901,13 @@ def get_preset_names( | |
) | ||
|
||
if scope == enums.JumpStartScriptScope.INFERENCE: | ||
presets = model_specs.inference_presets | ||
|
||
if scope == enums.JumpStartScriptScope.TRAINING: | ||
presets = model_specs.training_presets | ||
metadata_configs = model_specs.inference_configs | ||
elif scope == enums.JumpStartScriptScope.TRAINING: | ||
metadata_configs = model_specs.training_configs | ||
else: | ||
raise ValueError(f"Unknown script scope {scope}.") | ||
|
||
return list(presets.preset_configs.keys()) if presets else [] | ||
return list(metadata_configs.configs.keys()) if metadata_configs else [] | ||
|
||
|
||
def get_benchmark_stats( | ||
|
@@ -929,31 +930,34 @@ def get_benchmark_stats( | |
) | ||
|
||
if scope == enums.JumpStartScriptScope.INFERENCE: | ||
presets = model_specs.inference_presets | ||
metadata_configs = model_specs.inference_configs | ||
elif scope == enums.JumpStartScriptScope.TRAINING: | ||
presets = model_specs.training_presets | ||
metadata_configs = model_specs.training_configs | ||
else: | ||
raise ValueError(f"Unknown script scope {scope}.") | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not config_names: | ||
config_names = presets.preset_configs.keys() if presets else [] | ||
config_names = metadata_configs.configs.keys() if metadata_configs else [] | ||
|
||
benchmark_stats = {} | ||
for config_name in config_names: | ||
if config_name not in presets.preset_configs: | ||
raise ValueError(f"Unknown preset config name: '{config_name}'") | ||
benchmark_stats[config_name] = presets.preset_configs.get(config_name).benchmark_metrics | ||
if config_name not in metadata_configs.configs: | ||
raise ValueError(f"Unknown config name: '{config_name}'") | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics | ||
|
||
return benchmark_stats | ||
|
||
|
||
def get_jumpstart_presets( | ||
def get_jumpstart_configs( | ||
region: str, | ||
model_id: str, | ||
model_version: str, | ||
config_names: Optional[List[str]] = None, | ||
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, | ||
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, | ||
) -> Dict[str, List[JumpStartPresetConfig]]: | ||
) -> Dict[str, List[JumpStartMetadataConfig]]: | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns metadata configs for the given model ID and region.""" | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_specs = verify_model_region_and_return_specs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this block of code appears 3 times in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I think it's required because we use |
||
region=region, | ||
model_id=model_id, | ||
|
@@ -964,15 +968,17 @@ def get_jumpstart_presets( | |
) | ||
|
||
if scope == enums.JumpStartScriptScope.INFERENCE: | ||
presets = model_specs.inference_presets | ||
metadata_configs = model_specs.inference_configs | ||
elif scope == enums.JumpStartScriptScope.TRAINING: | ||
presets = model_specs.training_presets | ||
metadata_configs = model_specs.training_configs | ||
else: | ||
raise ValueError(f"Unknown script scope {scope}.") | ||
|
||
if not config_names: | ||
config_names = presets.preset_configs.keys() | ||
config_names = metadata_configs.configs.keys() if metadata_configs else [] | ||
|
||
preset_configs = { | ||
config_name: presets.preset_configs[config_name] for config_name in config_names | ||
} | ||
|
||
return preset_configs | ||
return ( | ||
{config_name: metadata_configs.configs[config_name] for config_name in config_names} | ||
if metadata_configs | ||
else {} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add
Raises
section please