Skip to content

feat: JumpStart alternative config parsing #4566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 332 additions & 6 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""This module stores types related to SageMaker JumpStart."""
from __future__ import absolute_import
from abc import abstractmethod
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -736,8 +737,66 @@ def _get_regional_property(
return alias_value


class JumpStartModelSpecs(JumpStartDataHolderType):
"""Data class JumpStart model specs."""
class JumpStartBenchmarkStat(JumpStartDataHolderType):
"""Data class JumpStart benchmark stats."""

__slots__ = ["name", "value", "unit"]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartBenchmarkStat object

Args:
spec (Dict[str, Any]): Dictionary representation of benchmark stat.
"""
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json.

Args:
json_obj (Dict[str, Any]): Dictionary representation of benchmark stats.
"""
self.name: str = json_obj["name"]
self.instance_type: str = json_obj["instance_type"]
self.latency_ms: int = json_obj["latency_ms"]

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartBenchmarkStat object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj


class JumpStartPresetRanking(JumpStartDataHolderType):
"""Data class JumpStart preset ranking."""

__slots__ = ["name", "description", "ranking"]

def __init__(self, spec: Optional[Dict[str, Any]]):
"""Initializes a JumpStartPresetRanking object.

Args:
spec (Dict[str, Any]): Dictionary representation of training preset ranking.
"""
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json.

Args:
json_obj (Dict[str, Any]): Dictionary representation of preset ranking.
"""
self.name: str = json_obj["name"]
self.description: str = json_obj["description"]
self.ranking: List[str] = json_obj["ranking"]

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartPresetRanking object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj


class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"""Data class JumpStart metadata base fields that can be overridden with configuration changes."""

__slots__ = [
"model_id",
Expand Down Expand Up @@ -794,13 +853,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
"model_subscription_link",
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartModelSpecs object from its json representation.
def __init__(self, fields: Optional[Dict[str, Any]]):
"""Initializes a JumpStartMetadataFields object.

Args:
spec (Dict[str, Any]): Dictionary representation of spec.
fields (Dict[str, Any]): Dictionary representation of metadata fields.
"""
self.from_json(spec)
self.from_json(fields)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this work, I don't mypy should accept this signature given that from_json accepts a Dict[str, Any] and get passed an Optional[Dict[str, Any]]. This could cause null pointer errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this is confusing typing, the fields object is Dict[str, Any] but not optional. I think mypy is not checking these typings, updated in my next PR


def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json of header.
Expand Down Expand Up @@ -944,6 +1003,273 @@ def to_json(self) -> Dict[str, Any]:
json_obj[att] = cur_val
return json_obj


class JumpStartPresetComponent(JumpStartMetadataBaseFields):
"""Data class of JumpStart config component."""

slots = ["component_name", "component_override_fields"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is component_override_fields?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess this is fine, gives us flexibility in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the fields that this particular component will override to the JumpStartMetadataBaseFields. Each JumpStartPresetComponent inherits the base class so contains all the base fields, but since we only want to override to these specific fields so I kept a separate list only for them.


# List of fields that is not allowed to override to JumpStartMetadataBaseFields
# TODO: finalize the fields that can be overrided
OVERRIDING_DENY_LIST = []

__slots__ = slots + JumpStartMetadataBaseFields.__slots__

def __init__(self, component: Optional[Dict[str, Any]], spec: Optional[Dict[str, Any]]):
"""Initializes a JumpStartInferenceComponent object from its json representation.

Args:
component (Dict[str, Any]):
Dictionary representation of the config component that can override to the metadata base fields.
spec (Dict[str, Any]):
Dictionary representation of the original metadata base fields.
"""
super().from_json(spec)
self.component_name: str = component["component_name"]
self.component_overrides: JumpStartMetadataBaseFields = JumpStartMetadataBaseFields(
component
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartInferenceComponent object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj

@staticmethod
def _override_merge(source: Dict[str, Any], override: Dict[str, Any]):
"""
Deep merge two dictionaries recursively, with override overriding source.
"""
for key, value in source.items():
if (
key not in JumpStartPresetComponent.OVERRIDING_DENY_LIST
and key in override
and isinstance(value, dict)
and isinstance(override[key], dict)
):
JumpStartPresetComponent._override_merge(value, override[key])
elif key in override:
source[key] = override[key]
for key in override.keys():
if key not in source:
source[key] = override[key]

return source

def resolve(self) -> None:
"""Resolves the JumpStartMetadataBaseFields with the current component, by overriding corresponding fields."""
merged_fields = JumpStartPresetComponent._override_merge(
self.to_json(), self.component_overrides.to_json()
)
super().from_json(merged_fields)


class JumpStartPresetConfig(JumpStartDataHolderType):
"""Data class of JumpStart Inference config."""

__slots__ = ["benchmark_metrics", "preset_components", "resolved_preset_config"]

def __init__(
self,
preset_components: Dict[str, JumpStartPresetComponent],
benchmark_metrics: Dict[str, JumpStartBenchmarkStat],
):
"""Initializes a JumpStartInferencePresetConfig object from its json representation.

Args:
preset_components (Dict[str, JumpStartPresetComponent]):
The list of components that are used to construct the resolved config.
benchmark_metrics (Dict[str, JumpStartBenchmarkStat]):
The dictionary of benchmark metrics with name being the key.
"""
self.preset_components: Dict[str, JumpStartPresetComponent] = preset_components
self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics
self.resolved_preset_config: Optional[JumpStartPresetConfig] = None

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartInferencePresetConfig object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj

def resolve(self) -> JumpStartPresetComponent:
"""Returns the final config that is resolved from the list of components."""
if self.resolved_preset_config:
return self.resolved_preset_config

# TODO: return the best preset config and resolve the config object
# Using the first component as the preset
component = self.preset_components.values()[0] if self.preset_components else None
component.resolve()
self.resolved_preset_config = component
return self.resolved_preset_config


class JumpStartPresetConfigs(JumpStartDataHolderType):
"""Data class to hold the set of JumpStart Preset configs."""

__slots__ = [
"preset_configs",
"preset_config_rankings",
]

def __init__(
self,
preset_configs: Dict[str, JumpStartPresetConfig],
preset_config_rankings: JumpStartPresetRanking,
):
"""Initializes a JumpStartInferencePresetConfig object from its json representation.

Args:
preset_configs (Dict[str, JumpStartPresetConfig]):
List of preset configs that the current model has.
preset_config_rankings (JumpStartPresetRanking):
Preset ranking class represents the ranking of the presets in the model.
"""
self.preset_configs = preset_configs
self.preset_config_rankings = preset_config_rankings

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json.

Args:
json_obj (Dict[str, Any]): Dictionary representation of inference config.
"""
if json_obj is None:
return

self.preset_configs: str = {
alias: JumpStartPresetConfig(preset_config)
for alias, preset_config in json_obj["preset_configs"].items()
}
self.preset_config_rankings: str = json_obj["preset_config_rankings"]

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartInferencePresetConfig object."""
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj

def get_resolved_config(self) -> JumpStartPresetConfig:
"""Resolves the preset config."""
# TODO: return the best preset config based on config rankings
return self.preset_configs[0] if self.preset_configs else None


class JumpStartModelSpecs(JumpStartMetadataBaseFields):
"""Data class JumpStart model specs."""

slots = [
"inference_presets",
"inference_preset_components",
"inference_preset_rankings",
"training_presets",
"training_preset_components",
"training_preset_rankings",
]

__slots__ = JumpStartMetadataBaseFields.__slots__ + slots

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartModelSpecs object from its json representation.

Args:
spec (Dict[str, Any]): Dictionary representation of spec.
"""
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json of header.

Args:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
super().from_json(json_obj)
self.inference_preset_components: Optional[Dict[str, JumpStartPresetComponent]] = (
{
alias: JumpStartPresetComponent(component, json_obj)
for alias, component in json_obj["inference_preset_components"].items()
}
if json_obj.get("inference_preset_components")
else None
)
self.inference_preset_rankings: Optional[Dict[str, JumpStartPresetRanking]] = (
{
alias: JumpStartPresetRanking(ranking)
for alias, ranking in json_obj["inference_preset_rankings"].items()
}
if json_obj.get("inference_preset_rankings")
else None
)
self.inference_presets: Optional[JumpStartPresetConfig] = (
JumpStartPresetConfigs(
(
{
alias: JumpStartPresetConfig(
{
component_name: self.inference_preset_components[component_name]
for component_name in config.get("component_names")
},
{
stat_name: JumpStartBenchmarkStat(stat)
for stat_name, stat in config.get("benchmark_metrics").items()
},
)
for alias, config in json_obj["inference_configs"].items()
}
if json_obj.get("inference_configs")
else None
),
self.inference_preset_rankings,
)
if "inference_configs" in json_obj
else None
)

if self.training_supported:
self.training_preset_components: Optional[Dict[str, JumpStartPresetComponent]] = (
{
alias: JumpStartPresetComponent(component)
for alias, component in json_obj["training_preset_components"].items()
}
if json_obj.get("training_preset_components")
else None
)
self.training_preset_rankings: Optional[Dict[str, JumpStartPresetRanking]] = (
{
alias: JumpStartPresetRanking(ranking)
for alias, ranking in json_obj["training_preset_rankings"].items()
}
if json_obj.get("training_preset_rankings")
else None
)
self.training_presets: Optional[JumpStartPresetConfig] = (
JumpStartPresetConfigs(
(
{
alias: JumpStartPresetConfig(
{
component_name: self.training_preset_components[component_name]
for component_name in config.get("component_names")
},
{
stat_name: JumpStartBenchmarkStat(stat)
for stat_name, stat in config.get("benchmark_metrics").items()
},
)
for alias, config in json_obj["training_configs"].items()
}
if json_obj.get("training_configs")
else None
),
self.training_preset_rankings,
)
if "training_configs" in json_obj
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")

if self.inference_presets and self.inference_presets.get_resolved_config():
super().from_json(self.inference_presets.get_resolved_config().to_json())

def supports_prepacked_inference(self) -> bool:
"""Returns True if the model has a prepacked inference artifact."""
return getattr(self, "hosting_prepacked_artifact_key", None) is not None
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7499,6 +7499,12 @@
"resource_name_base": "dfsdfsds",
"hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360},
"dynamic_container_deployment_supported": True,
"inference_presets": None,
"inference_preset_components": None,
"training_presets": None,
"training_preset_components": None,
"inference_preset_rankings": None,
"training_preset_rankings": None,
}

BASE_HEADER = {
Expand Down
Loading