54
54
from sagemaker .serve .validations .check_image_and_hardware_type import (
55
55
validate_image_uri_and_hardware ,
56
56
)
57
+ from sagemaker .workflow .entities import PipelineVariable
57
58
from sagemaker .huggingface .llm_utils import get_huggingface_model_metadata
58
59
59
60
logger = logging .getLogger (__name__ )
@@ -81,7 +82,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
81
82
82
83
* ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint
83
84
* ``Mode.LOCAL_CONTAINER``: Launch locally with a container
84
-
85
85
shared_libs (List[str]): Any shared libraries you want to bring into
86
86
the model packaging.
87
87
dependencies (Optional[Dict[str, Any]): The dependencies of the model
@@ -122,6 +122,15 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
122
122
``invoke`` and ``load`` functions.
123
123
image_uri (Optional[str]): The container image uri (which is derived from a
124
124
SageMaker-based container).
125
+ image_config (dict[str, str] or dict[str, PipelineVariable]): Specifies
126
+ whether the image of model container is pulled from ECR, or private
127
+ registry in your VPC. By default it is set to pull model container
128
+ image from ECR. (default: None).
129
+ vpc_config ( Optional[Dict[str, List[Union[str, PipelineVariable]]]]):
130
+ The VpcConfig set on the model (default: None)
131
+ * 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids.
132
+ * 'SecurityGroupIds' (List[Union[str, PipelineVariable]]]): List of security group
133
+ ids.
125
134
model_server (Optional[ModelServer]): The model server to which to deploy.
126
135
You need to provide this argument when you specify an ``image_uri``
127
136
in order for model builder to build the artifacts correctly (according
@@ -204,6 +213,23 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
204
213
image_uri : Optional [str ] = field (
205
214
default = None , metadata = {"help" : "Define the container image uri" }
206
215
)
216
+ image_config : Optional [Dict [str , Union [str , PipelineVariable ]]] = field (
217
+ default = None ,
218
+ metadata = {
219
+ "help" : "Specifies whether the image of model container is pulled from ECR,"
220
+ " or private registry in your VPC. By default it is set to pull model "
221
+ "container image from ECR. (default: None)."
222
+ },
223
+ )
224
+ vpc_config : Optional [Dict [str , List [Union [str , PipelineVariable ]]]] = field (
225
+ default = None ,
226
+ metadata = {
227
+ "help" : "The VpcConfig set on the model (default: None)."
228
+ "* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids."
229
+ "* ''SecurityGroupIds'' (List[Union[str, PipelineVariable]]): List of"
230
+ " security group ids."
231
+ },
232
+ )
207
233
model_server : Optional [ModelServer ] = field (
208
234
default = None , metadata = {"help" : "Define the model server to deploy to." }
209
235
)
@@ -386,6 +412,8 @@ def _create_model(self):
386
412
# TODO: we should create model as per the framework
387
413
self .pysdk_model = Model (
388
414
image_uri = self .image_uri ,
415
+ image_config = self .image_config ,
416
+ vpc_config = self .vpc_config ,
389
417
model_data = self .s3_upload_path ,
390
418
role = self .serve_settings .role_arn ,
391
419
env = self .env_vars ,
@@ -543,15 +571,16 @@ def build(
543
571
self ,
544
572
mode : Type [Mode ] = None ,
545
573
role_arn : str = None ,
546
- sagemaker_session : str = None ,
574
+ sagemaker_session : Optional [ Session ] = None ,
547
575
) -> Type [Model ]:
548
576
"""Create a deployable ``Model`` instance with ``ModelBuilder``.
549
577
550
578
Args:
551
579
mode (Type[Mode], optional): The mode. Defaults to ``None``.
552
580
role_arn (str, optional): The IAM role arn. Defaults to ``None``.
553
- sagemaker_session (str, optional): The SageMaker session to use
554
- for the execution. Defaults to ``None``.
581
+ sagemaker_session (Optional[Session]): Session object which manages interactions
582
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
583
+ function creates one using the default AWS configuration chain.
555
584
556
585
Returns:
557
586
Type[Model]: A deployable ``Model`` object.
@@ -562,10 +591,7 @@ def build(
562
591
self .mode = mode
563
592
if role_arn :
564
593
self .role_arn = role_arn
565
- if sagemaker_session :
566
- self .sagemaker_session = sagemaker_session
567
- elif not self .sagemaker_session :
568
- self .sagemaker_session = Session ()
594
+ self .sagemaker_session = sagemaker_session or Session ()
569
595
570
596
self .sagemaker_session .settings ._local_download_dir = self .model_path
571
597
@@ -607,7 +633,7 @@ def save(
607
633
self ,
608
634
save_path : Optional [str ] = None ,
609
635
s3_path : Optional [str ] = None ,
610
- sagemaker_session : Optional [str ] = None ,
636
+ sagemaker_session : Optional [Session ] = None ,
611
637
role_arn : Optional [str ] = None ,
612
638
) -> Type [Model ]:
613
639
"""WARNING: This function is expremental and not intended for production use.
@@ -618,7 +644,7 @@ def save(
618
644
save_path (Optional[str]): The path where you want to save resources.
619
645
s3_path (Optional[str]): The path where you want to upload resources.
620
646
"""
621
- self .sagemaker_session = sagemaker_session if sagemaker_session else Session ()
647
+ self .sagemaker_session = sagemaker_session or Session ()
622
648
623
649
if role_arn :
624
650
self .role_arn = role_arn
0 commit comments