Skip to content

Commit f90a68f

Browse files
committed
chore: working JumpStartEstimator and JumpStartModel
1 parent f120e0d commit f90a68f

15 files changed

+1111
-249
lines changed

src/sagemaker/deserializers.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
from __future__ import absolute_import
1515

1616

17-
from typing import Any, List, Optional
17+
from typing import List, Optional
1818

19+
# base_deserializers was renamed from deserializers, so this import
20+
# is for backwards compatibility.
1921
from sagemaker.base_deserializers import * # noqa: F403, F401 # pylint: disable=W0614,W0401
2022

23+
from sagemaker.base_deserializers import BaseDeserializer
24+
2125
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
2226

2327

@@ -27,7 +31,7 @@ def retrieve(
2731
model_version: Optional[str] = None,
2832
tolerate_vulnerable_model: bool = False,
2933
tolerate_deprecated_model: bool = False,
30-
) -> List[Any]:
34+
) -> List[BaseDeserializer]:
3135
"""Retrieves the supported deserializers for the model matching the given arguments.
3236
3337
Args:
@@ -71,7 +75,7 @@ def retrieve_default(
7175
model_version: Optional[str] = None,
7276
tolerate_vulnerable_model: bool = False,
7377
tolerate_deprecated_model: bool = False,
74-
) -> Any:
78+
) -> BaseDeserializer:
7579
"""Retrieves the default deserializer for the model matching the given arguments.
7680
7781
Args:

src/sagemaker/environment_variables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def retrieve_default(
5050
(Default: False).
5151
use_case (EnvVariableUseCase): The use case for the environment variables. The
5252
`Model` class of the SageMaker Python SDK inserts environment variables
53-
that would be requiredwhen making the low-level AWS API call.
53+
that would be required when making the low-level AWS API call.
5454
(Default: EnvVariableUseCase.AWS_SDK).
5555
Returns:
5656
dict: The variables to use for the model.

src/sagemaker/jumpstart/estimator.py

+72-128
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,14 @@
1414
from __future__ import absolute_import
1515

1616

17-
from copy import deepcopy
18-
from typing import Any, Optional
19-
from sagemaker import (
20-
hyperparameters,
21-
image_uris,
22-
instance_types,
23-
metric_definitions,
24-
model_uris,
25-
script_uris,
26-
)
27-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
28-
from sagemaker.jumpstart.enums import JumpStartScriptScope
29-
from sagemaker.jumpstart.utils import update_dict_if_key_not_present
30-
from sagemaker.model import Estimator
17+
from typing import Dict, List, Optional
18+
19+
from sagemaker.estimator import Estimator
20+
21+
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
22+
23+
24+
from sagemaker.predictor import Predictor
3125

3226

3327
class JumpStartEstimator(Estimator):
@@ -39,128 +33,78 @@ class JumpStartEstimator(Estimator):
3933
def __init__(
4034
self,
4135
model_id: str,
42-
model_version: Optional[str] = "*",
43-
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
44-
kwargs_for_base_estimator_class: dict = {},
36+
model_version: Optional[str] = None,
37+
instance_type: Optional[str] = None,
38+
instance_count: Optional[int] = None,
39+
region: Optional[str] = None,
40+
image_uri: Optional[str] = None,
41+
model_uri: Optional[str] = None,
42+
source_dir: Optional[str] = None,
43+
entry_point: Optional[str] = None,
44+
hyperparameters: Optional[dict] = None,
45+
metric_definitions: Optional[List[dict]] = None,
46+
**kwargs,
4547
):
46-
self.model_id = model_id
47-
self.model_version = model_version
48-
self.kwargs_for_base_estimator_class = deepcopy(kwargs_for_base_estimator_class)
49-
50-
self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
51-
self.kwargs_for_base_estimator_class,
52-
"image_uri",
53-
image_uris.retrieve(
54-
region=None,
55-
framework=None,
56-
image_scope="training",
57-
model_id=model_id,
58-
model_version=model_version,
59-
instance_type=self.instance_type,
60-
),
61-
)
62-
63-
self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
64-
self.kwargs_for_base_estimator_class,
65-
"model_uri",
66-
model_uris.retrieve(
67-
script_scope=JumpStartScriptScope.TRAINING,
68-
model_id=model_id,
69-
model_version=model_version,
70-
),
71-
)
72-
73-
self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
74-
self.kwargs_for_base_estimator_class,
75-
"script_uri",
76-
script_uris.retrieve(
77-
script_scope=JumpStartScriptScope.TRAINING,
78-
model_id=model_id,
79-
model_version=model_version,
80-
),
81-
)
82-
83-
default_hyperparameters = hyperparameters.retrieve_default(
84-
region=region, model_id=model_id, model_version=model_version
48+
estimator_init_kwargs = get_init_kwargs(
49+
model_id=model_id,
50+
model_version=model_version,
51+
instance_type=instance_type,
52+
instance_count=instance_count,
53+
region=region,
54+
image_uri=image_uri,
55+
model_uri=model_uri,
56+
source_dir=source_dir,
57+
entry_point=entry_point,
58+
hyperparameters=hyperparameters,
59+
metric_definitions=metric_definitions,
60+
kwargs=kwargs,
8561
)
8662

87-
curr_hyperparameters = self.kwargs_for_base_estimator_class.get("hyperparameters", {})
88-
new_hyperparameters = deepcopy(curr_hyperparameters)
89-
90-
for key, value in default_hyperparameters:
91-
new_hyperparameters = update_dict_if_key_not_present(
92-
new_hyperparameters,
93-
key,
94-
value,
95-
)
63+
self.model_id = estimator_init_kwargs.model_id
64+
self.model_version = estimator_init_kwargs.model_version
65+
self.instance_type = estimator_init_kwargs.instance_type
66+
self.instance_count = estimator_init_kwargs.instance_count
67+
self.region = estimator_init_kwargs.region
9668

97-
if new_hyperparameters == {}:
98-
new_hyperparameters = None
69+
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
9970

100-
self.kwargs_for_base_estimator_class["hyperparameters"] = new_hyperparameters
71+
def fit(self, *largs, **kwargs) -> None:
10172

102-
default_metric_definitions = metric_definitions.retrieve_default(
103-
region=region, model_id=model_id, model_version=model_version
73+
estimator_fit_kwargs = get_fit_kwargs(
74+
model_id=self.model_id,
75+
model_version=self.model_version,
76+
instance_type=self.instance_type,
77+
instance_count=self.instance_count,
78+
region=self.region,
79+
kwargs=kwargs,
10480
)
10581

106-
curr_metric_definitions = self.kwargs_for_base_estimator_class.get("metric_definitions", [])
107-
new_metric_definitions = deepcopy(curr_metric_definitions)
108-
109-
for metric_definition in default_metric_definitions:
110-
if metric_definition["Name"] not in [
111-
definition["Name"] for definition in new_metric_definitions
112-
]:
113-
new_metric_definitions.append(metric_definition)
114-
115-
if new_metric_definitions == []:
116-
new_metric_definitions = None
82+
return super(JumpStartEstimator, self).fit(*largs, **estimator_fit_kwargs.to_kwargs_dict())
11783

118-
self.kwargs_for_base_estimator_class["metric_definitions"] = new_metric_definitions
119-
120-
# estimator_kwargs_to_add = _retrieve_kwargs(model_id=model_id, model_version=model_version, region=region)
121-
estimator_kwargs_to_add = {}
122-
123-
new_kwargs_for_base_estimator_class = deepcopy(self.kwargs_for_base_estimator_class)
124-
for key, value in estimator_kwargs_to_add:
125-
new_kwargs_for_base_estimator_class = update_dict_if_key_not_present(
126-
new_kwargs_for_base_estimator_class,
127-
key,
128-
value,
129-
)
130-
131-
self.kwargs_for_base_estimator_class = new_kwargs_for_base_estimator_class
132-
133-
self.kwargs_for_base_estimator_class["model_id"] = model_id
134-
self.kwargs_for_base_estimator_class["model_version"] = model_version
135-
136-
# self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
137-
# self.kwargs_for_base_estimator_class,
138-
# "predictor_cls",
139-
# JumpStartPredictor,
140-
# )
141-
142-
self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
143-
self.kwargs_for_base_estimator_class, "instance_count", 1
144-
)
145-
self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
146-
self.kwargs_for_base_estimator_class,
147-
"instance_type",
148-
instance_types.retrieve_default(
149-
region=region, model_id=model_id, model_version=model_version
150-
),
84+
def deploy(
85+
self,
86+
image_uri: Optional[str] = None,
87+
source_dir: Optional[str] = None,
88+
entry_point: Optional[str] = None,
89+
env: Optional[Dict[str, str]] = None,
90+
predictor_cls: Optional[Predictor] = None,
91+
initial_instance_count: Optional[int] = None,
92+
instance_type: Optional[str] = None,
93+
**kwargs,
94+
) -> None:
95+
96+
estimator_deploy_kwargs = get_deploy_kwargs(
97+
model_id=self.model_id,
98+
model_version=self.model_version,
99+
instance_type=instance_type,
100+
initial_instance_count=initial_instance_count,
101+
region=self.region,
102+
image_uri=image_uri,
103+
source_dir=source_dir,
104+
entry_point=entry_point,
105+
env=env,
106+
predictor_cls=predictor_cls,
107+
kwargs=kwargs,
151108
)
152109

153-
super(Estimator, self).__init__(**self.kwargs_for_base_estimator_class)
154-
155-
@staticmethod
156-
def _update_dict_if_key_not_present(
157-
dict_to_update: dict, key_to_add: Any, value_to_add: Any
158-
) -> dict:
159-
if key_to_add not in dict_to_update:
160-
dict_to_update[key_to_add] = value_to_add
161-
162-
return dict_to_update
163-
164-
def fit(self, *largs, **kwargs) -> None:
165-
166-
return super(Estimator, self).fit(*largs, **kwargs)
110+
return super(JumpStartEstimator, self).deploy(**estimator_deploy_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/factory/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)