Skip to content

Commit 6ceb318

Browse files
aoguo64Rohan GujarathirohangujarathiAo Guo
authored and
Namrata Madan
committed
FruitStand support (aws#892)
Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: Ao Guo <[email protected]>
1 parent 0aa6926 commit 6ceb318

File tree

5 files changed

+53
-4
lines changed

5 files changed

+53
-4
lines changed

src/sagemaker/config/config_schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
S3_KMS_KEY_ID = "S3KmsKeyId"
5353
S3_ROOT_URI = "S3RootUri"
5454
JOB_CONDA_ENV = "JobCondaEnvironment"
55+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption"
5556
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
5657
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
5758
S3_STORAGE_CONFIG = "S3StorageConfig"

src/sagemaker/remote_function/client.py

+22
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def remote(
7878
tags: List[Tuple[str, str]] = None,
7979
volume_kms_key: str = None,
8080
volume_size: int = 30,
81+
encrypt_inter_container_traffic: bool = None,
82+
enable_network_isolation: bool = None,
8183
):
8284
"""Function that starts a new SageMaker job synchronously with overridden runtime settings.
8385
@@ -114,6 +116,13 @@ def remote(
114116
instance.
115117
volume_size (int): Size in GB of the storage volume to use for storing input and output
116118
data. Default is 30.
119+
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
120+
containers is encrypted for the training job. (default: ``False``).
121+
enable_network_isolation (bool): Specifies whether container will
122+
run in network isolation mode (default: ``False``). Network
123+
isolation mode restricts the container access to outside networks
124+
(such as the Internet). The container does not make any inbound or
125+
outbound network calls. Also known as Internet-free mode.
117126
"""
118127

119128
def _remote(func):
@@ -143,6 +152,8 @@ def wrapper(*args, **kwargs):
143152
tags=tags,
144153
volume_kms_key=volume_kms_key,
145154
volume_size=volume_size,
155+
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
156+
enable_network_isolation=enable_network_isolation,
146157
)
147158
job = _Job.start(job_settings, func, args, kwargs)
148159

@@ -331,6 +342,8 @@ def __init__(
331342
tags: List[Tuple[str, str]] = None,
332343
volume_kms_key: str = None,
333344
volume_size: int = 30,
345+
encrypt_inter_container_traffic: bool = None,
346+
enable_network_isolation: bool = None,
334347
):
335348
"""Initiates a ``RemoteExecutor`` instance.
336349
@@ -373,6 +386,13 @@ def __init__(
373386
instance.
374387
volume_size (int): Size in GB of the storage volume to use for storing input and output
375388
data. Defaults to 30.
389+
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
390+
containers is encrypted for the training job. (default: ``False``).
391+
enable_network_isolation (bool): Specifies whether container will
392+
run in network isolation mode (default: ``False``). Network
393+
isolation mode restricts the container access to outside networks
394+
(such as the Internet). The container does not make any inbound or
395+
outbound network calls. Also known as Internet-free mode.
376396
"""
377397
self.max_parallel_jobs = max_parallel_jobs
378398

@@ -400,6 +420,8 @@ def __init__(
400420
tags=tags,
401421
volume_kms_key=volume_kms_key,
402422
volume_size=volume_size,
423+
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
424+
enable_network_isolation=enable_network_isolation,
403425
)
404426

405427
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def __init__(
124124
tags: List[Tuple[str, str]] = None,
125125
volume_kms_key: str = None,
126126
volume_size: int = 30,
127+
encrypt_inter_container_traffic: bool = None,
128+
enable_network_isolation: bool = None,
127129
):
128130

129131
self.sagemaker_config = SageMakerConfigFactory.build_sagemaker_config(
@@ -157,6 +159,14 @@ def __init__(
157159
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
158160
self.job_conda_env = self._get_from_config(job_conda_env, config_schema.JOB_CONDA_ENV)
159161
self.job_name_prefix = job_name_prefix
162+
self.encrypt_inter_container_traffic = self._get_from_config(
163+
encrypt_inter_container_traffic,
164+
config_schema.ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
165+
default=False,
166+
)
167+
self.enable_network_isolation = self._get_from_config(
168+
enable_network_isolation, config_schema.ENABLE_NETWORK_ISOLATION, default=False
169+
)
160170

161171
_role = self._get_from_config(role, config_schema.ROLE_ARN)
162172
if _role:
@@ -195,7 +205,7 @@ def _get_from_config(
195205
required=False,
196206
):
197207
"""Get value from sagemaker config."""
198-
if override_value:
208+
if override_value is not None:
199209
return override_value
200210
config_value = self.sagemaker_config.get_config_value(
201211
"{}.{}.{}.{}.{}".format(
@@ -206,9 +216,9 @@ def _get_from_config(
206216
sagemaker_config_key,
207217
)
208218
)
209-
if config_value:
219+
if config_value is not None:
210220
return transform(config_value)
211-
if not default and required:
221+
if default is None and required:
212222
raise ValueError(f"{sagemaker_config_key} is a required parameter!")
213223
return default
214224

@@ -391,6 +401,14 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
391401

392402
request_dict["ResourceConfig"] = resource_config
393403

404+
if job_settings.enable_network_isolation is not None:
405+
request_dict["EnableNetworkIsolation"] = job_settings.enable_network_isolation
406+
407+
if job_settings.encrypt_inter_container_traffic is not None:
408+
request_dict[
409+
"EnableInterContainerTrafficEncryption"
410+
] = job_settings.encrypt_inter_container_traffic
411+
394412
if job_settings.vpc_config:
395413
request_dict["VpcConfig"] = job_settings.vpc_config
396414

tests/data/remote_function/config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ SageMaker:
1111
SecurityGroupIds: ["sg123"]
1212
Subnets: ["subnet-1234"]
1313
Tags: [{"someTagKey": "someTagValue"}, {"someTagKey2": "someTagValue2"}]
14-
VolumeKmsKeyId: "someVolumeKmsKey"
14+
VolumeKmsKeyId: "someVolumeKmsKey"
15+
EnableNetworkIsolation: true

tests/unit/sagemaker/remote_function/test_job.py

+7
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def test_sagemaker_config_job_settings_with_configuration_file(
149149
assert job_settings.volume_kms_key == "someVolumeKmsKey"
150150
assert job_settings.s3_kms_key == "someS3KmsKey"
151151
assert job_settings.instance_type == "ml.m5.large"
152+
assert job_settings.enable_network_isolation is True
153+
assert job_settings.encrypt_inter_container_traffic is False
152154

153155
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
154156

@@ -210,6 +212,7 @@ def test_start(
210212
role=ROLE_ARN,
211213
include_local_workdir=True,
212214
instance_type="ml.m5.large",
215+
encrypt_inter_container_traffic=True,
213216
)
214217

215218
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
@@ -297,6 +300,8 @@ def test_start(
297300
InstanceType="ml.m5.large",
298301
KeepAlivePeriodInSeconds=0,
299302
),
303+
EnableNetworkIsolation=False,
304+
EnableInterContainerTrafficEncryption=True,
300305
Environment={"AWS_DEFAULT_REGION": "us-west-2"},
301306
)
302307

@@ -416,6 +421,8 @@ def test_start_with_complete_job_settings(
416421
VolumeKmsKeyId=KMS_KEY_ARN,
417422
KeepAlivePeriodInSeconds=120,
418423
),
424+
EnableNetworkIsolation=False,
425+
EnableInterContainerTrafficEncryption=False,
419426
VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
420427
Environment={"AWS_DEFAULT_REGION": "us-east-2"},
421428
)

0 commit comments

Comments
 (0)