-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: JumpStart alternative config parsing #4566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
35f235a
d62a8fc
750b9f0
6db438e
21a0e60
76fee33
cd27c33
183a0c9
44f4c6f
decbe18
8dade8a
1481ba8
4d7a31a
decc22e
c19a28a
866956c
4850e4f
1a21b82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module stores types related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
from abc import abstractmethod | ||
from copy import deepcopy | ||
from enum import Enum | ||
from typing import Any, Dict, List, Optional, Set, Union | ||
|
@@ -736,8 +737,66 @@ def _get_regional_property( | |
return alias_value | ||
|
||
|
||
class JumpStartModelSpecs(JumpStartDataHolderType): | ||
"""Data class JumpStart model specs.""" | ||
class JumpStartBenchmarkStat(JumpStartDataHolderType): | ||
"""Data class JumpStart benchmark stats.""" | ||
|
||
__slots__ = ["name", "value", "unit"] | ||
|
||
def __init__(self, spec: Dict[str, Any]): | ||
"""Initializes a JumpStartBenchmarkStat object | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
spec (Dict[str, Any]): Dictionary representation of benchmark stat. | ||
""" | ||
self.from_json(spec) | ||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
"""Sets fields in object based on json. | ||
|
||
Args: | ||
json_obj (Dict[str, Any]): Dictionary representation of benchmark stats. | ||
""" | ||
self.name: str = json_obj["name"] | ||
self.instance_type: str = json_obj["instance_type"] | ||
self.latency_ms: int = json_obj["latency_ms"] | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartBenchmarkStat object.""" | ||
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
return json_obj | ||
|
||
|
||
class JumpStartPresetRanking(JumpStartDataHolderType): | ||
"""Data class JumpStart preset ranking.""" | ||
|
||
__slots__ = ["name", "description", "ranking"] | ||
|
||
def __init__(self, spec: Optional[Dict[str, Any]]): | ||
"""Initializes a JumpStartPresetRanking object. | ||
|
||
Args: | ||
spec (Dict[str, Any]): Dictionary representation of training preset ranking. | ||
""" | ||
self.from_json(spec) | ||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
"""Sets fields in object based on json. | ||
|
||
Args: | ||
json_obj (Dict[str, Any]): Dictionary representation of preset ranking. | ||
""" | ||
self.name: str = json_obj["name"] | ||
self.description: str = json_obj["description"] | ||
self.ranking: List[str] = json_obj["ranking"] | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartPresetRanking object.""" | ||
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
return json_obj | ||
|
||
|
||
class JumpStartMetadataBaseFields(JumpStartDataHolderType): | ||
"""Data class JumpStart metadata base fields that can be overridden with configuration changes.""" | ||
|
||
__slots__ = [ | ||
"model_id", | ||
|
@@ -794,13 +853,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType): | |
"model_subscription_link", | ||
] | ||
|
||
def __init__(self, spec: Dict[str, Any]): | ||
"""Initializes a JumpStartModelSpecs object from its json representation. | ||
def __init__(self, fields: Optional[Dict[str, Any]]): | ||
"""Initializes a JumpStartMetadataFields object. | ||
|
||
Args: | ||
spec (Dict[str, Any]): Dictionary representation of spec. | ||
fields (Dict[str, Any]): Dictionary representation of metadata fields. | ||
""" | ||
self.from_json(spec) | ||
self.from_json(fields) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how does this work, I don't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry this is confusing typing, the |
||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
"""Sets fields in object based on json of header. | ||
|
@@ -944,6 +1003,273 @@ def to_json(self) -> Dict[str, Any]: | |
json_obj[att] = cur_val | ||
return json_obj | ||
|
||
|
||
class JumpStartPresetComponent(JumpStartMetadataBaseFields): | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Data class of JumpStart config component.""" | ||
|
||
slots = ["component_name", "component_override_fields"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess this is fine, gives us flexibility in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is the fields that this particular component will override to the |
||
|
||
# List of fields that is not allowed to override to JumpStartMetadataBaseFields | ||
# TODO: finalize the fields that can be overrided | ||
OVERRIDING_DENY_LIST = [] | ||
|
||
__slots__ = slots + JumpStartMetadataBaseFields.__slots__ | ||
|
||
def __init__(self, component: Optional[Dict[str, Any]], spec: Optional[Dict[str, Any]]): | ||
"""Initializes a JumpStartInferenceComponent object from its json representation. | ||
|
||
Args: | ||
component (Dict[str, Any]): | ||
Dictionary representation of the config component that can override to the metadata base fields. | ||
spec (Dict[str, Any]): | ||
Dictionary representation of the original metadata base fields. | ||
""" | ||
super().from_json(spec) | ||
self.component_name: str = component["component_name"] | ||
self.component_overrides: JumpStartMetadataBaseFields = JumpStartMetadataBaseFields( | ||
component | ||
) | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartInferenceComponent object.""" | ||
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return json_obj | ||
|
||
@staticmethod | ||
def _override_merge(source: Dict[str, Any], override: Dict[str, Any]): | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Deep merge two dictionaries recursively, with override overriding source. | ||
""" | ||
for key, value in source.items(): | ||
if ( | ||
key not in JumpStartPresetComponent.OVERRIDING_DENY_LIST | ||
and key in override | ||
and isinstance(value, dict) | ||
and isinstance(override[key], dict) | ||
): | ||
JumpStartPresetComponent._override_merge(value, override[key]) | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif key in override: | ||
source[key] = override[key] | ||
for key in override.keys(): | ||
if key not in source: | ||
source[key] = override[key] | ||
|
||
return source | ||
|
||
def resolve(self) -> None: | ||
"""Resolves the JumpStartMetadataBaseFields with the current component, by overriding corresponding fields.""" | ||
merged_fields = JumpStartPresetComponent._override_merge( | ||
self.to_json(), self.component_overrides.to_json() | ||
) | ||
super().from_json(merged_fields) | ||
|
||
|
||
class JumpStartPresetConfig(JumpStartDataHolderType): | ||
"""Data class of JumpStart Inference config.""" | ||
|
||
__slots__ = ["benchmark_metrics", "preset_components", "resolved_preset_config"] | ||
|
||
def __init__( | ||
self, | ||
preset_components: Dict[str, JumpStartPresetComponent], | ||
benchmark_metrics: Dict[str, JumpStartBenchmarkStat], | ||
): | ||
"""Initializes a JumpStartInferencePresetConfig object from its json representation. | ||
|
||
Args: | ||
preset_components (Dict[str, JumpStartPresetComponent]): | ||
The list of components that are used to construct the resolved config. | ||
benchmark_metrics (Dict[str, JumpStartBenchmarkStat]): | ||
The dictionary of benchmark metrics with name being the key. | ||
""" | ||
self.preset_components: Dict[str, JumpStartPresetComponent] = preset_components | ||
self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics | ||
self.resolved_preset_config: Optional[JumpStartPresetConfig] = None | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartInferencePresetConfig object.""" | ||
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return json_obj | ||
|
||
def resolve(self) -> JumpStartPresetComponent: | ||
"""Returns the final config that is resolved from the list of components.""" | ||
if self.resolved_preset_config: | ||
return self.resolved_preset_config | ||
|
||
# TODO: return the best preset config and resolve the config object | ||
# Using the first component as the preset | ||
component = self.preset_components.values()[0] if self.preset_components else None | ||
component.resolve() | ||
self.resolved_preset_config = component | ||
return self.resolved_preset_config | ||
|
||
|
||
class JumpStartPresetConfigs(JumpStartDataHolderType): | ||
"""Data class to hold the set of JumpStart Preset configs.""" | ||
|
||
__slots__ = [ | ||
"preset_configs", | ||
"preset_config_rankings", | ||
] | ||
|
||
def __init__( | ||
self, | ||
preset_configs: Dict[str, JumpStartPresetConfig], | ||
preset_config_rankings: JumpStartPresetRanking, | ||
): | ||
"""Initializes a JumpStartInferencePresetConfig object from its json representation. | ||
|
||
Args: | ||
preset_configs (Dict[str, JumpStartPresetConfig]): | ||
List of preset configs that the current model has. | ||
preset_config_rankings (JumpStartPresetRanking): | ||
Preset ranking class represents the ranking of the presets in the model. | ||
""" | ||
self.preset_configs = preset_configs | ||
self.preset_config_rankings = preset_config_rankings | ||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
"""Sets fields in object based on json. | ||
|
||
Args: | ||
json_obj (Dict[str, Any]): Dictionary representation of inference config. | ||
""" | ||
if json_obj is None: | ||
return | ||
|
||
self.preset_configs: str = { | ||
alias: JumpStartPresetConfig(preset_config) | ||
for alias, preset_config in json_obj["preset_configs"].items() | ||
} | ||
self.preset_config_rankings: str = json_obj["preset_config_rankings"] | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartInferencePresetConfig object.""" | ||
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
return json_obj | ||
|
||
def get_resolved_config(self) -> JumpStartPresetConfig: | ||
"""Resolves the preset config.""" | ||
# TODO: return the best preset config based on config rankings | ||
return self.preset_configs[0] if self.preset_configs else None | ||
|
||
|
||
class JumpStartModelSpecs(JumpStartMetadataBaseFields): | ||
"""Data class JumpStart model specs.""" | ||
|
||
slots = [ | ||
"inference_presets", | ||
"inference_preset_components", | ||
"inference_preset_rankings", | ||
"training_presets", | ||
"training_preset_components", | ||
"training_preset_rankings", | ||
] | ||
|
||
__slots__ = JumpStartMetadataBaseFields.__slots__ + slots | ||
|
||
def __init__(self, spec: Dict[str, Any]): | ||
"""Initializes a JumpStartModelSpecs object from its json representation. | ||
|
||
Args: | ||
spec (Dict[str, Any]): Dictionary representation of spec. | ||
""" | ||
self.from_json(spec) | ||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
"""Sets fields in object based on json of header. | ||
|
||
Args: | ||
json_obj (Dict[str, Any]): Dictionary representation of spec. | ||
""" | ||
super().from_json(json_obj) | ||
self.inference_preset_components: Optional[Dict[str, JumpStartPresetComponent]] = ( | ||
{ | ||
alias: JumpStartPresetComponent(component, json_obj) | ||
for alias, component in json_obj["inference_preset_components"].items() | ||
} | ||
if json_obj.get("inference_preset_components") | ||
else None | ||
) | ||
self.inference_preset_rankings: Optional[Dict[str, JumpStartPresetRanking]] = ( | ||
{ | ||
alias: JumpStartPresetRanking(ranking) | ||
for alias, ranking in json_obj["inference_preset_rankings"].items() | ||
} | ||
if json_obj.get("inference_preset_rankings") | ||
else None | ||
) | ||
self.inference_presets: Optional[JumpStartPresetConfig] = ( | ||
JumpStartPresetConfigs( | ||
( | ||
{ | ||
alias: JumpStartPresetConfig( | ||
{ | ||
component_name: self.inference_preset_components[component_name] | ||
for component_name in config.get("component_names") | ||
}, | ||
{ | ||
stat_name: JumpStartBenchmarkStat(stat) | ||
for stat_name, stat in config.get("benchmark_metrics").items() | ||
}, | ||
) | ||
for alias, config in json_obj["inference_configs"].items() | ||
} | ||
if json_obj.get("inference_configs") | ||
else None | ||
), | ||
self.inference_preset_rankings, | ||
) | ||
if "inference_configs" in json_obj | ||
else None | ||
) | ||
|
||
if self.training_supported: | ||
self.training_preset_components: Optional[Dict[str, JumpStartPresetComponent]] = ( | ||
{ | ||
alias: JumpStartPresetComponent(component) | ||
for alias, component in json_obj["training_preset_components"].items() | ||
} | ||
if json_obj.get("training_preset_components") | ||
else None | ||
) | ||
self.training_preset_rankings: Optional[Dict[str, JumpStartPresetRanking]] = ( | ||
{ | ||
alias: JumpStartPresetRanking(ranking) | ||
for alias, ranking in json_obj["training_preset_rankings"].items() | ||
} | ||
if json_obj.get("training_preset_rankings") | ||
else None | ||
) | ||
self.training_presets: Optional[JumpStartPresetConfig] = ( | ||
JumpStartPresetConfigs( | ||
( | ||
{ | ||
alias: JumpStartPresetConfig( | ||
{ | ||
component_name: self.training_preset_components[component_name] | ||
for component_name in config.get("component_names") | ||
}, | ||
{ | ||
stat_name: JumpStartBenchmarkStat(stat) | ||
for stat_name, stat in config.get("benchmark_metrics").items() | ||
}, | ||
) | ||
for alias, config in json_obj["training_configs"].items() | ||
} | ||
if json_obj.get("training_configs") | ||
else None | ||
), | ||
self.training_preset_rankings, | ||
) | ||
if "training_configs" in json_obj | ||
else None | ||
) | ||
self.model_subscription_link = json_obj.get("model_subscription_link") | ||
|
||
if self.inference_presets and self.inference_presets.get_resolved_config(): | ||
super().from_json(self.inference_presets.get_resolved_config().to_json()) | ||
|
||
def supports_prepacked_inference(self) -> bool: | ||
"""Returns True if the model has a prepacked inference artifact.""" | ||
return getattr(self, "hosting_prepacked_artifact_key", None) is not None | ||
|
Uh oh!
There was an error while loading. Please reload this page.