@@ -78,6 +78,25 @@ def _prepare_for_mode(self):
78
78
"""Abstract method"""
79
79
80
80
def _create_transformers_model (self ) -> Type [Model ]:
81
+ """Initializes HF model with or without image_uri"""
82
+ if self .image_uri is None :
83
+ pysdk_model = self ._get_hf_metadata_create_model ()
84
+ else :
85
+ pysdk_model = HuggingFaceModel (
86
+ image_uri = self .image_uri ,
87
+ vpc_config = self .vpc_config ,
88
+ env = self .env_vars ,
89
+ role = self .role_arn ,
90
+ sagemaker_session = self .sagemaker_session ,
91
+ )
92
+
93
+ logger .info ("Detected %s. Proceeding with the the deployment." , self .image_uri )
94
+
95
+ self ._original_deploy = pysdk_model .deploy
96
+ pysdk_model .deploy = self ._transformers_model_builder_deploy_wrapper
97
+ return pysdk_model
98
+
99
+ def _get_hf_metadata_create_model (self ) -> Type [Model ]:
81
100
"""Initializes the model after fetching image
82
101
83
102
1. Get the metadata for deciding framework
@@ -132,22 +151,21 @@ def _create_transformers_model(self) -> Type[Model]:
132
151
vpc_config = self .vpc_config ,
133
152
)
134
153
135
- if not self . image_uri and self .mode == Mode .LOCAL_CONTAINER :
154
+ if self .mode == Mode .LOCAL_CONTAINER :
136
155
self .image_uri = pysdk_model .serving_image_uri (
137
156
self .sagemaker_session .boto_region_name , "local"
138
157
)
139
- elif not self . image_uri :
158
+ else :
140
159
self .image_uri = pysdk_model .serving_image_uri (
141
160
self .sagemaker_session .boto_region_name , self .instance_type
142
161
)
143
162
144
- logger .info ("Detected %s. Proceeding with the the deployment." , self .image_uri )
163
+ if pysdk_model is None or self .image_uri is None :
164
+ raise ValueError ("PySDK model unable to be created, try overriding image_uri" )
145
165
146
166
if not pysdk_model .image_uri :
147
167
pysdk_model .image_uri = self .image_uri
148
168
149
- self ._original_deploy = pysdk_model .deploy
150
- pysdk_model .deploy = self ._transformers_model_builder_deploy_wrapper
151
169
return pysdk_model
152
170
153
171
@_capture_telemetry ("transformers.deploy" )
0 commit comments