Skip to content

Commit 48205ad

Browse files
committed
Integ test updates
1 parent cffe46a commit 48205ad

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/sagemaker/serve/builder/tei_builder.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from typing import Type
1717
from abc import ABC, abstractmethod
1818

19+
from sagemaker import image_uris
1920
from sagemaker.model import Model
2021
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
2122

22-
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
23+
from sagemaker.huggingface import HuggingFaceModel
2324
from sagemaker.serve.utils.local_hardware import (
2425
_get_nb_instance,
2526
)
@@ -84,11 +85,16 @@ def _set_to_tgi(self):
8485
logger.warning(messaging)
8586
self.model_server = ModelServer.TGI
8687

87-
def _create_tei_model(self) -> Type[Model]:
88+
def _create_tei_model(self, **kwargs) -> Type[Model]:
8889
"""Placeholder docstring"""
90+
if self.nb_instance_type and "instance_type" not in kwargs:
91+
kwargs.update({"instance_type": self.nb_instance_type})
92+
8993
if not self.image_uri:
90-
self.image_uri = get_huggingface_llm_image_uri(
91-
"huggingface-tei", session=self.sagemaker_session
94+
self.image_uri = image_uris.retrieve(
95+
"huggingface-tei",
96+
image_scope="inference",
97+
instance_type=kwargs.get("instance_type")
9298
)
9399

94100
pysdk_model = HuggingFaceModel(
@@ -164,9 +170,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
164170
if "endpoint_logging" not in kwargs:
165171
kwargs["endpoint_logging"] = True
166172

167-
if self.nb_instance_type and "instance_type" not in kwargs:
168-
kwargs.update({"instance_type": self.nb_instance_type})
169-
elif not self.nb_instance_type and "instance_type" not in kwargs:
173+
if not self.nb_instance_type and "instance_type" not in kwargs:
170174
raise ValueError(
171175
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
172176
)

src/sagemaker/serve/model_server/tgi/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _invoke_tgi_serving(self, request: object, content_type: str, accept: str):
7474
"""Placeholder docstring"""
7575
try:
7676
response = requests.post(
77-
f"http://{get_docker_host()}:8080/generate",
77+
f"http://{get_docker_host()}:8080/invocations",
7878
data=request,
7979
headers={"Content-Type": content_type, "Accept": accept},
8080
timeout=600,

0 commit comments

Comments
 (0)