38
38
from sagemaker .predictor import Predictor
39
39
from sagemaker .serve .save_retrive .version_1_0_0 .metadata .metadata import Metadata
40
40
from sagemaker .serve .spec .inference_spec import InferenceSpec
41
+ from sagemaker .serve .utils import task
42
+ from sagemaker .serve .utils .exceptions import TaskNotFoundException
41
43
from sagemaker .serve .utils .predictors import _get_local_mode_predictor
42
44
from sagemaker .serve .detector .image_detector import (
43
45
auto_detect_container ,
@@ -605,7 +607,7 @@ def build(
605
607
606
608
self .serve_settings = self ._get_serve_setting ()
607
609
608
- self ._is_custom_image_uri = self .image_uri is None
610
+ self ._is_custom_image_uri = self .image_uri is not None
609
611
610
612
if isinstance (self .model , str ):
611
613
if self ._is_jumpstart_model_id ():
@@ -616,7 +618,12 @@ def build(
616
618
hf_model_md = get_huggingface_model_metadata (
617
619
self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
618
620
)
619
- if hf_model_md .get ("pipeline_tag" ) == "text-generation" : # pylint: disable=R1705
621
+
622
+ model_task = hf_model_md .get ("pipeline_tag" )
623
+ if self .schema_builder is None and model_task :
624
+ self ._schema_builder_init (model_task )
625
+
626
+ if model_task == "text-generation" : # pylint: disable=R1705
620
627
return self ._build_for_tgi ()
621
628
else :
622
629
return self ._build_for_transformers ()
@@ -674,3 +681,18 @@ def validate(self, model_dir: str) -> Type[bool]:
674
681
"""
675
682
676
683
return get_metadata (model_dir )
684
+
685
+ def _schema_builder_init (self , model_task : str ):
686
+ """Initialize the schema builder
687
+
688
+ Args:
689
+ model_task (str): Required, the task name
690
+
691
+ Raises:
692
+ TaskNotFoundException: If the I/O schema for the given task is not found.
693
+ """
694
+ try :
695
+ sample_inputs , sample_outputs = task .retrieve_local_schemas (model_task )
696
+ self .schema_builder = SchemaBuilder (sample_inputs , sample_outputs )
697
+ except ValueError :
698
+ raise TaskNotFoundException (f"Schema builder for { model_task } could not be found." )
0 commit comments