Skip to content

Commit 393fa75

Browse files
committed
add session value for HF model
1 parent fe4bdd3 commit 393fa75

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/sagemaker/huggingface/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.predictor import Predictor
2727
from sagemaker.serializers import JSONSerializer
28+
from sagemaker.session import Session
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -169,6 +170,7 @@ def __init__(
169170
super(HuggingFaceModel, self).__init__(
170171
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
171172
)
173+
self.sagemaker_session = self.sagemaker_session or Session()
172174

173175
self.model_server_workers = model_server_workers
174176

@@ -262,11 +264,12 @@ def deploy(
262264
is not None. Otherwise, return None.
263265
"""
264266

265-
if instance_type.startswith("ml.inf") and not self.image_uri:
267+
if not self.image_uri and instance_type.startswith("ml.inf"):
266268
self.image_uri = self.serving_image_uri(
267269
region_name=self.sagemaker_session.boto_session.region_name,
268270
instance_type=instance_type,
269271
)
272+
270273
return super(HuggingFaceModel, self).deploy(
271274
initial_instance_count,
272275
instance_type,

0 commit comments

Comments
 (0)