Skip to content

Commit 8b55e39

Browse files
committed
Merge remote-tracking branch 'origin' into feat/jumpstart-model-estimator-classes
2 parents 7154fc6 + 1caab81 commit 8b55e39

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/sagemaker/remote_function/client.py

+14
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def remote(
185185
methods that are not available via PyPI or conda. Default value is ``False``.
186186
187187
instance_count (int): The number of instances to use. Defaults to 1.
188+
NOTE: Remote function does not support instance_count > 1
188189
189190
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
190191
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -255,6 +256,12 @@ def _remote(func):
255256
@functools.wraps(func)
256257
def wrapper(*args, **kwargs):
257258

259+
if instance_count > 1:
260+
raise ValueError(
261+
"Remote function do not support training on multi instances. "
262+
+ "Please provide instance_count = 1"
263+
)
264+
258265
RemoteExecutor._validate_submit_args(func, *args, **kwargs)
259266

260267
job_settings = _JobSettings(
@@ -574,6 +581,7 @@ def __init__(
574581
and methods that are not available via PyPI or conda. Default value is ``False``.
575582
576583
instance_count (int): The number of instances to use. Defaults to 1.
584+
NOTE: Remote function does not support instance_count > 1
577585
578586
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
579587
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -647,6 +655,12 @@ def __init__(
647655
if self.max_parallel_jobs <= 0:
648656
raise ValueError("max_parallel_jobs must be greater than 0.")
649657

658+
if instance_count > 1:
659+
raise ValueError(
660+
"Remote function do not support training on multi instances. "
661+
+ "Please provide instance_count = 1"
662+
)
663+
650664
self.job_settings = _JobSettings(
651665
dependencies=dependencies,
652666
pre_execution_commands=pre_execution_commands,

tests/unit/sagemaker/remote_function/test_client.py

+19
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,17 @@ def square(x):
373373
square(5)
374374

375375

376+
def test_decorator_instance_count_greater_than_one():
377+
@remote(image_uri=IMAGE, s3_root_uri=S3_URI, instance_count=2)
378+
def square(x):
379+
return x * x
380+
381+
with pytest.raises(
382+
ValueError, match=r"Remote function do not support training on multi instances."
383+
):
384+
square(5)
385+
386+
376387
@patch("sagemaker.remote_function.client._JobSettings")
377388
@patch("sagemaker.remote_function.client._Job.start")
378389
def test_decorator_underlying_job_timed_out(mock_start, mock_job_settings):
@@ -626,6 +637,14 @@ def test_executor_fails_to_start_job(mock_start, *args):
626637
assert future_2.done()
627638

628639

640+
def test_executor_instance_count_greater_than_one():
641+
with pytest.raises(
642+
ValueError, match=r"Remote function do not support training on multi instances."
643+
):
644+
with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/", instance_count=2) as e:
645+
e.submit(job_function, 1, 2, c=3, d=4)
646+
647+
629648
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
630649
@patch("sagemaker.remote_function.client._JobSettings")
631650
@patch("sagemaker.remote_function.client._Job.start")

0 commit comments

Comments
 (0)