19
19
import os
20
20
import re
21
21
import copy
22
+ from typing import List , Dict
22
23
23
24
import sagemaker
24
25
from sagemaker import (
38
39
from sagemaker .async_inference import AsyncInferenceConfig
39
40
from sagemaker .predictor_async import AsyncPredictor
40
41
from sagemaker .workflow import is_pipeline_variable
42
+ from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
41
43
42
44
LOGGER = logging .getLogger ("sagemaker" )
43
45
@@ -289,6 +291,7 @@ def __init__(
289
291
self .uploaded_code = None
290
292
self .repacked_model_data = None
291
293
294
+ @runnable_by_pipeline
292
295
def register (
293
296
self ,
294
297
content_types ,
@@ -310,12 +313,12 @@ def register(
310
313
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311
314
312
315
Args:
313
- content_types (list): The supported MIME types for the input data (default: None) .
314
- response_types (list): The supported MIME types for the output data (default: None) .
316
+ content_types (list): The supported MIME types for the input data.
317
+ response_types (list): The supported MIME types for the output data.
315
318
inference_instances (list): A list of the instance types that are used to
316
- generate inferences in real-time (default: None) .
319
+ generate inferences in real-time.
317
320
transform_instances (list): A list of the instance types on which a transformation
318
- job can be run or on which an endpoint can be deployed (default: None) .
321
+ job can be run or on which an endpoint can be deployed.
319
322
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320
323
using `model_package_name` makes the Model Package un-versioned (default: None).
321
324
model_package_group_name (str): Model Package Group name, exclusive to
@@ -366,12 +369,50 @@ def register(
366
369
model_package = self .sagemaker_session .create_model_package_from_containers (
367
370
** model_pkg_args
368
371
)
372
+ if isinstance (self .sagemaker_session , PipelineSession ):
373
+ return None
369
374
return ModelPackage (
370
375
role = self .role ,
371
376
model_data = self .model_data ,
372
377
model_package_arn = model_package .get ("ModelPackageArn" ),
373
378
)
374
379
380
+ @runnable_by_pipeline
381
+ def create (
382
+ self ,
383
+ instance_type : str = None ,
384
+ accelerator_type : str = None ,
385
+ serverless_inference_config : ServerlessInferenceConfig = None ,
386
+ tags : List [Dict [str , str ]] = None ,
387
+ ):
388
+ """Create a SageMaker Model Entity
389
+
390
+ Args:
391
+ instance_type (str): The EC2 instance type that this Model will be
392
+ used for, this is only used to determine if the image needs GPU
393
+ support or not (default: None).
394
+ accelerator_type (str): Type of Elastic Inference accelerator to
395
+ attach to an endpoint for model loading and inference, for
396
+ example, 'ml.eia1.medium'. If not specified, no Elastic
397
+ Inference accelerator will be attached to the endpoint (default: None).
398
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399
+ Specifies configuration related to serverless endpoint. Instance type is
400
+ not provided in serverless inference. So this is used to find image URIs
401
+ (default: None).
402
+ tags (List[Dict[str, str]]): The list of tags to add to
403
+ the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
404
+ 'tagvalue'}] For more information about tags, see
405
+ https://boto3.amazonaws.com/v1/documentation
406
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
407
+ """
408
+ # TODO: we should replace _create_sagemaker_model() with create()
409
+ self ._create_sagemaker_model (
410
+ instance_type = instance_type ,
411
+ accelerator_type = accelerator_type ,
412
+ tags = tags ,
413
+ serverless_inference_config = serverless_inference_config ,
414
+ )
415
+
375
416
def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
376
417
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
377
418
@@ -455,6 +496,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
455
496
if repack and self .model_data is not None and self .entry_point is not None :
456
497
if is_pipeline_variable (self .model_data ):
457
498
# model is not yet there, defer repacking to later during pipeline execution
499
+ if not isinstance (self .sagemaker_session , PipelineSession ):
500
+ # TODO: link the doc in the warning once ready
501
+ logging .warning (
502
+ "The model_data is a Pipeline variable of type %s, "
503
+ "which should be used under `PipelineSession` and "
504
+ "leverage `ModelStep` to create or register model. "
505
+ "Otherwise some functionalities e.g. "
506
+ "runtime repack may be missing" ,
507
+ type (self .model_data ),
508
+ )
509
+ return
510
+ self .sagemaker_session .context .need_runtime_repack .add (id (self ))
511
+ # Add the uploaded_code and repacked_model_data to update the container env
512
+ self .repacked_model_data = self .model_data
513
+ self .uploaded_code = fw_utils .UploadedCode (
514
+ s3_prefix = self .repacked_model_data ,
515
+ script_name = os .path .basename (self .entry_point ),
516
+ )
458
517
return
459
518
if local_code and self .model_data .startswith ("file://" ):
460
519
repacked_model_data = self .model_data
@@ -538,22 +597,29 @@ def _create_sagemaker_model(
538
597
serverless_inference_config = serverless_inference_config ,
539
598
)
540
599
541
- self ._ensure_base_name_if_needed (
542
- image_uri = container_def ["Image" ], script_uri = self .source_dir , model_uri = self .model_data
543
- )
544
- self ._set_model_name_if_needed ()
600
+ if not isinstance (self .sagemaker_session , PipelineSession ):
601
+ # _base_name, model_name are not needed under PipelineSession.
602
+ # the model_data may be Pipeline variable
603
+ # which may break the _base_name generation
604
+ self ._ensure_base_name_if_needed (
605
+ image_uri = container_def ["Image" ],
606
+ script_uri = self .source_dir ,
607
+ model_uri = self .model_data ,
608
+ )
609
+ self ._set_model_name_if_needed ()
545
610
546
611
enable_network_isolation = self .enable_network_isolation ()
547
612
548
613
self ._init_sagemaker_session_if_does_not_exist (instance_type )
549
- self . sagemaker_session . create_model (
550
- self .name ,
551
- self .role ,
552
- container_def ,
614
+ create_model_args = dict (
615
+ name = self .name ,
616
+ role = self .role ,
617
+ container_defs = container_def ,
553
618
vpc_config = self .vpc_config ,
554
619
enable_network_isolation = enable_network_isolation ,
555
620
tags = tags ,
556
621
)
622
+ self .sagemaker_session .create_model (** create_model_args )
557
623
558
624
def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
559
625
"""Create a base name from the image URI if there is no model name provided.
0 commit comments