10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """Utility methods used by framework classes"""
13
+ """Utility methods used by framework classes. """
14
14
from __future__ import absolute_import
15
15
16
16
import json
17
17
import logging
18
18
import os
19
19
import re
20
- import time
21
20
import shutil
22
21
import tempfile
22
+ import time
23
23
from collections import namedtuple
24
- from typing import List , Optional , Union , Dict
24
+ from typing import Dict , List , Optional , Union
25
+
25
26
from packaging import version
26
27
27
28
import sagemaker .image_uris
29
+ import sagemaker .utils
30
+ from sagemaker .deprecations import deprecation_warn_base , renamed_kwargs , renamed_warning
28
31
from sagemaker .instance_group import InstanceGroup
29
32
from sagemaker .s3_utils import s3_path_join
30
33
from sagemaker .session_settings import SessionSettings
31
- import sagemaker .utils
32
34
from sagemaker .workflow import is_pipeline_variable
33
-
34
- from sagemaker .deprecations import renamed_warning , renamed_kwargs
35
35
from sagemaker .workflow .entities import PipelineVariable
36
- from sagemaker .deprecations import deprecation_warn_base
37
36
38
37
logger = logging .getLogger (__name__ )
39
38
40
39
_TAR_SOURCE_FILENAME = "source.tar.gz"
41
40
42
41
UploadedCode = namedtuple ("UploadedCode" , ["s3_prefix" , "script_name" ])
43
42
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43
+
44
44
This is for the source code used for the entry point with an ``Estimator``. It can be
45
45
instantiated with positional or keyword arguments.
46
46
"""
@@ -211,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
211
211
git_config : Optional [Dict [str , str ]] = None ,
212
212
enable_network_isolation : Union [bool , PipelineVariable ] = False ,
213
213
):
214
- """Validate source code input against pipeline variables
214
+ """Validate source code input against pipeline variables.
215
215
216
216
Args:
217
217
entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -481,7 +481,7 @@ def tar_and_upload_dir(
481
481
482
482
483
483
def _list_files_to_compress (script , directory ):
484
- """Placeholder docstring"""
484
+ """Placeholder docstring. """
485
485
if directory is None :
486
486
return [script ]
487
487
@@ -585,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
585
585
The location returned is a potential concatenation of 2 parts
586
586
1. code_location_key_prefix if it exists
587
587
2. model_name or a name derived from the image
588
-
589
588
Args:
590
589
code_location_key_prefix (str): the s3 key prefix from code_location
591
590
model_name (str): the name of the model
@@ -620,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
620
619
"enabled": True
621
620
}
622
621
}
623
-
624
-
625
622
"""
626
623
if training_instance_type == "local" or distribution is None :
627
624
return
@@ -646,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
646
643
def profiler_config_deprecation_warning (
647
644
profiler_config , image_uri , framework_name , framework_version
648
645
):
649
- """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
646
+ """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0. """
650
647
if profiler_config is None or profiler_config .framework_profile_params is None :
651
648
return
652
649
@@ -692,6 +689,7 @@ def validate_smdistributed(
692
689
framework_name (str): A string representing the name of framework selected.
693
690
framework_version (str): A string representing the framework version selected.
694
691
py_version (str): A string representing the python version selected.
692
+ Ex: `py38, py39, py310, py311`
695
693
distribution (dict): A dictionary with information to enable distributed training.
696
694
(Defaults to None if distributed training is not enabled.) For example:
697
695
@@ -763,7 +761,8 @@ def _validate_smdataparallel_args(
763
761
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
764
762
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
765
763
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
766
- py_version (str): A string representing the python version selected. Ex: `py3`
764
+ py_version (str): A string representing the python version selected.
765
+ Ex: `py38, py39, py310, py311`
767
766
distribution (dict): A dictionary with information to enable distributed training.
768
767
(Defaults to None if distributed training is not enabled.) Ex:
769
768
@@ -847,6 +846,7 @@ def validate_distribution(
847
846
framework_name (str): A string representing the name of framework selected.
848
847
framework_version (str): A string representing the framework version selected.
849
848
py_version (str): A string representing the python version selected.
849
+ Ex: `py38, py39, py310, py311`
850
850
image_uri (str): A string representing a Docker image URI.
851
851
kwargs(dict): Additional kwargs passed to this function
852
852
@@ -953,7 +953,7 @@ def validate_distribution(
953
953
954
954
955
955
def validate_distribution_for_instance_type (instance_type , distribution ):
956
- """Check if the provided distribution strategy is supported for the instance_type
956
+ """Check if the provided distribution strategy is supported for the instance_type.
957
957
958
958
Args:
959
959
instance_type (str): A string representing the type of training instance selected.
@@ -1010,6 +1010,7 @@ def validate_torch_distributed_distribution(
1010
1010
}
1011
1011
framework_version (str): A string representing the framework version selected.
1012
1012
py_version (str): A string representing the python version selected.
1013
+ Ex: `py38, py39, py310, py311`
1013
1014
image_uri (str): A string representing a Docker image URI.
1014
1015
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
1015
1016
source file that should be executed as the entry point to
@@ -1072,7 +1073,7 @@ def validate_torch_distributed_distribution(
1072
1073
1073
1074
1074
1075
def _is_gpu_instance (instance_type ):
1075
- """Returns bool indicating whether instance_type supports GPU
1076
+ """Returns bool indicating whether instance_type supports GPU.
1076
1077
1077
1078
Args:
1078
1079
instance_type (str): Name of the instance_type to check against.
@@ -1091,7 +1092,7 @@ def _is_gpu_instance(instance_type):
1091
1092
1092
1093
1093
1094
def _is_trainium_instance (instance_type ):
1094
- """Returns bool indicating whether instance_type is a Trainium instance
1095
+ """Returns bool indicating whether instance_type is a Trainium instance.
1095
1096
1096
1097
Args:
1097
1098
instance_type (str): Name of the instance_type to check against.
@@ -1107,7 +1108,7 @@ def _is_trainium_instance(instance_type):
1107
1108
1108
1109
1109
1110
def python_deprecation_warning (framework , latest_supported_version ):
1110
- """Placeholder docstring"""
1111
+ """Placeholder docstring. """
1111
1112
return PYTHON_2_DEPRECATION_WARNING .format (
1112
1113
framework = framework , latest_supported_version = latest_supported_version
1113
1114
)
@@ -1121,7 +1122,6 @@ def _region_supports_debugger(region_name):
1121
1122
1122
1123
Returns:
1123
1124
bool: Whether or not the region supports Amazon SageMaker Debugger.
1124
-
1125
1125
"""
1126
1126
return region_name .lower () not in DEBUGGER_UNSUPPORTED_REGIONS
1127
1127
@@ -1134,7 +1134,6 @@ def _region_supports_profiler(region_name):
1134
1134
1135
1135
Returns:
1136
1136
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1137
-
1138
1137
"""
1139
1138
return region_name .lower () not in PROFILER_UNSUPPORTED_REGIONS
1140
1139
@@ -1162,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
1162
1161
1163
1162
Args:
1164
1163
framework_version (str): The version of the framework.
1165
- py_version (str): The version of Python.
1164
+ py_version (str): A string representing the python version selected.
1165
+ Ex: `py38, py39, py310, py311`
1166
1166
image_uri (str): The URI of the image.
1167
1167
1168
1168
Raises:
@@ -1194,9 +1194,8 @@ def create_image_uri(
1194
1194
instance_type (str): SageMaker instance type. Used to determine device
1195
1195
type (cpu/gpu/family-specific optimized).
1196
1196
framework_version (str): The version of the framework.
1197
- py_version (str): Optional. Python version. If specified, should be one
1198
- of 'py2' or 'py3'. If not specified, image uri will not include a
1199
- python component.
1197
+ py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1198
+ If not specified, image uri will not include a python component.
1200
1199
account (str): AWS account that contains the image. (default:
1201
1200
'520713654638')
1202
1201
accelerator_type (str): SageMaker Elastic Inference accelerator type.
0 commit comments