-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: JumpStart alternative config parsing #4566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/sagemaker/jumpstart/types.py
Outdated
return json_obj | ||
|
||
@property | ||
def resolved_config(self) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic won't work for dicts inside dicts. can you instead do:
def deep_override_dict(dict1: dict, dict2: dict) -> dict:
"""Overrides any overlapping contents of dict1 with the contents of dict2."""
flattened_dict = flatten(dict1)
flattened_dict.update(flatten(dict2))
return unflatten(flattened_dict) if flattened_dict else {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remembered I used something similar before to this but right now I don't see a native Python supported method flatten
. There's something similar but doesn't seem it trivially support unflatten
: https://pandas.pydata.org/docs/reference/api/pandas.json_normalize.html
src/sagemaker/jumpstart/types.py
Outdated
class JumpStartPresetComponent(JumpStartMetadataBaseFields): | ||
"""Data class of JumpStart config component.""" | ||
|
||
slots = ["component_name", "component_override_fields"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is component_override_fields
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess this is fine, gives us flexibility in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the fields that this particular component will override to the JumpStartMetadataBaseFields
. Each JumpStartPresetComponent
inherits the base class so contains all the base fields, but since we only want to override to these specific fields so I kept a separate list only for them.
continue | ||
return self.configs[config_name] | ||
|
||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type should be Optional
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll address this in my next PR, thanks for catching
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, | ||
) -> Dict[str, List[JumpStartMetadataConfig]]: | ||
"""Returns metadata configs for the given model ID and region.""" | ||
model_specs = verify_model_region_and_return_specs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this block of code appears 3 times in get_jumpstart_configs
, get_benchmark_stats
, and get_config_names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I think it's required because we use verify_model_region_and_return_specs
as the entry point to get specs in all these utils, and each function can be called standalone. But I think we can probably choose a better naming of verify_model_region_and_return_specs
and move outside of utils as it's used heavily everywhere.
@@ -1028,22 +1038,23 @@ class JumpStartPresetComponent(JumpStartMetadataBaseFields): | |||
|
|||
__slots__ = slots + JumpStartMetadataBaseFields.__slots__ | |||
|
|||
def __init__( | |||
def __init__( # pylint: disable=super-init-not-called |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why aren't we calling the superclass constructor? This isn't good OOP design IMO, it violates Liskov Substution Principle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update, I left the component to only have its specific fields but not construct all the parent class fields which are optional, but it should work with superclass constructor. Good callout
""" | ||
self.from_json(spec) | ||
self.from_json(fields) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does this work, I don't mypy
should accept this signature given that from_json
accepts a Dict[str, Any]
and get passed an Optional[Dict[str, Any]]
. This could cause null pointer errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this is confusing typing, the fields
object is Dict[str, Any]
but not optional. I think mypy
is not checking these typings, updated in my next PR
self.incremental_training_supported: bool = bool( | ||
json_obj.get("incremental_training_supported", False) | ||
json_obj.get("incremental_training_supported") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all these types are de facto Optional[*]
now?
* add alternative config parsing * rename a few functions * refactor based on new mock * remove commented files * add interfaces and utils * various fixes and add some tests * add more tests * address comments * swap presets with configs * fix: docstyle * fix: doc * fix: docstyle * updates * typing
* add alternative config parsing * rename a few functions * refactor based on new mock * remove commented files * add interfaces and utils * various fixes and add some tests * add more tests * address comments * swap presets with configs * fix: docstyle * fix: doc * fix: docstyle * updates * typing
Issue #, if available:
Description of changes:
Initial round of reading and parsing the JumpStart alternative configs.
Class structures that goes from bottom to top:
Testing done:
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_base
to create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.