Skip to content

feature: Add TrainingImageConfig support for SageMaker training jobs #3603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def __init__(
entry_point: Optional[Union[str, PipelineVariable]] = None,
dependencies: Optional[List[Union[str]]] = None,
instance_groups: Optional[List[InstanceGroup]] = None,
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -489,6 +491,18 @@ def __init__(
`Train Using a Heterogeneous Cluster
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
in the *Amazon SageMaker developer guide*.
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
Docker image that contains the training algorithm (default: None).
Set this to one of the following values:
* 'Platform' - The training image is hosted in Amazon ECR.
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
When it's default to None, its behavior will be same as 'Platform' - image is hosted
in ECR.
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
private Docker registry where your training image is hosted (default: None).
When it's set to None, SageMaker will not do authentication before pulling the image
in the private Docker registry.
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -536,7 +550,9 @@ def __init__(
self.dependencies = dependencies or []
self.uploaded_code = None
self.tags = add_jumpstart_tags(
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
tags=tags,
training_model_uri=self.model_uri,
training_script_uri=self.source_dir,
)
if self.instance_type in ("local", "local_gpu"):
if self.instance_type == "local_gpu" and self.instance_count > 1:
Expand Down Expand Up @@ -571,6 +587,12 @@ def __init__(
self.subnets = subnets
self.security_group_ids = security_group_ids

# training image configs
self.training_repository_access_mode = training_repository_access_mode
self.training_repository_credentials_provider_arn = (
training_repository_credentials_provider_arn
)

self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
self.use_spot_instances = use_spot_instances
self.max_wait = max_wait
Expand Down Expand Up @@ -651,7 +673,8 @@ def _ensure_base_job_name(self):
self.base_job_name
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
or base_name_from_image(
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
self.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
)
)

Expand Down Expand Up @@ -1405,7 +1428,10 @@ def deploy(
self._ensure_base_job_name()

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
kwargs.get("source_dir"),
self.source_dir,
kwargs.get("model_data"),
self.model_uri,
)
default_name = (
name_from_base(jumpstart_base_name)
Expand Down Expand Up @@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
if "TrainingImageConfig" in job_details["AlgorithmSpecification"]:
init_params["training_repository_access_mode"] = job_details[
"AlgorithmSpecification"
]["TrainingImageConfig"].get("TrainingRepositoryAccessMode")
init_params["training_repository_credentials_provider_arn"] = (
job_details["AlgorithmSpecification"]["TrainingImageConfig"]
.get("TrainingRepositoryAuthConfig", {})
.get("TrainingRepositoryCredentialsProviderArn")
)
else:
raise RuntimeError(
"Invalid AlgorithmSpecification. Either TrainingImage or "
Expand Down Expand Up @@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
else:
train_args["retry_strategy"] = None

if estimator.training_repository_access_mode is not None:
training_image_config = {
"TrainingRepositoryAccessMode": estimator.training_repository_access_mode
}
if estimator.training_repository_credentials_provider_arn is not None:
training_image_config["TrainingRepositoryAuthConfig"] = {}
training_image_config["TrainingRepositoryAuthConfig"][
"TrainingRepositoryCredentialsProviderArn"
] = estimator.training_repository_credentials_provider_arn
train_args["training_image_config"] = training_image_config

# encrypt_inter_container_traffic may be a pipeline variable place holder object
# which is parsed in execution time
if estimator.encrypt_inter_container_traffic:
Expand Down Expand Up @@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri):

@classmethod
def update(
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
cls,
estimator,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
):
"""Update a running Amazon SageMaker training job.

Expand Down Expand Up @@ -2321,6 +2371,8 @@ def __init__(
entry_point: Optional[Union[str, PipelineVariable]] = None,
dependencies: Optional[List[str]] = None,
instance_groups: Optional[List[InstanceGroup]] = None,
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -2654,6 +2706,18 @@ def __init__(
`Train Using a Heterogeneous Cluster
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
in the *Amazon SageMaker developer guide*.
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
Docker image that contains the training algorithm (default: None).
Set this to one of the following values:
* 'Platform' - The training image is hosted in Amazon ECR.
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
When it's default to None, its behavior will be same as 'Platform' - image is hosted
in ECR.
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
private Docker registry where your training image is hosted (default: None).
When it's set to None, SageMaker will not do authentication before pulling the image
in the private Docker registry.
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -2698,6 +2762,8 @@ def __init__(
dependencies=dependencies,
hyperparameters=hyperparameters,
instance_groups=instance_groups,
training_repository_access_mode=training_repository_access_mode,
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
**kwargs,
)

Expand Down
50 changes: 50 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def train( # noqa: C901
metric_definitions,
enable_network_isolation=False,
image_uri=None,
training_image_config=None,
algorithm_arn=None,
encrypt_inter_container_traffic=False,
use_spot_instances=False,
Expand Down Expand Up @@ -548,6 +549,28 @@ def train( # noqa: C901
enable_network_isolation (bool): Whether to request for the training job to run with
network isolation or not.
image_uri (str): Docker image containing training code.
training_image_config(dict): Training image configuration.
Optionally, the dict can contain 'TrainingRepositoryAccessMode' and
'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig').
For example,

.. code:: python

training_image_config = {
"TrainingRepositoryAccessMode": "Vpc",
"TrainingRepositoryAuthConfig": {
"TrainingRepositoryCredentialsProviderArn":
"arn:aws:lambda:us-west-2:1234567890:function:test"
},
}

If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed
through a private Docker registry in customer Vpc. If it's set to Platform or None,
the training image is accessed through ECR.
If TrainingRepositoryCredentialsProviderArn is provided, the credentials to
authenticate to the private Docker registry will be retrieved from this AWS Lambda
function. (default: ``None``). When it's set to None, SageMaker will not do
authentication before pulling the image in the private Docker registry.
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
Expand Down Expand Up @@ -606,6 +629,7 @@ def train( # noqa: C901
metric_definitions=metric_definitions,
enable_network_isolation=enable_network_isolation,
image_uri=image_uri,
training_image_config=training_image_config,
algorithm_arn=algorithm_arn,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
use_spot_instances=use_spot_instances,
Expand Down Expand Up @@ -644,6 +668,7 @@ def _get_train_request( # noqa: C901
metric_definitions,
enable_network_isolation=False,
image_uri=None,
training_image_config=None,
algorithm_arn=None,
encrypt_inter_container_traffic=False,
use_spot_instances=False,
Expand Down Expand Up @@ -704,6 +729,28 @@ def _get_train_request( # noqa: C901
enable_network_isolation (bool): Whether to request for the training job to run with
network isolation or not.
image_uri (str): Docker image containing training code.
training_image_config(dict): Training image configuration.
Optionally, the dict can contain 'TrainingRepositoryAccessMode' and
'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig').
For example,

.. code:: python

training_image_config = {
"TrainingRepositoryAccessMode": "Vpc",
"TrainingRepositoryAuthConfig": {
"TrainingRepositoryCredentialsProviderArn":
"arn:aws:lambda:us-west-2:1234567890:function:test"
},
}

If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed
through a private Docker registry in customer Vpc. If it's set to Platform or None,
the training image is accessed through ECR.
If TrainingRepositoryCredentialsProviderArn is provided, the credentials to
authenticate to the private Docker registry will be retrieved from this AWS Lambda
function. (default: ``None``). When it's set to None, SageMaker will not do
authentication before pulling the image in the private Docker registry.
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
Expand Down Expand Up @@ -768,6 +815,9 @@ def _get_train_request( # noqa: C901
if image_uri is not None:
train_request["AlgorithmSpecification"]["TrainingImage"] = image_uri

if training_image_config is not None:
train_request["AlgorithmSpecification"]["TrainingImageConfig"] = training_image_config

if algorithm_arn is not None:
train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn

Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
CODECOMMIT_BRANCH = "master"
REPO_DIR = "/tmp/repo_dir"
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
TRAINING_REPOSITORY_ACCESS_MODE = "VPC"
TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN = "arn:aws:lambda:us-west-2:1234567890:function:test"

DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}

Expand Down Expand Up @@ -391,6 +393,70 @@ def test_framework_with_keep_alive_period(sagemaker_session):
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS


def test_framework_with_both_training_repository_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE,
training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert (
args["training_image_config"]["TrainingRepositoryAccessMode"]
== TRAINING_REPOSITORY_ACCESS_MODE
)
assert (
args["training_image_config"]["TrainingRepositoryAuthConfig"][
"TrainingRepositoryCredentialsProviderArn"
]
== TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN
)


def test_framework_with_training_repository_access_mode(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert (
args["training_image_config"]["TrainingRepositoryAccessMode"]
== TRAINING_REPOSITORY_ACCESS_MODE
)
assert "TrainingRepositoryAuthConfig" not in args["training_image_config"]


def test_framework_without_training_repository_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args.get("training_image_config") is None


def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
base_config=rule_configs.stalled_training_rule(),
Expand Down Expand Up @@ -3763,6 +3829,28 @@ def test_prepare_init_params_from_job_description_with_retry_strategy():
assert init_params["max_retry_attempts"] == 2


def test_prepare_init_params_from_job_description_with_training_image_config():
job_description = RETURNED_JOB_DESCRIPTION.copy()
job_description["AlgorithmSpecification"]["TrainingImageConfig"] = {
"TrainingRepositoryAccessMode": "Vpc",
"TrainingRepositoryAuthConfig": {
"TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567890:function:test"
},
}

init_params = EstimatorBase._prepare_init_params_from_job_description(
job_details=job_description
)

assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
assert init_params["instance_count"] == 1
assert init_params["training_repository_access_mode"] == "Vpc"
assert (
init_params["training_repository_credentials_provider_arn"]
== "arn:aws:lambda:us-west-2:1234567890:function:test"
)


def test_prepare_init_params_from_job_description_with_invalid_training_job():

invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()
Expand Down
Loading