@@ -250,7 +250,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
250
250
default = None , metadata = {"help" : "Define sagemaker session for execution" }
251
251
)
252
252
name : Optional [str ] = field (
253
- default = "model-name-" + uuid .uuid1 ().hex ,
253
+ default_factory = lambda : "model-name-" + uuid .uuid1 ().hex ,
254
254
metadata = {"help" : "Define the model name" },
255
255
)
256
256
mode : Optional [Mode ] = field (
@@ -1130,7 +1130,7 @@ def build(
1130
1130
def _get_processing_unit (self ):
1131
1131
"""Detects if the resource requirements are intended for a CPU or GPU instance."""
1132
1132
# Assume custom orchestrator will be deployed as an endpoint to a CPU instance
1133
- if not self .resource_requirements :
1133
+ if not self .resource_requirements or not self . resource_requirements . num_accelerators :
1134
1134
return "cpu"
1135
1135
for ic in self .modelbuilder_list or []:
1136
1136
if ic .resource_requirements .num_accelerators > 0 :
@@ -1171,10 +1171,10 @@ def _get_ic_resource_requirements(self, mb: ModelBuilder = None) -> ModelBuilder
1171
1171
1172
1172
@_capture_telemetry ("build_custom_orchestrator" )
1173
1173
def _get_smd_image_uri (self , processing_unit : str = None ) -> str :
1174
- """Gets the SMD Inference URI.
1174
+ """Gets the SMD Inference Image URI.
1175
1175
1176
1176
Returns:
1177
- str: Pytorch DLC URI.
1177
+ str: SMD Inference Image URI.
1178
1178
"""
1179
1179
from sagemaker import image_uris
1180
1180
import sys
@@ -1183,10 +1183,10 @@ def _get_smd_image_uri(self, processing_unit: str = None) -> str:
1183
1183
from packaging .version import Version
1184
1184
1185
1185
formatted_py_version = f"py{ sys .version_info .major } { sys .version_info .minor } "
1186
- if Version (f"{ sys .version_info .major } { sys .version_info .minor } " ) < Version ("3.11.11 " ):
1186
+ if Version (f"{ sys .version_info .major } { sys .version_info .minor } " ) < Version ("3.12 " ):
1187
1187
raise ValueError (
1188
1188
f"Found Python version { formatted_py_version } but"
1189
- f"Custom orchestrator deployment requires Python version >= 3.11.11 ."
1189
+ f"Custom orchestrator deployment requires Python version >= 3.12 ."
1190
1190
)
1191
1191
1192
1192
INSTANCE_TYPES = {"cpu" : "ml.c5.xlarge" , "gpu" : "ml.g5.4xlarge" }
@@ -1957,7 +1957,7 @@ def deploy(
1957
1957
] = None ,
1958
1958
update_endpoint : Optional [bool ] = False ,
1959
1959
custom_orchestrator_instance_type : str = None ,
1960
- custom_orchestrator_initial_instance_count : int = 1 ,
1960
+ custom_orchestrator_initial_instance_count : int = None ,
1961
1961
** kwargs ,
1962
1962
) -> Union [Predictor , Transformer , List [Predictor ]]:
1963
1963
"""Deploys the built Model.
@@ -2054,13 +2054,14 @@ def deploy(
2054
2054
)
2055
2055
if self ._deployables .get ("CustomOrchestrator" , None ):
2056
2056
custom_orchestrator = self ._deployables .get ("CustomOrchestrator" )
2057
+ if not custom_orchestrator_instance_type and not instance_type :
2058
+ logger .warning (
2059
+ "Deploying custom orchestrator as an endpoint but no instance type was "
2060
+ "set. Defaulting to `ml.c5.xlarge`."
2061
+ )
2062
+ custom_orchestrator_instance_type = "ml.c5.xlarge"
2063
+ custom_orchestrator_initial_instance_count = 1
2057
2064
if custom_orchestrator ["Mode" ] == "Endpoint" :
2058
- if not custom_orchestrator_instance_type :
2059
- logger .warning (
2060
- "Deploying custom orchestrator as an endpoint but no instance type was "
2061
- "set. Defaulting to `ml.c5.xlarge`."
2062
- )
2063
- custom_orchestrator_instance_type = "ml.c5.xlarge"
2064
2065
logger .info (
2065
2066
"Deploying custom orchestrator on instance type %s." ,
2066
2067
custom_orchestrator_instance_type ,
@@ -2073,13 +2074,18 @@ def deploy(
2073
2074
)
2074
2075
)
2075
2076
elif custom_orchestrator ["Mode" ] == "InferenceComponent" :
2077
+ logger .info (
2078
+ "Deploying custom orchestrator as an inference component "
2079
+ f"to endpoint { endpoint_name } "
2080
+ )
2076
2081
predictors .append (
2077
2082
self ._deploy_for_ic (
2078
2083
ic_data = custom_orchestrator ,
2079
2084
container_timeout_in_seconds = container_timeout_in_second ,
2080
2085
instance_type = custom_orchestrator_instance_type or instance_type ,
2081
2086
initial_instance_count = custom_orchestrator_initial_instance_count
2082
2087
or initial_instance_count ,
2088
+ endpoint_name = endpoint_name ,
2083
2089
** kwargs ,
2084
2090
)
2085
2091
)
0 commit comments