@@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
124
124
into a stream. All translations between the server and the client are handled
125
125
automatically with the specified input and output.
126
126
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
127
- inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
128
- ``inference_spec`` is required for the model builder to build the artifact.
127
+ inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
128
+ is required for the model builder to build the artifact.
129
129
inference_spec (InferenceSpec): The inference spec file with your customized
130
130
``invoke`` and ``load`` functions.
131
131
image_uri (Optional[str]): The container image uri (which is derived from a
@@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
145
145
to the model server). Possible values for this argument are
146
146
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
147
147
``TRITON``, and``TGI``.
148
+ model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace
149
+ model metadata. Currently ``HF_TASK`` is overridable.
148
150
"""
149
151
150
152
model_path : Optional [str ] = field (
@@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
241
243
model_server : Optional [ModelServer ] = field (
242
244
default = None , metadata = {"help" : "Define the model server to deploy to." }
243
245
)
246
+ model_metadata : Optional [Dict [str , Any ]] = field (
247
+ default = None ,
248
+ metadata = {"help" : "Define the model metadata to override, currently supports `HF_TASK`" },
249
+ )
244
250
245
251
def _build_validations (self ):
246
252
"""Placeholder docstring"""
@@ -616,6 +622,9 @@ def build( # pylint: disable=R0911
616
622
self ._is_custom_image_uri = self .image_uri is not None
617
623
618
624
if isinstance (self .model , str ):
625
+ model_task = None
626
+ if self .model_metadata :
627
+ model_task = self .model_metadata .get ("HF_TASK" )
619
628
if self ._is_jumpstart_model_id ():
620
629
return self ._build_for_jumpstart ()
621
630
if self ._is_djl (): # pylint: disable=R1705
@@ -625,10 +634,10 @@ def build( # pylint: disable=R0911
625
634
self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
626
635
)
627
636
628
- model_task = hf_model_md .get ("pipeline_tag" )
629
- if self .schema_builder is None and model_task :
637
+ if model_task is None :
638
+ model_task = hf_model_md .get ("pipeline_tag" )
639
+ if self .schema_builder is None and model_task is not None :
630
640
self ._schema_builder_init (model_task )
631
-
632
641
if model_task == "text-generation" : # pylint: disable=R1705
633
642
return self ._build_for_tgi ()
634
643
elif self ._can_fit_on_single_gpu ():
0 commit comments