Skip to content

Commit e66e77e

Browse files
cj-zhangPravali Uppugunduri
authored and
Pravali Uppugunduri
committed
Bugfixes from e2e testing. (aws#1670)
1 parent 99b6c35 commit e66e77e

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

src/sagemaker/serve/builder/model_builder.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
250250
default=None, metadata={"help": "Define sagemaker session for execution"}
251251
)
252252
name: Optional[str] = field(
253-
default="model-name-" + uuid.uuid1().hex,
253+
default_factory=lambda: "model-name-" + uuid.uuid1().hex,
254254
metadata={"help": "Define the model name"},
255255
)
256256
mode: Optional[Mode] = field(
@@ -1130,7 +1130,7 @@ def build(
11301130
def _get_processing_unit(self):
11311131
"""Detects if the resource requirements are intended for a CPU or GPU instance."""
11321132
# Assume custom orchestrator will be deployed as an endpoint to a CPU instance
1133-
if not self.resource_requirements:
1133+
if not self.resource_requirements or not self.resource_requirements.num_accelerators:
11341134
return "cpu"
11351135
for ic in self.modelbuilder_list or []:
11361136
if ic.resource_requirements.num_accelerators > 0:
@@ -1171,10 +1171,10 @@ def _get_ic_resource_requirements(self, mb: ModelBuilder = None) -> ModelBuilder
11711171

11721172
@_capture_telemetry("build_custom_orchestrator")
11731173
def _get_smd_image_uri(self, processing_unit: str = None) -> str:
1174-
"""Gets the SMD Inference URI.
1174+
"""Gets the SMD Inference Image URI.
11751175
11761176
Returns:
1177-
str: Pytorch DLC URI.
1177+
str: SMD Inference Image URI.
11781178
"""
11791179
from sagemaker import image_uris
11801180
import sys
@@ -1183,10 +1183,10 @@ def _get_smd_image_uri(self, processing_unit: str = None) -> str:
11831183
from packaging.version import Version
11841184

11851185
formatted_py_version = f"py{sys.version_info.major}{sys.version_info.minor}"
1186-
if Version(f"{sys.version_info.major}{sys.version_info.minor}") < Version("3.11.11"):
1186+
if Version(f"{sys.version_info.major}{sys.version_info.minor}") < Version("3.12"):
11871187
raise ValueError(
11881188
f"Found Python version {formatted_py_version} but"
1189-
f"Custom orchestrator deployment requires Python version >= 3.11.11."
1189+
f"Custom orchestrator deployment requires Python version >= 3.12."
11901190
)
11911191

11921192
INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"}
@@ -1957,7 +1957,7 @@ def deploy(
19571957
] = None,
19581958
update_endpoint: Optional[bool] = False,
19591959
custom_orchestrator_instance_type: str = None,
1960-
custom_orchestrator_initial_instance_count: int = 1,
1960+
custom_orchestrator_initial_instance_count: int = None,
19611961
**kwargs,
19621962
) -> Union[Predictor, Transformer, List[Predictor]]:
19631963
"""Deploys the built Model.
@@ -2054,13 +2054,14 @@ def deploy(
20542054
)
20552055
if self._deployables.get("CustomOrchestrator", None):
20562056
custom_orchestrator = self._deployables.get("CustomOrchestrator")
2057+
if not custom_orchestrator_instance_type and not instance_type:
2058+
logger.warning(
2059+
"Deploying custom orchestrator as an endpoint but no instance type was "
2060+
"set. Defaulting to `ml.c5.xlarge`."
2061+
)
2062+
custom_orchestrator_instance_type = "ml.c5.xlarge"
2063+
custom_orchestrator_initial_instance_count = 1
20572064
if custom_orchestrator["Mode"] == "Endpoint":
2058-
if not custom_orchestrator_instance_type:
2059-
logger.warning(
2060-
"Deploying custom orchestrator as an endpoint but no instance type was "
2061-
"set. Defaulting to `ml.c5.xlarge`."
2062-
)
2063-
custom_orchestrator_instance_type = "ml.c5.xlarge"
20642065
logger.info(
20652066
"Deploying custom orchestrator on instance type %s.",
20662067
custom_orchestrator_instance_type,
@@ -2073,13 +2074,18 @@ def deploy(
20732074
)
20742075
)
20752076
elif custom_orchestrator["Mode"] == "InferenceComponent":
2077+
logger.info(
2078+
"Deploying custom orchestrator as an inference component "
2079+
f"to endpoint {endpoint_name}"
2080+
)
20762081
predictors.append(
20772082
self._deploy_for_ic(
20782083
ic_data=custom_orchestrator,
20792084
container_timeout_in_seconds=container_timeout_in_second,
20802085
instance_type=custom_orchestrator_instance_type or instance_type,
20812086
initial_instance_count=custom_orchestrator_initial_instance_count
20822087
or initial_instance_count,
2088+
endpoint_name=endpoint_name,
20832089
**kwargs,
20842090
)
20852091
)

src/sagemaker/serve/model_server/smd/custom_execution_inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ async def handler(request):
6767
:return: outputs to be send back to client
6868
"""
6969
if asyncio.iscoroutinefunction(custom_orchestrator.handle):
70-
return await custom_orchestrator.handle(request)
70+
return await custom_orchestrator.handle(request.body)
7171
else:
72-
return custom_orchestrator.handle(request)
72+
return custom_orchestrator.handle(request.body)

src/sagemaker/serve/spec/inference_base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ def __init__(self):
2424
@property
2525
def client(self):
2626
"""Boto3 SageMaker runtime client to use with custom orchestrator"""
27-
if not hasattr(self, "_client"):
27+
if not hasattr(self, "_client") or not self._client:
2828
from boto3 import Session
2929

3030
self._client = Session().client("sagemaker-runtime")
31-
3231
return self._client
3332

3433
@abstractmethod

0 commit comments

Comments
 (0)