17
17
from typing import Optional , Tuple
18
18
from sagemaker .jumpstart .constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19
19
20
- from sagemaker .jumpstart .utils import get_jumpstart_model_id_version_from_resource_arn
20
+ from sagemaker .jumpstart .utils import get_jumpstart_model_info_from_resource_arn
21
21
from sagemaker .session import Session
22
22
from sagemaker .utils import aws_partition
23
23
@@ -26,7 +26,7 @@ def get_model_info_from_endpoint(
26
26
endpoint_name : str ,
27
27
inference_component_name : Optional [str ] = None ,
28
28
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
29
- ) -> Tuple [str , str , Optional [str ], Optional [str ]]:
29
+ ) -> Tuple [str , str , Optional [str ], Optional [str ], Optional [ str ] ]:
30
30
"""Optionally inference component names, return the model ID, version and config name.
31
31
32
32
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
@@ -46,7 +46,8 @@ def get_model_info_from_endpoint(
46
46
(
47
47
model_id ,
48
48
model_version ,
49
- config_name ,
49
+ inference_config_name ,
50
+ training_config_name ,
50
51
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
51
52
inference_component_name , sagemaker_session
52
53
)
@@ -55,17 +56,29 @@ def get_model_info_from_endpoint(
55
56
(
56
57
model_id ,
57
58
model_version ,
58
- config_name ,
59
+ inference_config_name ,
60
+ training_config_name ,
59
61
inference_component_name ,
60
62
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
61
63
endpoint_name , sagemaker_session
62
64
)
63
65
64
66
else :
65
- model_id , model_version , config_name = _get_model_info_from_model_based_endpoint (
67
+ (
68
+ model_id ,
69
+ model_version ,
70
+ inference_config_name ,
71
+ training_config_name ,
72
+ ) = _get_model_info_from_model_based_endpoint (
66
73
endpoint_name , inference_component_name , sagemaker_session
67
74
)
68
- return model_id , model_version , inference_component_name , config_name
75
+ return (
76
+ model_id ,
77
+ model_version ,
78
+ inference_component_name ,
79
+ inference_config_name ,
80
+ training_config_name ,
81
+ )
69
82
70
83
71
84
def _get_model_info_from_inference_component_endpoint_without_inference_component_name (
@@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
125
138
f"inference-component/{ inference_component_name } "
126
139
)
127
140
128
- model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
129
- inference_component_arn , sagemaker_session
130
- )
141
+ (
142
+ model_id ,
143
+ model_version ,
144
+ inference_config_name ,
145
+ training_config_name ,
146
+ ) = get_jumpstart_model_info_from_resource_arn (inference_component_arn , sagemaker_session )
131
147
132
148
if not model_id :
133
149
raise ValueError (
@@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
136
152
"when retrieving default predictor for this inference component."
137
153
)
138
154
139
- return model_id , model_version , config_name
155
+ return model_id , model_version , inference_config_name , training_config_name
140
156
141
157
142
158
def _get_model_info_from_model_based_endpoint (
143
159
endpoint_name : str ,
144
160
inference_component_name : Optional [str ],
145
161
sagemaker_session : Session ,
146
- ) -> Tuple [str , str , Optional [str ]]:
162
+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
147
163
"""Returns the model ID, version and config name inferred from a model-based endpoint.
148
164
149
165
Raises:
@@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint(
163
179
164
180
endpoint_arn = f"arn:{ partition } :sagemaker:{ region } :{ account_id } :endpoint/{ endpoint_name } "
165
181
166
- model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
167
- endpoint_arn , sagemaker_session
168
- )
182
+ (
183
+ model_id ,
184
+ model_version ,
185
+ inference_config_name ,
186
+ training_config_name ,
187
+ ) = get_jumpstart_model_info_from_resource_arn (endpoint_arn , sagemaker_session )
169
188
170
189
if not model_id :
171
190
raise ValueError (
@@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint(
174
193
"predictor for this endpoint."
175
194
)
176
195
177
- return model_id , model_version , config_name
196
+ return model_id , model_version , inference_config_name , training_config_name
178
197
179
198
180
199
def get_model_info_from_training_job (
181
200
training_job_name : str ,
182
201
sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
183
- ) -> Tuple [str , str , Optional [str ]]:
202
+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
184
203
"""Returns the model ID and version and config name inferred from a training job.
185
204
186
205
Raises:
@@ -199,8 +218,9 @@ def get_model_info_from_training_job(
199
218
(
200
219
model_id ,
201
220
inferred_model_version ,
202
- config_name ,
203
- ) = get_jumpstart_model_id_version_from_resource_arn (training_job_arn , sagemaker_session )
221
+ inference_config_name ,
222
+ trainig_config_name ,
223
+ ) = get_jumpstart_model_info_from_resource_arn (training_job_arn , sagemaker_session )
204
224
205
225
model_version = inferred_model_version or None
206
226
@@ -211,4 +231,4 @@ def get_model_info_from_training_job(
211
231
"for this training job."
212
232
)
213
233
214
- return model_id , model_version , config_name
234
+ return model_id , model_version , inference_config_name , trainig_config_name
0 commit comments