|
21 | 21 | import shutil
|
22 | 22 | import tempfile
|
23 | 23 | from collections import namedtuple
|
24 |
| -from typing import Optional, Union, Dict |
| 24 | +from typing import List, Optional, Union, Dict |
25 | 25 | from packaging import version
|
26 | 26 |
|
27 | 27 | import sagemaker.image_uris
|
| 28 | +from sagemaker.instance_group import InstanceGroup |
28 | 29 | from sagemaker.s3_utils import s3_path_join
|
29 | 30 | from sagemaker.session_settings import SessionSettings
|
30 | 31 | import sagemaker.utils
|
@@ -828,14 +829,14 @@ def _validate_smdataparallel_args(
|
828 | 829 |
|
829 | 830 |
|
830 | 831 | def validate_distribution(
|
831 |
| - distribution, |
832 |
| - instance_groups, |
833 |
| - framework_name, |
834 |
| - framework_version, |
835 |
| - py_version, |
836 |
| - image_uri, |
837 |
| - kwargs, |
838 |
| -): |
| 832 | + distribution: Dict, |
| 833 | + instance_groups: List[InstanceGroup], |
| 834 | + framework_name: str, |
| 835 | + framework_version: str, |
| 836 | + py_version: str, |
| 837 | + image_uri: str, |
| 838 | + kwargs: Dict, |
| 839 | +) -> Dict: |
839 | 840 | """Check if distribution strategy is correctly invoked by the user.
|
840 | 841 |
|
841 | 842 | Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
|
@@ -872,7 +873,9 @@ def validate_distribution(
|
872 | 873 | strategy-specific inputs are incorrect/unsupported or
|
873 | 874 | heterogeneous cluster set up is incorrect
|
874 | 875 | """
|
875 |
| - train_instance_groups = distribution.get("instance_groups", []) |
| 876 | + validated_distribution = dict(distribution) |
| 877 | + |
| 878 | + train_instance_groups = validated_distribution.get("instance_groups", []) |
876 | 879 | if instance_groups is None:
|
877 | 880 | if len(train_instance_groups) >= 1:
|
878 | 881 | # if estimator's instance_groups is not defined but
|
@@ -902,77 +905,77 @@ def validate_distribution(
|
902 | 905 | instance_type = train_instance_group.instance_type
|
903 | 906 | validate_distribution_for_instance_type(
|
904 | 907 | instance_type=instance_type,
|
905 |
| - distribution=distribution, |
| 908 | + distribution=validated_distribution, |
906 | 909 | )
|
907 | 910 | validate_smdistributed(
|
908 | 911 | instance_type=instance_type,
|
909 | 912 | framework_name=framework_name,
|
910 | 913 | framework_version=framework_version,
|
911 | 914 | py_version=py_version,
|
912 |
| - distribution=distribution, |
| 915 | + distribution=validated_distribution, |
913 | 916 | image_uri=image_uri,
|
914 | 917 | )
|
915 | 918 | if framework_name and framework_name == "pytorch":
|
916 | 919 | # We need to validate only for PyTorch framework
|
917 | 920 | validate_pytorch_distribution(
|
918 |
| - distribution=distribution, |
| 921 | + distribution=validated_distribution, |
919 | 922 | framework_name=framework_name,
|
920 | 923 | framework_version=framework_version,
|
921 | 924 | py_version=py_version,
|
922 | 925 | image_uri=image_uri,
|
923 | 926 | )
|
924 | 927 | validate_torch_distributed_distribution(
|
925 | 928 | instance_type=instance_type,
|
926 |
| - distribution=distribution, |
| 929 | + distribution=validated_distribution, |
927 | 930 | framework_version=framework_version,
|
928 | 931 | py_version=py_version,
|
929 | 932 | image_uri=image_uri,
|
930 | 933 | entry_point=kwargs["entry_point"],
|
931 | 934 | )
|
932 | 935 | warn_if_parameter_server_with_multi_gpu(
|
933 |
| - training_instance_type=instance_type, distribution=distribution |
| 936 | + training_instance_type=instance_type, distribution=validated_distribution |
934 | 937 | )
|
935 | 938 | # get instance group names
|
936 | 939 | instance_group_names.append(train_instance_group.instance_group_name)
|
937 |
| - distribution["instance_groups"] = instance_group_names |
| 940 | + validated_distribution["instance_groups"] = instance_group_names |
938 | 941 | else:
|
939 | 942 | # in this case, we are handling a normal training job (without heterogeneous cluster)
|
940 | 943 | instance_type = renamed_kwargs(
|
941 | 944 | "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
|
942 | 945 | )
|
943 | 946 | validate_distribution_for_instance_type(
|
944 | 947 | instance_type=instance_type,
|
945 |
| - distribution=distribution, |
| 948 | + distribution=validated_distribution, |
946 | 949 | )
|
947 | 950 | validate_smdistributed(
|
948 | 951 | instance_type=instance_type,
|
949 | 952 | framework_name=framework_name,
|
950 | 953 | framework_version=framework_version,
|
951 | 954 | py_version=py_version,
|
952 |
| - distribution=distribution, |
| 955 | + distribution=validated_distribution, |
953 | 956 | image_uri=image_uri,
|
954 | 957 | )
|
955 | 958 | if framework_name and framework_name == "pytorch":
|
956 | 959 | # We need to validate only for PyTorch framework
|
957 | 960 | validate_pytorch_distribution(
|
958 |
| - distribution=distribution, |
| 961 | + distribution=validated_distribution, |
959 | 962 | framework_name=framework_name,
|
960 | 963 | framework_version=framework_version,
|
961 | 964 | py_version=py_version,
|
962 | 965 | image_uri=image_uri,
|
963 | 966 | )
|
964 | 967 | validate_torch_distributed_distribution(
|
965 | 968 | instance_type=instance_type,
|
966 |
| - distribution=distribution, |
| 969 | + distribution=validated_distribution, |
967 | 970 | framework_version=framework_version,
|
968 | 971 | py_version=py_version,
|
969 | 972 | image_uri=image_uri,
|
970 | 973 | entry_point=kwargs["entry_point"],
|
971 | 974 | )
|
972 | 975 | warn_if_parameter_server_with_multi_gpu(
|
973 |
| - training_instance_type=instance_type, distribution=distribution |
| 976 | + training_instance_type=instance_type, distribution=validated_distribution |
974 | 977 | )
|
975 |
| - return distribution |
| 978 | + return validated_distribution |
976 | 979 |
|
977 | 980 |
|
978 | 981 | def validate_distribution_for_instance_type(instance_type, distribution):
|
|
0 commit comments