22
22
from sagemaker .utils import aws_partition
23
23
24
24
25
- def get_model_id_version_from_endpoint (
25
+ def get_model_info_from_endpoint (
26
26
endpoint_name : str ,
27
27
inference_component_name : Optional [str ] = None ,
28
28
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
29
- ) -> Tuple [str , str , Optional [str ]]:
30
- """Given an endpoint and optionally inference component names, return the model ID and version .
29
+ ) -> Tuple [str , str , Optional [str ], Optional [ str ] ]:
30
+ """Optionally inference component names, return the model ID, version and config name .
31
31
32
32
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
33
33
and version. A third string element is included in the tuple for any inferred inference
@@ -46,30 +46,32 @@ def get_model_id_version_from_endpoint(
46
46
(
47
47
model_id ,
48
48
model_version ,
49
- ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
49
+ config_name ,
50
+ ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name ( # noqa E501 # pylint: disable=c0301
50
51
inference_component_name , sagemaker_session
51
52
)
52
53
53
54
else :
54
55
(
55
56
model_id ,
56
57
model_version ,
58
+ config_name ,
57
59
inference_component_name ,
58
- ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
60
+ ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name ( # noqa E501 # pylint: disable=c0301
59
61
endpoint_name , sagemaker_session
60
62
)
61
63
62
64
else :
63
- model_id , model_version = _get_model_id_version_from_model_based_endpoint (
65
+ model_id , model_version , config_name = _get_model_info_from_model_based_endpoint (
64
66
endpoint_name , inference_component_name , sagemaker_session
65
67
)
66
- return model_id , model_version , inference_component_name
68
+ return model_id , model_version , inference_component_name , config_name
67
69
68
70
69
- def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name (
71
+ def _get_model_info_from_inference_component_endpoint_without_inference_component_name (
70
72
endpoint_name : str , sagemaker_session : Session
71
- ) -> Tuple [str , str , str ]:
72
- """Given an endpoint name, derives the model ID, version, and inferred inference component name.
73
+ ) -> Tuple [str , str , str , str ]:
74
+ """Derives the model ID, version, config name and inferred inference component name.
73
75
74
76
This function assumes the endpoint corresponds to an inference-component-based endpoint.
75
77
An endpoint is inference-component-based if and only if the associated endpoint config
@@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
98
100
)
99
101
inference_component_name = inference_component_names [0 ]
100
102
return (
101
- * _get_model_id_version_from_inference_component_endpoint_with_inference_component_name (
103
+ * _get_model_info_from_inference_component_endpoint_with_inference_component_name (
102
104
inference_component_name , sagemaker_session
103
105
),
104
106
inference_component_name ,
105
107
)
106
108
107
109
108
- def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name (
110
+ def _get_model_info_from_inference_component_endpoint_with_inference_component_name (
109
111
inference_component_name : str , sagemaker_session : Session
110
112
):
111
113
"""Returns the model ID and version inferred from a SageMaker inference component.
@@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
123
125
f"inference-component/{ inference_component_name } "
124
126
)
125
127
126
- model_id , model_version = get_jumpstart_model_id_version_from_resource_arn (
128
+ model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
127
129
inference_component_arn , sagemaker_session
128
130
)
129
131
@@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
134
136
"when retrieving default predictor for this inference component."
135
137
)
136
138
137
- return model_id , model_version
139
+ return model_id , model_version , config_name
138
140
139
141
140
- def _get_model_id_version_from_model_based_endpoint (
142
+ def _get_model_info_from_model_based_endpoint (
141
143
endpoint_name : str ,
142
144
inference_component_name : Optional [str ],
143
145
sagemaker_session : Session ,
144
- ) -> Tuple [str , str ]:
145
- """Returns the model ID and version inferred from a model-based endpoint.
146
+ ) -> Tuple [str , str , Optional [ str ] ]:
147
+ """Returns the model ID, version and config name inferred from a model-based endpoint.
146
148
147
149
Raises:
148
150
ValueError: If an inference component name is supplied, or if the endpoint does
@@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(
161
163
162
164
endpoint_arn = f"arn:{ partition } :sagemaker:{ region } :{ account_id } :endpoint/{ endpoint_name } "
163
165
164
- model_id , model_version = get_jumpstart_model_id_version_from_resource_arn (
166
+ model_id , model_version , config_name = get_jumpstart_model_id_version_from_resource_arn (
165
167
endpoint_arn , sagemaker_session
166
168
)
167
169
@@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
172
174
"predictor for this endpoint."
173
175
)
174
176
175
- return model_id , model_version
177
+ return model_id , model_version , config_name
176
178
177
179
178
- def get_model_id_version_from_training_job (
180
+ def get_model_info_from_training_job (
179
181
training_job_name : str ,
180
182
sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
181
- ) -> Tuple [str , str ]:
182
- """Returns the model ID and version inferred from a training job.
183
+ ) -> Tuple [str , str , Optional [ str ] ]:
184
+ """Returns the model ID and version and config name inferred from a training job.
183
185
184
186
Raises:
185
187
ValueError: If the training job does not have tags from which the model ID
@@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
194
196
f"arn:{ partition } :sagemaker:{ region } :{ account_id } :training-job/{ training_job_name } "
195
197
)
196
198
197
- model_id , inferred_model_version = get_jumpstart_model_id_version_from_resource_arn (
198
- training_job_arn , sagemaker_session
199
- )
199
+ (
200
+ model_id ,
201
+ inferred_model_version ,
202
+ config_name ,
203
+ ) = get_jumpstart_model_id_version_from_resource_arn (training_job_arn , sagemaker_session )
200
204
201
205
model_version = inferred_model_version or None
202
206
@@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
207
211
"for this training job."
208
212
)
209
213
210
- return model_id , model_version
214
+ return model_id , model_version , config_name
0 commit comments