@@ -155,6 +155,8 @@ def __init__(
155
155
entry_point : Optional [Union [str , PipelineVariable ]] = None ,
156
156
dependencies : Optional [List [Union [str ]]] = None ,
157
157
instance_groups : Optional [List [InstanceGroup ]] = None ,
158
+ training_repository_access_mode : Optional [Union [str , PipelineVariable ]] = None ,
159
+ training_repository_credentials_provider_arn : Optional [Union [str , PipelineVariable ]] = None ,
158
160
** kwargs ,
159
161
):
160
162
"""Initialize an ``EstimatorBase`` instance.
@@ -489,6 +491,18 @@ def __init__(
489
491
`Train Using a Heterogeneous Cluster
490
492
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
491
493
in the *Amazon SageMaker developer guide*.
494
+ training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
495
+ Docker image that contains the training algorithm (default: None).
496
+ Set this to one of the following values:
497
+ * 'Platform' - The training image is hosted in Amazon ECR.
498
+ * 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
499
+ When it's default to None, its behavior will be same as 'Platform' - image is hosted
500
+ in ECR.
501
+ training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
502
+ (ARN) of an AWS Lambda function that provides credentials to authenticate to the
503
+ private Docker registry where your training image is hosted (default: None).
504
+ When it's set to None, SageMaker will not do authentication before pulling the image
505
+ in the private Docker registry.
492
506
"""
493
507
instance_count = renamed_kwargs (
494
508
"train_instance_count" , "instance_count" , instance_count , kwargs
@@ -536,7 +550,9 @@ def __init__(
536
550
self .dependencies = dependencies or []
537
551
self .uploaded_code = None
538
552
self .tags = add_jumpstart_tags (
539
- tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
553
+ tags = tags ,
554
+ training_model_uri = self .model_uri ,
555
+ training_script_uri = self .source_dir ,
540
556
)
541
557
if self .instance_type in ("local" , "local_gpu" ):
542
558
if self .instance_type == "local_gpu" and self .instance_count > 1 :
@@ -571,6 +587,12 @@ def __init__(
571
587
self .subnets = subnets
572
588
self .security_group_ids = security_group_ids
573
589
590
+ # training image configs
591
+ self .training_repository_access_mode = training_repository_access_mode
592
+ self .training_repository_credentials_provider_arn = (
593
+ training_repository_credentials_provider_arn
594
+ )
595
+
574
596
self .encrypt_inter_container_traffic = encrypt_inter_container_traffic
575
597
self .use_spot_instances = use_spot_instances
576
598
self .max_wait = max_wait
@@ -651,7 +673,8 @@ def _ensure_base_job_name(self):
651
673
self .base_job_name
652
674
or get_jumpstart_base_name_if_jumpstart_model (self .source_dir , self .model_uri )
653
675
or base_name_from_image (
654
- self .training_image_uri (), default_base_name = EstimatorBase .JOB_CLASS_NAME
676
+ self .training_image_uri (),
677
+ default_base_name = EstimatorBase .JOB_CLASS_NAME ,
655
678
)
656
679
)
657
680
@@ -1405,7 +1428,10 @@ def deploy(
1405
1428
self ._ensure_base_job_name ()
1406
1429
1407
1430
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model (
1408
- kwargs .get ("source_dir" ), self .source_dir , kwargs .get ("model_data" ), self .model_uri
1431
+ kwargs .get ("source_dir" ),
1432
+ self .source_dir ,
1433
+ kwargs .get ("model_data" ),
1434
+ self .model_uri ,
1409
1435
)
1410
1436
default_name = (
1411
1437
name_from_base (jumpstart_base_name )
@@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
1638
1664
init_params ["algorithm_arn" ] = job_details ["AlgorithmSpecification" ]["AlgorithmName" ]
1639
1665
elif "TrainingImage" in job_details ["AlgorithmSpecification" ]:
1640
1666
init_params ["image_uri" ] = job_details ["AlgorithmSpecification" ]["TrainingImage" ]
1667
+ if "TrainingImageConfig" in job_details ["AlgorithmSpecification" ]:
1668
+ init_params ["training_repository_access_mode" ] = job_details [
1669
+ "AlgorithmSpecification"
1670
+ ]["TrainingImageConfig" ].get ("TrainingRepositoryAccessMode" )
1671
+ init_params ["training_repository_credentials_provider_arn" ] = (
1672
+ job_details ["AlgorithmSpecification" ]["TrainingImageConfig" ]
1673
+ .get ("TrainingRepositoryAuthConfig" , {})
1674
+ .get ("TrainingRepositoryCredentialsProviderArn" )
1675
+ )
1641
1676
else :
1642
1677
raise RuntimeError (
1643
1678
"Invalid AlgorithmSpecification. Either TrainingImage or "
@@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
2118
2153
else :
2119
2154
train_args ["retry_strategy" ] = None
2120
2155
2156
+ if estimator .training_repository_access_mode is not None :
2157
+ training_image_config = {
2158
+ "TrainingRepositoryAccessMode" : estimator .training_repository_access_mode
2159
+ }
2160
+ if estimator .training_repository_credentials_provider_arn is not None :
2161
+ training_image_config ["TrainingRepositoryAuthConfig" ] = {}
2162
+ training_image_config ["TrainingRepositoryAuthConfig" ][
2163
+ "TrainingRepositoryCredentialsProviderArn"
2164
+ ] = estimator .training_repository_credentials_provider_arn
2165
+ train_args ["training_image_config" ] = training_image_config
2166
+
2121
2167
# encrypt_inter_container_traffic may be a pipeline variable place holder object
2122
2168
# which is parsed in execution time
2123
2169
if estimator .encrypt_inter_container_traffic :
@@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri):
2182
2228
2183
2229
@classmethod
2184
2230
def update (
2185
- cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2231
+ cls ,
2232
+ estimator ,
2233
+ profiler_rule_configs = None ,
2234
+ profiler_config = None ,
2235
+ resource_config = None ,
2186
2236
):
2187
2237
"""Update a running Amazon SageMaker training job.
2188
2238
@@ -2321,6 +2371,8 @@ def __init__(
2321
2371
entry_point : Optional [Union [str , PipelineVariable ]] = None ,
2322
2372
dependencies : Optional [List [str ]] = None ,
2323
2373
instance_groups : Optional [List [InstanceGroup ]] = None ,
2374
+ training_repository_access_mode : Optional [Union [str , PipelineVariable ]] = None ,
2375
+ training_repository_credentials_provider_arn : Optional [Union [str , PipelineVariable ]] = None ,
2324
2376
** kwargs ,
2325
2377
):
2326
2378
"""Initialize an ``Estimator`` instance.
@@ -2654,6 +2706,18 @@ def __init__(
2654
2706
`Train Using a Heterogeneous Cluster
2655
2707
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
2656
2708
in the *Amazon SageMaker developer guide*.
2709
+ training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
2710
+ Docker image that contains the training algorithm (default: None).
2711
+ Set this to one of the following values:
2712
+ * 'Platform' - The training image is hosted in Amazon ECR.
2713
+ * 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
2714
+ When it's default to None, its behavior will be same as 'Platform' - image is hosted
2715
+ in ECR.
2716
+ training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
2717
+ (ARN) of an AWS Lambda function that provides credentials to authenticate to the
2718
+ private Docker registry where your training image is hosted (default: None).
2719
+ When it's set to None, SageMaker will not do authentication before pulling the image
2720
+ in the private Docker registry.
2657
2721
"""
2658
2722
self .image_uri = image_uri
2659
2723
self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -2698,6 +2762,8 @@ def __init__(
2698
2762
dependencies = dependencies ,
2699
2763
hyperparameters = hyperparameters ,
2700
2764
instance_groups = instance_groups ,
2765
+ training_repository_access_mode = training_repository_access_mode ,
2766
+ training_repository_credentials_provider_arn = training_repository_credentials_provider_arn , # noqa: E501 # pylint: disable=line-too-long
2701
2767
** kwargs ,
2702
2768
)
2703
2769
0 commit comments