@@ -56,7 +56,8 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
56
56
self .deserializer = deserializer
57
57
self .content_type = content_type or getattr (serializer , 'content_type' , None )
58
58
self .accept = accept or getattr (deserializer , 'accept' , None )
59
- self ._model_names = self ._get_model_names ()
59
+ self ._endpoint_config_name = self ._get_endpoint_config_name ()
60
+ self ._model_names = self ._endpoint_config_desc_and_model_names ()
60
61
61
62
def predict (self , data , initial_args = None ):
62
63
"""Return the inference from the specified endpoint.
@@ -134,15 +135,14 @@ def delete_model(self):
134
135
for model_name in self ._model_names :
135
136
self .sagemaker_session .delete_model (model_name )
136
137
137
- def _get_endpoint_config_desc (self ):
138
+ def _get_endpoint_config_name (self ):
138
139
endpoint_desc = self .sagemaker_session .sagemaker_client .describe_endpoint (EndpointName = self .endpoint )
139
- self ._endpoint_config_name = endpoint_desc ['EndpointConfigName' ]
140
+ endpoint_config_name = endpoint_desc ['EndpointConfigName' ]
141
+ return endpoint_config_name
142
+
143
+ def _endpoint_config_desc_and_model_names (self ):
140
144
endpoint_config = self .sagemaker_session .sagemaker_client .describe_endpoint_config (
141
145
EndpointConfigName = self ._endpoint_config_name )
142
- return endpoint_config
143
-
144
- def _get_model_names (self ):
145
- endpoint_config = self ._get_endpoint_config_desc ()
146
146
production_variants = endpoint_config ['ProductionVariants' ]
147
147
return map (lambda d : d ['ModelName' ], production_variants )
148
148
0 commit comments