Skip to content

Commit da9d05a

Browse files
trungleducknikure
authored andcommitted
fix: clone distribution in validate_distribution (aws#4205)
1 parent 95bec79 commit da9d05a

File tree

2 files changed

+47
-22
lines changed

2 files changed

+47
-22
lines changed

src/sagemaker/fw_utils.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
import shutil
2222
import tempfile
2323
from collections import namedtuple
24-
from typing import Optional, Union, Dict
24+
from typing import List, Optional, Union, Dict
2525
from packaging import version
2626

2727
import sagemaker.image_uris
28+
from sagemaker.instance_group import InstanceGroup
2829
from sagemaker.s3_utils import s3_path_join
2930
from sagemaker.session_settings import SessionSettings
3031
import sagemaker.utils
@@ -828,14 +829,14 @@ def _validate_smdataparallel_args(
828829

829830

830831
def validate_distribution(
831-
distribution,
832-
instance_groups,
833-
framework_name,
834-
framework_version,
835-
py_version,
836-
image_uri,
837-
kwargs,
838-
):
832+
distribution: Dict,
833+
instance_groups: List[InstanceGroup],
834+
framework_name: str,
835+
framework_version: str,
836+
py_version: str,
837+
image_uri: str,
838+
kwargs: Dict,
839+
) -> Dict:
839840
"""Check if distribution strategy is correctly invoked by the user.
840841
841842
Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
@@ -872,7 +873,9 @@ def validate_distribution(
872873
strategy-specific inputs are incorrect/unsupported or
873874
heterogeneous cluster set up is incorrect
874875
"""
875-
train_instance_groups = distribution.get("instance_groups", [])
876+
validated_distribution = dict(distribution)
877+
878+
train_instance_groups = validated_distribution.get("instance_groups", [])
876879
if instance_groups is None:
877880
if len(train_instance_groups) >= 1:
878881
# if estimator's instance_groups is not defined but
@@ -902,77 +905,77 @@ def validate_distribution(
902905
instance_type = train_instance_group.instance_type
903906
validate_distribution_for_instance_type(
904907
instance_type=instance_type,
905-
distribution=distribution,
908+
distribution=validated_distribution,
906909
)
907910
validate_smdistributed(
908911
instance_type=instance_type,
909912
framework_name=framework_name,
910913
framework_version=framework_version,
911914
py_version=py_version,
912-
distribution=distribution,
915+
distribution=validated_distribution,
913916
image_uri=image_uri,
914917
)
915918
if framework_name and framework_name == "pytorch":
916919
# We need to validate only for PyTorch framework
917920
validate_pytorch_distribution(
918-
distribution=distribution,
921+
distribution=validated_distribution,
919922
framework_name=framework_name,
920923
framework_version=framework_version,
921924
py_version=py_version,
922925
image_uri=image_uri,
923926
)
924927
validate_torch_distributed_distribution(
925928
instance_type=instance_type,
926-
distribution=distribution,
929+
distribution=validated_distribution,
927930
framework_version=framework_version,
928931
py_version=py_version,
929932
image_uri=image_uri,
930933
entry_point=kwargs["entry_point"],
931934
)
932935
warn_if_parameter_server_with_multi_gpu(
933-
training_instance_type=instance_type, distribution=distribution
936+
training_instance_type=instance_type, distribution=validated_distribution
934937
)
935938
# get instance group names
936939
instance_group_names.append(train_instance_group.instance_group_name)
937-
distribution["instance_groups"] = instance_group_names
940+
validated_distribution["instance_groups"] = instance_group_names
938941
else:
939942
# in this case, we are handling a normal training job (without heterogeneous cluster)
940943
instance_type = renamed_kwargs(
941944
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
942945
)
943946
validate_distribution_for_instance_type(
944947
instance_type=instance_type,
945-
distribution=distribution,
948+
distribution=validated_distribution,
946949
)
947950
validate_smdistributed(
948951
instance_type=instance_type,
949952
framework_name=framework_name,
950953
framework_version=framework_version,
951954
py_version=py_version,
952-
distribution=distribution,
955+
distribution=validated_distribution,
953956
image_uri=image_uri,
954957
)
955958
if framework_name and framework_name == "pytorch":
956959
# We need to validate only for PyTorch framework
957960
validate_pytorch_distribution(
958-
distribution=distribution,
961+
distribution=validated_distribution,
959962
framework_name=framework_name,
960963
framework_version=framework_version,
961964
py_version=py_version,
962965
image_uri=image_uri,
963966
)
964967
validate_torch_distributed_distribution(
965968
instance_type=instance_type,
966-
distribution=distribution,
969+
distribution=validated_distribution,
967970
framework_version=framework_version,
968971
py_version=py_version,
969972
image_uri=image_uri,
970973
entry_point=kwargs["entry_point"],
971974
)
972975
warn_if_parameter_server_with_multi_gpu(
973-
training_instance_type=instance_type, distribution=distribution
976+
training_instance_type=instance_type, distribution=validated_distribution
974977
)
975-
return distribution
978+
return validated_distribution
976979

977980

978981
def validate_distribution_for_instance_type(instance_type, distribution):

tests/unit/test_fw_utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,28 @@ def test_validate_distribution_raises():
784784
)
785785

786786

787+
def test_validate_distribution_copy():
788+
train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1)
789+
instance_groups = [train_group]
790+
framework = "tensorflow"
791+
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
792+
validated = fw_utils.validate_distribution(
793+
distribution,
794+
instance_groups,
795+
framework,
796+
None,
797+
None,
798+
"custom-container",
799+
{"entry_point": "train.py"},
800+
)
801+
802+
assert validated == {
803+
"instance_groups": ["train_group"],
804+
"smdistributed": {"dataparallel": {"enabled": True}},
805+
}
806+
assert validated is not distribution
807+
808+
787809
def test_validate_smdistributed_not_raises():
788810
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
789811
smdataparallel_enabled_custom_mpi = {

0 commit comments

Comments
 (0)