Skip to content

Commit b7245bf

Browse files
YingqiColJoseJuan98
authored andcommitted
feature: Add TrainingImageConfig support for SageMaker training jobs (aws#3603)
1 parent 23361f9 commit b7245bf

File tree

4 files changed

+334
-34
lines changed

4 files changed

+334
-34
lines changed

src/sagemaker/estimator.py

+70-4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def __init__(
155155
entry_point: Optional[Union[str, PipelineVariable]] = None,
156156
dependencies: Optional[List[Union[str]]] = None,
157157
instance_groups: Optional[List[InstanceGroup]] = None,
158+
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
159+
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
158160
**kwargs,
159161
):
160162
"""Initialize an ``EstimatorBase`` instance.
@@ -489,6 +491,18 @@ def __init__(
489491
`Train Using a Heterogeneous Cluster
490492
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
491493
in the *Amazon SageMaker developer guide*.
494+
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
495+
Docker image that contains the training algorithm (default: None).
496+
Set this to one of the following values:
497+
* 'Platform' - The training image is hosted in Amazon ECR.
498+
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
499+
When it's default to None, its behavior will be same as 'Platform' - image is hosted
500+
in ECR.
501+
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
502+
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
503+
private Docker registry where your training image is hosted (default: None).
504+
When it's set to None, SageMaker will not do authentication before pulling the image
505+
in the private Docker registry.
492506
"""
493507
instance_count = renamed_kwargs(
494508
"train_instance_count", "instance_count", instance_count, kwargs
@@ -536,7 +550,9 @@ def __init__(
536550
self.dependencies = dependencies or []
537551
self.uploaded_code = None
538552
self.tags = add_jumpstart_tags(
539-
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
553+
tags=tags,
554+
training_model_uri=self.model_uri,
555+
training_script_uri=self.source_dir,
540556
)
541557
if self.instance_type in ("local", "local_gpu"):
542558
if self.instance_type == "local_gpu" and self.instance_count > 1:
@@ -571,6 +587,12 @@ def __init__(
571587
self.subnets = subnets
572588
self.security_group_ids = security_group_ids
573589

590+
# training image configs
591+
self.training_repository_access_mode = training_repository_access_mode
592+
self.training_repository_credentials_provider_arn = (
593+
training_repository_credentials_provider_arn
594+
)
595+
574596
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
575597
self.use_spot_instances = use_spot_instances
576598
self.max_wait = max_wait
@@ -651,7 +673,8 @@ def _ensure_base_job_name(self):
651673
self.base_job_name
652674
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
653675
or base_name_from_image(
654-
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
676+
self.training_image_uri(),
677+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
655678
)
656679
)
657680

@@ -1405,7 +1428,10 @@ def deploy(
14051428
self._ensure_base_job_name()
14061429

14071430
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1408-
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
1431+
kwargs.get("source_dir"),
1432+
self.source_dir,
1433+
kwargs.get("model_data"),
1434+
self.model_uri,
14091435
)
14101436
default_name = (
14111437
name_from_base(jumpstart_base_name)
@@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
16381664
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
16391665
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
16401666
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
1667+
if "TrainingImageConfig" in job_details["AlgorithmSpecification"]:
1668+
init_params["training_repository_access_mode"] = job_details[
1669+
"AlgorithmSpecification"
1670+
]["TrainingImageConfig"].get("TrainingRepositoryAccessMode")
1671+
init_params["training_repository_credentials_provider_arn"] = (
1672+
job_details["AlgorithmSpecification"]["TrainingImageConfig"]
1673+
.get("TrainingRepositoryAuthConfig", {})
1674+
.get("TrainingRepositoryCredentialsProviderArn")
1675+
)
16411676
else:
16421677
raise RuntimeError(
16431678
"Invalid AlgorithmSpecification. Either TrainingImage or "
@@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
21182153
else:
21192154
train_args["retry_strategy"] = None
21202155

2156+
if estimator.training_repository_access_mode is not None:
2157+
training_image_config = {
2158+
"TrainingRepositoryAccessMode": estimator.training_repository_access_mode
2159+
}
2160+
if estimator.training_repository_credentials_provider_arn is not None:
2161+
training_image_config["TrainingRepositoryAuthConfig"] = {}
2162+
training_image_config["TrainingRepositoryAuthConfig"][
2163+
"TrainingRepositoryCredentialsProviderArn"
2164+
] = estimator.training_repository_credentials_provider_arn
2165+
train_args["training_image_config"] = training_image_config
2166+
21212167
# encrypt_inter_container_traffic may be a pipeline variable place holder object
21222168
# which is parsed in execution time
21232169
if estimator.encrypt_inter_container_traffic:
@@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri):
21822228

21832229
@classmethod
21842230
def update(
2185-
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2231+
cls,
2232+
estimator,
2233+
profiler_rule_configs=None,
2234+
profiler_config=None,
2235+
resource_config=None,
21862236
):
21872237
"""Update a running Amazon SageMaker training job.
21882238
@@ -2321,6 +2371,8 @@ def __init__(
23212371
entry_point: Optional[Union[str, PipelineVariable]] = None,
23222372
dependencies: Optional[List[str]] = None,
23232373
instance_groups: Optional[List[InstanceGroup]] = None,
2374+
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
2375+
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
23242376
**kwargs,
23252377
):
23262378
"""Initialize an ``Estimator`` instance.
@@ -2654,6 +2706,18 @@ def __init__(
26542706
`Train Using a Heterogeneous Cluster
26552707
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
26562708
in the *Amazon SageMaker developer guide*.
2709+
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
2710+
Docker image that contains the training algorithm (default: None).
2711+
Set this to one of the following values:
2712+
* 'Platform' - The training image is hosted in Amazon ECR.
2713+
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
2714+
When it's default to None, its behavior will be same as 'Platform' - image is hosted
2715+
in ECR.
2716+
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
2717+
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
2718+
private Docker registry where your training image is hosted (default: None).
2719+
When it's set to None, SageMaker will not do authentication before pulling the image
2720+
in the private Docker registry.
26572721
"""
26582722
self.image_uri = image_uri
26592723
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -2698,6 +2762,8 @@ def __init__(
26982762
dependencies=dependencies,
26992763
hyperparameters=hyperparameters,
27002764
instance_groups=instance_groups,
2765+
training_repository_access_mode=training_repository_access_mode,
2766+
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
27012767
**kwargs,
27022768
)
27032769

src/sagemaker/session.py

+50
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def train( # noqa: C901
488488
metric_definitions,
489489
enable_network_isolation=False,
490490
image_uri=None,
491+
training_image_config=None,
491492
algorithm_arn=None,
492493
encrypt_inter_container_traffic=False,
493494
use_spot_instances=False,
@@ -548,6 +549,28 @@ def train( # noqa: C901
548549
enable_network_isolation (bool): Whether to request for the training job to run with
549550
network isolation or not.
550551
image_uri (str): Docker image containing training code.
552+
training_image_config(dict): Training image configuration.
553+
Optionally, the dict can contain 'TrainingRepositoryAccessMode' and
554+
'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig').
555+
For example,
556+
557+
.. code:: python
558+
559+
training_image_config = {
560+
"TrainingRepositoryAccessMode": "Vpc",
561+
"TrainingRepositoryAuthConfig": {
562+
"TrainingRepositoryCredentialsProviderArn":
563+
"arn:aws:lambda:us-west-2:1234567890:function:test"
564+
},
565+
}
566+
567+
If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed
568+
through a private Docker registry in customer Vpc. If it's set to Platform or None,
569+
the training image is accessed through ECR.
570+
If TrainingRepositoryCredentialsProviderArn is provided, the credentials to
571+
authenticate to the private Docker registry will be retrieved from this AWS Lambda
572+
function. (default: ``None``). When it's set to None, SageMaker will not do
573+
authentication before pulling the image in the private Docker registry.
551574
algorithm_arn (str): Algorithm Arn from Marketplace.
552575
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
553576
containers is encrypted for the training job (default: ``False``).
@@ -606,6 +629,7 @@ def train( # noqa: C901
606629
metric_definitions=metric_definitions,
607630
enable_network_isolation=enable_network_isolation,
608631
image_uri=image_uri,
632+
training_image_config=training_image_config,
609633
algorithm_arn=algorithm_arn,
610634
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
611635
use_spot_instances=use_spot_instances,
@@ -644,6 +668,7 @@ def _get_train_request( # noqa: C901
644668
metric_definitions,
645669
enable_network_isolation=False,
646670
image_uri=None,
671+
training_image_config=None,
647672
algorithm_arn=None,
648673
encrypt_inter_container_traffic=False,
649674
use_spot_instances=False,
@@ -704,6 +729,28 @@ def _get_train_request( # noqa: C901
704729
enable_network_isolation (bool): Whether to request for the training job to run with
705730
network isolation or not.
706731
image_uri (str): Docker image containing training code.
732+
training_image_config(dict): Training image configuration.
733+
Optionally, the dict can contain 'TrainingRepositoryAccessMode' and
734+
'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig').
735+
For example,
736+
737+
.. code:: python
738+
739+
training_image_config = {
740+
"TrainingRepositoryAccessMode": "Vpc",
741+
"TrainingRepositoryAuthConfig": {
742+
"TrainingRepositoryCredentialsProviderArn":
743+
"arn:aws:lambda:us-west-2:1234567890:function:test"
744+
},
745+
}
746+
747+
If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed
748+
through a private Docker registry in customer Vpc. If it's set to Platform or None,
749+
the training image is accessed through ECR.
750+
If TrainingRepositoryCredentialsProviderArn is provided, the credentials to
751+
authenticate to the private Docker registry will be retrieved from this AWS Lambda
752+
function. (default: ``None``). When it's set to None, SageMaker will not do
753+
authentication before pulling the image in the private Docker registry.
707754
algorithm_arn (str): Algorithm Arn from Marketplace.
708755
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
709756
containers is encrypted for the training job (default: ``False``).
@@ -768,6 +815,9 @@ def _get_train_request( # noqa: C901
768815
if image_uri is not None:
769816
train_request["AlgorithmSpecification"]["TrainingImage"] = image_uri
770817

818+
if training_image_config is not None:
819+
train_request["AlgorithmSpecification"]["TrainingImageConfig"] = training_image_config
820+
771821
if algorithm_arn is not None:
772822
train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn
773823

tests/unit/test_estimator.py

+88
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
CODECOMMIT_BRANCH = "master"
9494
REPO_DIR = "/tmp/repo_dir"
9595
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
96+
TRAINING_REPOSITORY_ACCESS_MODE = "VPC"
97+
TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN = "arn:aws:lambda:us-west-2:1234567890:function:test"
9698

9799
DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}
98100

@@ -391,6 +393,70 @@ def test_framework_with_keep_alive_period(sagemaker_session):
391393
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS
392394

393395

396+
def test_framework_with_both_training_repository_config(sagemaker_session):
397+
f = DummyFramework(
398+
entry_point=SCRIPT_PATH,
399+
role=ROLE,
400+
sagemaker_session=sagemaker_session,
401+
instance_groups=[
402+
InstanceGroup("group1", "ml.c4.xlarge", 1),
403+
InstanceGroup("group2", "ml.m4.xlarge", 2),
404+
],
405+
training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE,
406+
training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN,
407+
)
408+
f.fit("s3://mydata")
409+
sagemaker_session.train.assert_called_once()
410+
_, args = sagemaker_session.train.call_args
411+
assert (
412+
args["training_image_config"]["TrainingRepositoryAccessMode"]
413+
== TRAINING_REPOSITORY_ACCESS_MODE
414+
)
415+
assert (
416+
args["training_image_config"]["TrainingRepositoryAuthConfig"][
417+
"TrainingRepositoryCredentialsProviderArn"
418+
]
419+
== TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN
420+
)
421+
422+
423+
def test_framework_with_training_repository_access_mode(sagemaker_session):
424+
f = DummyFramework(
425+
entry_point=SCRIPT_PATH,
426+
role=ROLE,
427+
sagemaker_session=sagemaker_session,
428+
instance_groups=[
429+
InstanceGroup("group1", "ml.c4.xlarge", 1),
430+
InstanceGroup("group2", "ml.m4.xlarge", 2),
431+
],
432+
training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE,
433+
)
434+
f.fit("s3://mydata")
435+
sagemaker_session.train.assert_called_once()
436+
_, args = sagemaker_session.train.call_args
437+
assert (
438+
args["training_image_config"]["TrainingRepositoryAccessMode"]
439+
== TRAINING_REPOSITORY_ACCESS_MODE
440+
)
441+
assert "TrainingRepositoryAuthConfig" not in args["training_image_config"]
442+
443+
444+
def test_framework_without_training_repository_config(sagemaker_session):
445+
f = DummyFramework(
446+
entry_point=SCRIPT_PATH,
447+
role=ROLE,
448+
sagemaker_session=sagemaker_session,
449+
instance_groups=[
450+
InstanceGroup("group1", "ml.c4.xlarge", 1),
451+
InstanceGroup("group2", "ml.m4.xlarge", 2),
452+
],
453+
)
454+
f.fit("s3://mydata")
455+
sagemaker_session.train.assert_called_once()
456+
_, args = sagemaker_session.train.call_args
457+
assert args.get("training_image_config") is None
458+
459+
394460
def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
395461
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
396462
base_config=rule_configs.stalled_training_rule(),
@@ -3763,6 +3829,28 @@ def test_prepare_init_params_from_job_description_with_retry_strategy():
37633829
assert init_params["max_retry_attempts"] == 2
37643830

37653831

3832+
def test_prepare_init_params_from_job_description_with_training_image_config():
3833+
job_description = RETURNED_JOB_DESCRIPTION.copy()
3834+
job_description["AlgorithmSpecification"]["TrainingImageConfig"] = {
3835+
"TrainingRepositoryAccessMode": "Vpc",
3836+
"TrainingRepositoryAuthConfig": {
3837+
"TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567890:function:test"
3838+
},
3839+
}
3840+
3841+
init_params = EstimatorBase._prepare_init_params_from_job_description(
3842+
job_details=job_description
3843+
)
3844+
3845+
assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
3846+
assert init_params["instance_count"] == 1
3847+
assert init_params["training_repository_access_mode"] == "Vpc"
3848+
assert (
3849+
init_params["training_repository_credentials_provider_arn"]
3850+
== "arn:aws:lambda:us-west-2:1234567890:function:test"
3851+
)
3852+
3853+
37663854
def test_prepare_init_params_from_job_description_with_invalid_training_job():
37673855

37683856
invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()

0 commit comments

Comments
 (0)