14
14
from __future__ import absolute_import
15
15
16
16
17
- from copy import deepcopy
18
- from typing import Any , Optional
19
- from sagemaker import (
20
- hyperparameters ,
21
- image_uris ,
22
- instance_types ,
23
- metric_definitions ,
24
- model_uris ,
25
- script_uris ,
26
- )
27
- from sagemaker .jumpstart .constants import JUMPSTART_DEFAULT_REGION_NAME
28
- from sagemaker .jumpstart .enums import JumpStartScriptScope
29
- from sagemaker .jumpstart .utils import update_dict_if_key_not_present
30
- from sagemaker .model import Estimator
17
+ from typing import Dict , List , Optional
18
+
19
+ from sagemaker .estimator import Estimator
20
+
21
+ from sagemaker .jumpstart .factory .estimator import get_deploy_kwargs , get_fit_kwargs , get_init_kwargs
22
+
23
+
24
+ from sagemaker .predictor import Predictor
31
25
32
26
33
27
class JumpStartEstimator (Estimator ):
@@ -39,128 +33,78 @@ class JumpStartEstimator(Estimator):
39
33
def __init__ (
40
34
self ,
41
35
model_id : str ,
42
- model_version : Optional [str ] = "*" ,
43
- region : Optional [str ] = JUMPSTART_DEFAULT_REGION_NAME ,
44
- kwargs_for_base_estimator_class : dict = {},
36
+ model_version : Optional [str ] = None ,
37
+ instance_type : Optional [str ] = None ,
38
+ instance_count : Optional [int ] = None ,
39
+ region : Optional [str ] = None ,
40
+ image_uri : Optional [str ] = None ,
41
+ model_uri : Optional [str ] = None ,
42
+ source_dir : Optional [str ] = None ,
43
+ entry_point : Optional [str ] = None ,
44
+ hyperparameters : Optional [dict ] = None ,
45
+ metric_definitions : Optional [List [dict ]] = None ,
46
+ ** kwargs ,
45
47
):
46
- self .model_id = model_id
47
- self .model_version = model_version
48
- self .kwargs_for_base_estimator_class = deepcopy (kwargs_for_base_estimator_class )
49
-
50
- self .kwargs_for_base_estimator_class = update_dict_if_key_not_present (
51
- self .kwargs_for_base_estimator_class ,
52
- "image_uri" ,
53
- image_uris .retrieve (
54
- region = None ,
55
- framework = None ,
56
- image_scope = "training" ,
57
- model_id = model_id ,
58
- model_version = model_version ,
59
- instance_type = self .instance_type ,
60
- ),
61
- )
62
-
63
- self .kwargs_for_base_estimator_class = update_dict_if_key_not_present (
64
- self .kwargs_for_base_estimator_class ,
65
- "model_uri" ,
66
- model_uris .retrieve (
67
- script_scope = JumpStartScriptScope .TRAINING ,
68
- model_id = model_id ,
69
- model_version = model_version ,
70
- ),
71
- )
72
-
73
- self .kwargs_for_base_estimator_class = update_dict_if_key_not_present (
74
- self .kwargs_for_base_estimator_class ,
75
- "script_uri" ,
76
- script_uris .retrieve (
77
- script_scope = JumpStartScriptScope .TRAINING ,
78
- model_id = model_id ,
79
- model_version = model_version ,
80
- ),
81
- )
82
-
83
- default_hyperparameters = hyperparameters .retrieve_default (
84
- region = region , model_id = model_id , model_version = model_version
48
+ estimator_init_kwargs = get_init_kwargs (
49
+ model_id = model_id ,
50
+ model_version = model_version ,
51
+ instance_type = instance_type ,
52
+ instance_count = instance_count ,
53
+ region = region ,
54
+ image_uri = image_uri ,
55
+ model_uri = model_uri ,
56
+ source_dir = source_dir ,
57
+ entry_point = entry_point ,
58
+ hyperparameters = hyperparameters ,
59
+ metric_definitions = metric_definitions ,
60
+ kwargs = kwargs ,
85
61
)
86
62
87
- curr_hyperparameters = self .kwargs_for_base_estimator_class .get ("hyperparameters" , {})
88
- new_hyperparameters = deepcopy (curr_hyperparameters )
89
-
90
- for key , value in default_hyperparameters :
91
- new_hyperparameters = update_dict_if_key_not_present (
92
- new_hyperparameters ,
93
- key ,
94
- value ,
95
- )
63
+ self .model_id = estimator_init_kwargs .model_id
64
+ self .model_version = estimator_init_kwargs .model_version
65
+ self .instance_type = estimator_init_kwargs .instance_type
66
+ self .instance_count = estimator_init_kwargs .instance_count
67
+ self .region = estimator_init_kwargs .region
96
68
97
- if new_hyperparameters == {}:
98
- new_hyperparameters = None
69
+ super (JumpStartEstimator , self ).__init__ (** estimator_init_kwargs .to_kwargs_dict ())
99
70
100
- self . kwargs_for_base_estimator_class [ "hyperparameters" ] = new_hyperparameters
71
+ def fit ( self , * largs , ** kwargs ) -> None :
101
72
102
- default_metric_definitions = metric_definitions .retrieve_default (
103
- region = region , model_id = model_id , model_version = model_version
73
+ estimator_fit_kwargs = get_fit_kwargs (
74
+ model_id = self .model_id ,
75
+ model_version = self .model_version ,
76
+ instance_type = self .instance_type ,
77
+ instance_count = self .instance_count ,
78
+ region = self .region ,
79
+ kwargs = kwargs ,
104
80
)
105
81
106
- curr_metric_definitions = self .kwargs_for_base_estimator_class .get ("metric_definitions" , [])
107
- new_metric_definitions = deepcopy (curr_metric_definitions )
108
-
109
- for metric_definition in default_metric_definitions :
110
- if metric_definition ["Name" ] not in [
111
- definition ["Name" ] for definition in new_metric_definitions
112
- ]:
113
- new_metric_definitions .append (metric_definition )
114
-
115
- if new_metric_definitions == []:
116
- new_metric_definitions = None
82
+ return super (JumpStartEstimator , self ).fit (* largs , ** estimator_fit_kwargs .to_kwargs_dict ())
117
83
118
- self .kwargs_for_base_estimator_class ["metric_definitions" ] = new_metric_definitions
119
-
120
- # estimator_kwargs_to_add = _retrieve_kwargs(model_id=model_id, model_version=model_version, region=region)
121
- estimator_kwargs_to_add = {}
122
-
123
- new_kwargs_for_base_estimator_class = deepcopy (self .kwargs_for_base_estimator_class )
124
- for key , value in estimator_kwargs_to_add :
125
- new_kwargs_for_base_estimator_class = update_dict_if_key_not_present (
126
- new_kwargs_for_base_estimator_class ,
127
- key ,
128
- value ,
129
- )
130
-
131
- self .kwargs_for_base_estimator_class = new_kwargs_for_base_estimator_class
132
-
133
- self .kwargs_for_base_estimator_class ["model_id" ] = model_id
134
- self .kwargs_for_base_estimator_class ["model_version" ] = model_version
135
-
136
- # self.kwargs_for_base_estimator_class = update_dict_if_key_not_present(
137
- # self.kwargs_for_base_estimator_class,
138
- # "predictor_cls",
139
- # JumpStartPredictor,
140
- # )
141
-
142
- self .kwargs_for_base_estimator_class = update_dict_if_key_not_present (
143
- self .kwargs_for_base_estimator_class , "instance_count" , 1
144
- )
145
- self .kwargs_for_base_estimator_class = update_dict_if_key_not_present (
146
- self .kwargs_for_base_estimator_class ,
147
- "instance_type" ,
148
- instance_types .retrieve_default (
149
- region = region , model_id = model_id , model_version = model_version
150
- ),
84
+ def deploy (
85
+ self ,
86
+ image_uri : Optional [str ] = None ,
87
+ source_dir : Optional [str ] = None ,
88
+ entry_point : Optional [str ] = None ,
89
+ env : Optional [Dict [str , str ]] = None ,
90
+ predictor_cls : Optional [Predictor ] = None ,
91
+ initial_instance_count : Optional [int ] = None ,
92
+ instance_type : Optional [str ] = None ,
93
+ ** kwargs ,
94
+ ) -> None :
95
+
96
+ estimator_deploy_kwargs = get_deploy_kwargs (
97
+ model_id = self .model_id ,
98
+ model_version = self .model_version ,
99
+ instance_type = instance_type ,
100
+ initial_instance_count = initial_instance_count ,
101
+ region = self .region ,
102
+ image_uri = image_uri ,
103
+ source_dir = source_dir ,
104
+ entry_point = entry_point ,
105
+ env = env ,
106
+ predictor_cls = predictor_cls ,
107
+ kwargs = kwargs ,
151
108
)
152
109
153
- super (Estimator , self ).__init__ (** self .kwargs_for_base_estimator_class )
154
-
155
- @staticmethod
156
- def _update_dict_if_key_not_present (
157
- dict_to_update : dict , key_to_add : Any , value_to_add : Any
158
- ) -> dict :
159
- if key_to_add not in dict_to_update :
160
- dict_to_update [key_to_add ] = value_to_add
161
-
162
- return dict_to_update
163
-
164
- def fit (self , * largs , ** kwargs ) -> None :
165
-
166
- return super (Estimator , self ).fit (* largs , ** kwargs )
110
+ return super (JumpStartEstimator , self ).deploy (** estimator_deploy_kwargs .to_kwargs_dict ())
0 commit comments