Skip to content

Commit 2907ea5

Browse files
authored
fix: forward network_isolation parameter to Estimators when False (aws#4543)
A sagemaker config file would override this parameter if it was True in the config but False in the Estimator parameters. Closes aws#4542
1 parent 8bb857f commit 2907ea5

14 files changed

+23
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ def start_new(cls, estimator, inputs, experiment_config):
24342434
"""
24352435
train_args = cls._get_train_args(estimator, inputs, experiment_config)
24362436

2437+
logger.debug("Train args after processing defaults: %s", train_args)
24372438
estimator.sagemaker_session.train(**train_args)
24382439

24392440
return cls(estimator.sagemaker_session, estimator._current_job_name)
@@ -2499,7 +2500,13 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
24992500

25002501
# enable_network_isolation may be a pipeline variable place holder object
25012502
# which is parsed in execution time
2502-
if estimator.enable_network_isolation():
2503+
2504+
# Should be defaulted to False
2505+
train_args["enable_network_isolation"] = False
2506+
2507+
# Only change it if it's explicitly passed so the sagemaker config
2508+
# doesn't override the kwarg.
2509+
if estimator.enable_network_isolation() is not None:
25032510
train_args["enable_network_isolation"] = estimator.enable_network_isolation()
25042511

25052512
if estimator.max_retry_attempts is not None:

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _create_train_job(version, base_framework_version):
145145
"environment": None,
146146
"retry_strategy": None,
147147
"experiment_config": None,
148+
"enable_network_isolation": False,
148149
"debugger_hook_config": {
149150
"CollectionConfigurations": [],
150151
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
141141
"vpc_config": None,
142142
"metric_definitions": None,
143143
"environment": None,
144+
"enable_network_isolation": False,
144145
"experiment_config": None,
145146
"profiler_config": {
146147
"DisableProfiler": False,

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(
143143
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
144144
"tags": None,
145145
"vpc_config": None,
146+
"enable_network_isolation": False,
146147
"metric_definitions": None,
147148
"environment": None,
148149
"retry_strategy": None,

tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def _create_train_job(
144144
"environment": None,
145145
"retry_strategy": None,
146146
"experiment_config": EXPERIMENT_CONFIG,
147+
"enable_network_isolation": False,
147148
"debugger_hook_config": {
148149
"CollectionConfigurations": [],
149150
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(
143143
"environment": None,
144144
"retry_strategy": None,
145145
"experiment_config": EXPERIMENT_CONFIG,
146+
"enable_network_isolation": False,
146147
"debugger_hook_config": {
147148
"CollectionConfigurations": [],
148149
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config
149149
"environment": None,
150150
"retry_strategy": None,
151151
"experiment_config": EXPERIMENT_CONFIG,
152+
"enable_network_isolation": False,
152153
"debugger_hook_config": {
153154
"CollectionConfigurations": [],
154155
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(version, py_version):
150150
"tags": None,
151151
"vpc_config": None,
152152
"metric_definitions": None,
153+
"enable_network_isolation": False,
153154
"environment": None,
154155
"experiment_config": None,
155156
"debugger_hook_config": {

tests/unit/test_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
18951895
"encrypt_inter_container_traffic": True,
18961896
"use_spot_instances": True,
18971897
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
1898+
"enable_network_isolation": False,
18981899
"checkpoint_local_path": "/tmp/checkpoints",
18991900
"environment": None,
19001901
"experiment_config": None,
@@ -3441,6 +3442,7 @@ def test_unsupported_type_in_dict():
34413442
"vpc_config": None,
34423443
"metric_definitions": None,
34433444
"environment": None,
3445+
"enable_network_isolation": False,
34443446
"experiment_config": None,
34453447
}
34463448

@@ -3831,7 +3833,7 @@ def test_generic_to_fit_with_network_isolation(sagemaker_session):
38313833

38323834
sagemaker_session.train.assert_called_once()
38333835
args = sagemaker_session.train.call_args[1]
3834-
assert args["enable_network_isolation"]
3836+
assert args["enable_network_isolation"] is True
38353837

38363838

38373839
def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _get_train_args(job_name):
167167
"environment": None,
168168
"retry_strategy": None,
169169
"experiment_config": None,
170+
"enable_network_isolation": False,
170171
"debugger_hook_config": {
171172
"CollectionConfigurations": [],
172173
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/test_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _create_train_job(version, py_version):
165165
"environment": None,
166166
"retry_strategy": None,
167167
"experiment_config": None,
168+
"enable_network_isolation": False,
168169
"debugger_hook_config": {
169170
"CollectionConfigurations": [],
170171
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
155155
],
156156
"environment": None,
157157
"experiment_config": None,
158+
"enable_network_isolation": False,
158159
"debugger_hook_config": {
159160
"CollectionConfigurations": [],
160161
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/test_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _create_train_job(version):
142142
"vpc_config": None,
143143
"environment": None,
144144
"experiment_config": None,
145+
"enable_network_isolation": False,
145146
"debugger_hook_config": {
146147
"CollectionConfigurations": [],
147148
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

tests/unit/test_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
155155
"vpc_config": None,
156156
"environment": None,
157157
"experiment_config": None,
158+
"enable_network_isolation": False,
158159
"debugger_hook_config": {
159160
"CollectionConfigurations": [],
160161
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),

0 commit comments

Comments
 (0)