Skip to content

Commit c6d1c1a

Browse files
navaj0Zhankuil
authored and
Namrata Madan
committed
Set AWS_DEFAULT_REGION environment variable (aws#889)
Co-authored-by: Zhankui Lu <[email protected]>
1 parent 65f21df commit c6d1c1a

File tree

6 files changed

+60
-43
lines changed

6 files changed

+60
-43
lines changed

src/sagemaker/remote_function/client.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _submit_worker(executor):
222222
def has_work_to_do():
223223
return (
224224
len(executor._pending_request_queue) > 0
225-
and len(executor._running_jobs) < executor.max_parallel_job
225+
and len(executor._running_jobs) < executor.max_parallel_jobs
226226
)
227227

228228
try:
@@ -315,7 +315,7 @@ def __init__(
315315
job_conda_env: str = None,
316316
job_name_prefix: str = None,
317317
keep_alive_period_in_seconds: int = 0,
318-
max_parallel_job: int = 1,
318+
max_parallel_jobs: int = 1,
319319
max_retry_attempts: int = 1,
320320
max_runtime_in_seconds: int = 24 * 60 * 60,
321321
role: str = None,
@@ -346,7 +346,7 @@ def __init__(
346346
job_name_prefix (str): Prefix used to identify the underlying sagemaker job.
347347
keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured
348348
resources in a warm pool for subsequent training jobs. Defaults to 0.
349-
max_parallel_job (int): Maximal number of jobs that run in parallel. Default to 1.
349+
max_parallel_jobs (int): Maximal number of jobs that run in parallel. Default to 1.
350350
max_retry_attempts (int): Max number of times the job is retried on
351351
InternalServerFailure.Defaults to 1.
352352
max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of
@@ -370,10 +370,10 @@ def __init__(
370370
volume_size (int): Size in GB of the storage volume to use for storing input and output
371371
data. Defaults to 30.
372372
"""
373-
self.max_parallel_job = max_parallel_job
373+
self.max_parallel_jobs = max_parallel_jobs
374374

375-
if self.max_parallel_job <= 0:
376-
raise ValueError("max_parallel_job must be greater than 0.")
375+
if self.max_parallel_jobs <= 0:
376+
raise ValueError("max_parallel_jobs must be greater than 0.")
377377

378378
self.job_settings = _JobSettings(
379379
dependencies=dependencies,

src/sagemaker/remote_function/job.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,12 @@ def __init__(
131131
)
132132
self.sagemaker_session = sagemaker_session or Session()
133133

134-
self.environment_variables = self._get_from_config(
135-
environment_variables, config_schema.ENVIRONMENT_VARIABLES
134+
self.environment_variables = {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name}
135+
136+
self.environment_variables.update(
137+
self._get_from_config(
138+
environment_variables, config_schema.ENVIRONMENT_VARIABLES, default={}
139+
)
136140
)
137141

138142
_image_uri = self._get_from_config(image_uri, config_schema.IMAGE_URI)

tests/integ/sagemaker/remote_function/test_decorator.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import pytest
1818
import os
1919
import pandas as pd
20-
import boto3
21-
from sagemaker import Session
2220
from sagemaker.experiments.run import Run, load_run
2321
from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources
2422
from sagemaker.experiments.trial_component import _TrialComponent
@@ -273,6 +271,7 @@ def test_with_non_existent_dependencies(
273271
dependencies=dependencies_path,
274272
instance_type=cpu_instance_type,
275273
sagemaker_session=sagemaker_session,
274+
keep_alive_period_in_seconds=30,
276275
)
277276
def divide(x, y):
278277
return x / y
@@ -293,6 +292,7 @@ def test_with_incompatible_dependencies(
293292
dependencies=dependencies_path,
294293
instance_type=cpu_instance_type,
295294
sagemaker_session=sagemaker_session,
295+
keep_alive_period_in_seconds=30,
296296
)
297297
def mul_ten(df: pd.DataFrame):
298298
return df.mul(10)
@@ -318,16 +318,11 @@ def test_decorator_with_exp_and_run_names_passed_to_remote_function(
318318
image_uri=dummy_container_without_error,
319319
instance_type=cpu_instance_type,
320320
sagemaker_session=sagemaker_session,
321+
keep_alive_period_in_seconds=30,
321322
)
322323
def train(exp_name, run_name):
323-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
324-
sagemaker_session = Session(boto_session=boto_session)
325324

326-
with Run(
327-
experiment_name=exp_name,
328-
run_name=run_name,
329-
sagemaker_session=sagemaker_session,
330-
) as run:
325+
with Run(experiment_name=exp_name, run_name=run_name) as run:
331326
print(f"Experiment name: {run.experiment_name}")
332327
print(f"Run name: {run.run_name}")
333328
print(f"Trial component name: {run._trial_component.trial_component_name}")
@@ -380,6 +375,7 @@ def test_decorator_load_run_inside_remote_function(
380375
image_uri=dummy_container_without_error,
381376
instance_type=cpu_instance_type,
382377
sagemaker_session=sagemaker_session,
378+
keep_alive_period_in_seconds=30,
383379
)
384380
def train():
385381
with load_run() as run:
@@ -419,14 +415,12 @@ def test_decorator_with_nested_exp_run(
419415
image_uri=dummy_container_without_error,
420416
instance_type=cpu_instance_type,
421417
sagemaker_session=sagemaker_session,
418+
keep_alive_period_in_seconds=30,
422419
)
423420
def train(exp_name, run_name):
424-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
425-
sagemaker_session = Session(boto_session=boto_session)
426421
with Run(
427422
experiment_name=exp_name,
428423
run_name=run_name,
429-
sagemaker_session=sagemaker_session,
430424
) as run:
431425
print(f"Experiment name: {run.experiment_name}")
432426
print(f"Run name: {run.run_name}")

tests/integ/sagemaker/remote_function/test_executor.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def cube(x):
3030
return x * x * x
3131

3232
with RemoteExecutor(
33-
max_parallel_job=1,
33+
max_parallel_jobs=1,
3434
role=ROLE,
3535
image_uri=dummy_container_without_error,
3636
instance_type=cpu_instance_type,
@@ -59,7 +59,7 @@ def power(a, b):
5959
return a**b
6060

6161
with RemoteExecutor(
62-
max_parallel_job=1,
62+
max_parallel_jobs=1,
6363
role=ROLE,
6464
image_uri=dummy_container_without_error,
6565
instance_type=cpu_instance_type,
@@ -98,7 +98,7 @@ def cube(x):
9898

9999
with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
100100
with RemoteExecutor(
101-
max_parallel_job=1,
101+
max_parallel_jobs=1,
102102
role=ROLE,
103103
image_uri=dummy_container_without_error,
104104
instance_type=cpu_instance_type,
@@ -162,7 +162,7 @@ def cube(x):
162162
sagemaker_session=sagemaker_session,
163163
):
164164
with RemoteExecutor(
165-
max_parallel_job=1,
165+
max_parallel_jobs=1,
166166
role=ROLE,
167167
image_uri=dummy_container_without_error,
168168
instance_type=cpu_instance_type,
@@ -213,7 +213,7 @@ def square(x):
213213
sagemaker_session=sagemaker_session,
214214
):
215215
with RemoteExecutor(
216-
max_parallel_job=2,
216+
max_parallel_jobs=2,
217217
role=ROLE,
218218
image_uri=dummy_container_without_error,
219219
instance_type=cpu_instance_type,
@@ -227,7 +227,7 @@ def square(x):
227227
assert results[1] == 16
228228

229229
with RemoteExecutor(
230-
max_parallel_job=2,
230+
max_parallel_jobs=2,
231231
role=ROLE,
232232
image_uri=dummy_container_without_error,
233233
instance_type=cpu_instance_type,

tests/unit/sagemaker/remote_function/test_client.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,14 @@ def decorated_function(a, b, c=1, *, d, e, f=3):
453453

454454
def test_executor_invalid_arguments():
455455
with pytest.raises(ValueError):
456-
with RemoteExecutor(max_parallel_job=0, s3_root_uri="s3://bucket/") as e:
456+
with RemoteExecutor(max_parallel_jobs=0, s3_root_uri="s3://bucket/") as e:
457457
e.submit(job_function, 1, 2, c=3, d=4)
458458

459459

460460
@patch("sagemaker.remote_function.client._JobSettings")
461461
def test_executor_submit_after_shutdown(*args):
462462
with pytest.raises(RuntimeError):
463-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
463+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e:
464464
pass
465465
e.submit(job_function, 1, 2, c=3, d=4)
466466

@@ -476,7 +476,7 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
476476
mock_job_4 = create_mock_job("job_4", COMPLETED_TRAINING_JOB)
477477
mock_start.side_effect = [mock_job_1, mock_job_2, mock_job_3, mock_job_4]
478478

479-
with RemoteExecutor(max_parallel_job=parallelism, s3_root_uri="s3://bucket/") as e:
479+
with RemoteExecutor(max_parallel_jobs=parallelism, s3_root_uri="s3://bucket/") as e:
480480
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
481481
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
482482
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
@@ -514,7 +514,7 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
514514
run_info = _RunInfo(run_obj.experiment_name, run_obj.run_name)
515515

516516
with run_obj:
517-
with RemoteExecutor(max_parallel_job=2, s3_root_uri="s3://bucket/") as e:
517+
with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as e:
518518
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
519519
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
520520

@@ -530,7 +530,7 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
530530
assert future_1.done()
531531
assert future_2.done()
532532

533-
with RemoteExecutor(max_parallel_job=2, s3_root_uri="s3://bucket/") as e:
533+
with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as e:
534534
with run_obj:
535535
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
536536
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
@@ -556,7 +556,7 @@ def test_executor_submit_enforcing_max_parallel_jobs(mock_start, *args):
556556
mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB)
557557
mock_start.side_effect = [mock_job_1, mock_job_2]
558558

559-
e = RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/")
559+
e = RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/")
560560
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
561561
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
562562

@@ -588,7 +588,7 @@ def test_executor_fails_to_start_job(mock_start, *args):
588588

589589
mock_start.side_effect = [TypeError(), mock_job]
590590

591-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
591+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e:
592592
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
593593
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
594594

@@ -606,7 +606,7 @@ def test_executor_submit_and_cancel(mock_start, *args):
606606
mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB)
607607
mock_start.side_effect = [mock_job_1, mock_job_2]
608608

609-
e = RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/")
609+
e = RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/")
610610

611611
# submit first job and stay in progress
612612
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
@@ -645,7 +645,7 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args):
645645
]
646646
mock_start.return_value = mock_job
647647

648-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
648+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e:
649649
# submit first job
650650
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
651651
# submit second job
@@ -663,7 +663,7 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
663663
mock_job.describe.side_effect = RuntimeError()
664664
mock_start.return_value = mock_job
665665

666-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
666+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e:
667667
# submit first job
668668
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
669669
# submit second job
@@ -695,7 +695,7 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
695695
@patch("sagemaker.remote_function.client._JobSettings")
696696
def test_executor_submit_invalid_function_args(mock_job_settings, args, kwargs, error_message):
697697
with pytest.raises(TypeError) as e:
698-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as executor:
698+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as executor:
699699
executor.submit(job_function, *args, **kwargs)
700700
assert error_message in str(e.value)
701701

@@ -1063,7 +1063,7 @@ def test_executor_map_happy_case(mock_deserialized, mock_start, mock_job_setting
10631063

10641064
mock_deserialized.side_effect = [1, 16]
10651065

1066-
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as executor:
1066+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as executor:
10671067
results = executor.map(job_function2, [1, 2], [3, 4])
10681068

10691069
mock_start.assert_has_calls(
@@ -1095,7 +1095,7 @@ def test_executor_map_with_run(mock_deserialized, mock_start, mock_job_settings,
10951095
run_info = _RunInfo(run_obj.experiment_name, run_obj.run_name)
10961096

10971097
with run_obj:
1098-
with RemoteExecutor(max_parallel_job=2, s3_root_uri="s3://bucket/") as executor:
1098+
with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as executor:
10991099
results_12 = executor.map(job_function2, [1, 2], [3, 4])
11001100

11011101
mock_start.assert_has_calls(
@@ -1112,7 +1112,7 @@ def test_executor_map_with_run(mock_deserialized, mock_start, mock_job_settings,
11121112

11131113
mock_deserialized.side_effect = [1, 16]
11141114

1115-
with RemoteExecutor(max_parallel_job=2, s3_root_uri="s3://bucket/") as executor:
1115+
with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as executor:
11161116
with run_obj:
11171117
results_34 = executor.map(job_function2, [1, 2], [3, 4])
11181118

tests/unit/sagemaker/remote_function/test_job.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ def job_function(a, b=1, *, c, d=3):
114114
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
115115
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
116116
def test_sagemaker_config_job_settings(get_execution_role, session, monkeypatch):
117+
118+
job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge")
119+
assert job_settings.image_uri == "image_uri"
120+
assert job_settings.s3_root_uri == f"s3://{BUCKET}"
121+
assert job_settings.role == DEFAULT_ROLE_ARN
122+
assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"}
123+
assert job_settings.include_local_workdir is False
124+
assert job_settings.instance_type == "ml.m5.xlarge"
125+
126+
127+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
128+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
129+
def test_sagemaker_config_job_settings_with_configuration_file(
130+
get_execution_role, session, monkeypatch
131+
):
117132
monkeypatch.setenv(
118133
"SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", os.path.join(DATA_DIR, "remote_function")
119134
)
@@ -125,7 +140,10 @@ def test_sagemaker_config_job_settings(get_execution_role, session, monkeypatch)
125140
assert job_settings.tags == [("someTagKey", "someTagValue"), ("someTagKey2", "someTagValue2")]
126141
assert job_settings.vpc_config == {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]}
127142
assert job_settings.dependencies == "path/to/requirements.txt"
128-
assert job_settings.environment_variables == {"EnvVarKey": "EnvVarValue"}
143+
assert job_settings.environment_variables == {
144+
"AWS_DEFAULT_REGION": "us-west-2",
145+
"EnvVarKey": "EnvVarValue",
146+
}
129147
assert job_settings.job_conda_env == "my_conda_env"
130148
assert job_settings.include_local_workdir is False
131149
assert job_settings.volume_kms_key == "someVolumeKmsKey"
@@ -276,6 +294,7 @@ def test_start(
276294
InstanceType="ml.m5.large",
277295
KeepAlivePeriodInSeconds=0,
278296
),
297+
Environment={"AWS_DEFAULT_REGION": "us-west-2"},
279298
)
280299

281300

@@ -292,7 +311,7 @@ def test_start_with_complete_job_settings(
292311

293312
job_settings = _JobSettings(
294313
dependencies="path/to/dependencies/req.txt",
295-
environment_variables={"REGION": "us-west-2"},
314+
environment_variables={"AWS_DEFAULT_REGION": "us-east-2"},
296315
image_uri=IMAGE,
297316
s3_root_uri=S3_URI,
298317
s3_kms_key=KMS_KEY_ARN,
@@ -392,7 +411,7 @@ def test_start_with_complete_job_settings(
392411
KeepAlivePeriodInSeconds=120,
393412
),
394413
VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
395-
Environment={"REGION": "us-west-2"},
414+
Environment={"AWS_DEFAULT_REGION": "us-east-2"},
396415
)
397416

398417

0 commit comments

Comments
 (0)