Skip to content

Commit cebfd71

Browse files
authored
Change: Raise error if custom VPC is halfway defined (#4269)
1 parent b001698 commit cebfd71

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/sagemaker/estimator.py

+11
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,17 @@ def __init__(
578578
self.dependencies = dependencies or []
579579
self.uploaded_code: Optional[UploadedCode] = None
580580

581+
# Check that the user properly sets both subnet and secutiry_groupe_ids
582+
if (
583+
subnets is not None
584+
and security_group_ids is None
585+
or security_group_ids is not None
586+
and subnets is None
587+
):
588+
raise RuntimeError(
589+
"When setting up custom VPC, both subnets and security_group_ids must be set"
590+
)
591+
581592
if self.instance_type in ("local", "local_gpu"):
582593
if self.instance_type == "local_gpu" and self.instance_count > 1:
583594
raise RuntimeError("Distributed Training in Local GPU is not supported")

tests/unit/test_estimator.py

+18
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,24 @@ def test_framework_all_init_args(sagemaker_session):
512512
}
513513

514514

515+
def test_subnets_without_security_groups(sagemaker_session):
516+
with pytest.raises(RuntimeError):
517+
DummyFramework(
518+
entry_point=SCRIPT_PATH,
519+
sagemaker_session=sagemaker_session,
520+
subnets=["123"],
521+
)
522+
523+
524+
def test_security_groups_without_subnets(sagemaker_session):
525+
with pytest.raises(RuntimeError):
526+
DummyFramework(
527+
entry_point=SCRIPT_PATH,
528+
sagemaker_session=sagemaker_session,
529+
security_group_ids=["123"],
530+
)
531+
532+
515533
def test_framework_without_role_parameter(sagemaker_session):
516534
with pytest.raises(ValueError):
517535
DummyFramework(

0 commit comments

Comments
 (0)