Skip to content

Commit d739fcd

Browse files
authored
Merge branch 'master' into master
2 parents b409c62 + d37292f commit d739fcd

File tree

3 files changed

+47
-52
lines changed

3 files changed

+47
-52
lines changed

src/sagemaker/fw_utils.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,37 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Utility methods used by framework classes"""
13+
"""Utility methods used by framework classes."""
1414
from __future__ import absolute_import
1515

1616
import json
1717
import logging
1818
import os
1919
import re
20-
import time
2120
import shutil
2221
import tempfile
22+
import time
2323
from collections import namedtuple
24-
from typing import List, Optional, Union, Dict
24+
from typing import Dict, List, Optional, Union
25+
2526
from packaging import version
2627

2728
import sagemaker.image_uris
29+
import sagemaker.utils
30+
from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
2831
from sagemaker.instance_group import InstanceGroup
2932
from sagemaker.s3_utils import s3_path_join
3033
from sagemaker.session_settings import SessionSettings
31-
import sagemaker.utils
3234
from sagemaker.workflow import is_pipeline_variable
33-
34-
from sagemaker.deprecations import renamed_warning, renamed_kwargs
3535
from sagemaker.workflow.entities import PipelineVariable
36-
from sagemaker.deprecations import deprecation_warn_base
3736

3837
logger = logging.getLogger(__name__)
3938

4039
_TAR_SOURCE_FILENAME = "source.tar.gz"
4140

4241
UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
4342
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43+
4444
This is for the source code used for the entry point with an ``Estimator``. It can be
4545
instantiated with positional or keyword arguments.
4646
"""
@@ -211,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
211211
git_config: Optional[Dict[str, str]] = None,
212212
enable_network_isolation: Union[bool, PipelineVariable] = False,
213213
):
214-
"""Validate source code input against pipeline variables
214+
"""Validate source code input against pipeline variables.
215215
216216
Args:
217217
entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -481,7 +481,7 @@ def tar_and_upload_dir(
481481

482482

483483
def _list_files_to_compress(script, directory):
484-
"""Placeholder docstring"""
484+
"""Placeholder docstring."""
485485
if directory is None:
486486
return [script]
487487

@@ -585,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
585585
The location returned is a potential concatenation of 2 parts
586586
1. code_location_key_prefix if it exists
587587
2. model_name or a name derived from the image
588-
589588
Args:
590589
code_location_key_prefix (str): the s3 key prefix from code_location
591590
model_name (str): the name of the model
@@ -620,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
620619
"enabled": True
621620
}
622621
}
623-
624-
625622
"""
626623
if training_instance_type == "local" or distribution is None:
627624
return
@@ -646,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
646643
def profiler_config_deprecation_warning(
647644
profiler_config, image_uri, framework_name, framework_version
648645
):
649-
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
646+
"""Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0."""
650647
if profiler_config is None or profiler_config.framework_profile_params is None:
651648
return
652649

@@ -692,6 +689,7 @@ def validate_smdistributed(
692689
framework_name (str): A string representing the name of framework selected.
693690
framework_version (str): A string representing the framework version selected.
694691
py_version (str): A string representing the python version selected.
692+
Ex: `py38, py39, py310, py311`
695693
distribution (dict): A dictionary with information to enable distributed training.
696694
(Defaults to None if distributed training is not enabled.) For example:
697695
@@ -763,7 +761,8 @@ def _validate_smdataparallel_args(
763761
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
764762
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
765763
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
766-
py_version (str): A string representing the python version selected. Ex: `py3`
764+
py_version (str): A string representing the python version selected.
765+
Ex: `py38, py39, py310, py311`
767766
distribution (dict): A dictionary with information to enable distributed training.
768767
(Defaults to None if distributed training is not enabled.) Ex:
769768
@@ -847,6 +846,7 @@ def validate_distribution(
847846
framework_name (str): A string representing the name of framework selected.
848847
framework_version (str): A string representing the framework version selected.
849848
py_version (str): A string representing the python version selected.
849+
Ex: `py38, py39, py310, py311`
850850
image_uri (str): A string representing a Docker image URI.
851851
kwargs(dict): Additional kwargs passed to this function
852852
@@ -953,7 +953,7 @@ def validate_distribution(
953953

954954

955955
def validate_distribution_for_instance_type(instance_type, distribution):
956-
"""Check if the provided distribution strategy is supported for the instance_type
956+
"""Check if the provided distribution strategy is supported for the instance_type.
957957
958958
Args:
959959
instance_type (str): A string representing the type of training instance selected.
@@ -1010,6 +1010,7 @@ def validate_torch_distributed_distribution(
10101010
}
10111011
framework_version (str): A string representing the framework version selected.
10121012
py_version (str): A string representing the python version selected.
1013+
Ex: `py38, py39, py310, py311`
10131014
image_uri (str): A string representing a Docker image URI.
10141015
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
10151016
source file that should be executed as the entry point to
@@ -1072,7 +1073,7 @@ def validate_torch_distributed_distribution(
10721073

10731074

10741075
def _is_gpu_instance(instance_type):
1075-
"""Returns bool indicating whether instance_type supports GPU
1076+
"""Returns bool indicating whether instance_type supports GPU.
10761077
10771078
Args:
10781079
instance_type (str): Name of the instance_type to check against.
@@ -1091,7 +1092,7 @@ def _is_gpu_instance(instance_type):
10911092

10921093

10931094
def _is_trainium_instance(instance_type):
1094-
"""Returns bool indicating whether instance_type is a Trainium instance
1095+
"""Returns bool indicating whether instance_type is a Trainium instance.
10951096
10961097
Args:
10971098
instance_type (str): Name of the instance_type to check against.
@@ -1107,7 +1108,7 @@ def _is_trainium_instance(instance_type):
11071108

11081109

11091110
def python_deprecation_warning(framework, latest_supported_version):
1110-
"""Placeholder docstring"""
1111+
"""Placeholder docstring."""
11111112
return PYTHON_2_DEPRECATION_WARNING.format(
11121113
framework=framework, latest_supported_version=latest_supported_version
11131114
)
@@ -1121,7 +1122,6 @@ def _region_supports_debugger(region_name):
11211122
11221123
Returns:
11231124
bool: Whether or not the region supports Amazon SageMaker Debugger.
1124-
11251125
"""
11261126
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
11271127

@@ -1134,7 +1134,6 @@ def _region_supports_profiler(region_name):
11341134
11351135
Returns:
11361136
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1137-
11381137
"""
11391138
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
11401139

@@ -1162,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11621161
11631162
Args:
11641163
framework_version (str): The version of the framework.
1165-
py_version (str): The version of Python.
1164+
py_version (str): A string representing the python version selected.
1165+
Ex: `py38, py39, py310, py311`
11661166
image_uri (str): The URI of the image.
11671167
11681168
Raises:
@@ -1194,9 +1194,8 @@ def create_image_uri(
11941194
instance_type (str): SageMaker instance type. Used to determine device
11951195
type (cpu/gpu/family-specific optimized).
11961196
framework_version (str): The version of the framework.
1197-
py_version (str): Optional. Python version. If specified, should be one
1198-
of 'py2' or 'py3'. If not specified, image uri will not include a
1199-
python component.
1197+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1198+
If not specified, image uri will not include a python component.
12001199
account (str): AWS account that contains the image. (default:
12011200
'520713654638')
12021201
accelerator_type (str): SageMaker Elastic Inference accelerator type.

src/sagemaker/huggingface/estimator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515

1616
import logging
1717
import re
18-
from typing import Optional, Union, Dict
18+
from typing import Dict, Optional, Union
1919

20-
from sagemaker.estimator import Framework, EstimatorBase
21-
from sagemaker.fw_utils import (
22-
framework_name_from_image,
23-
validate_distribution,
24-
)
20+
from sagemaker.estimator import EstimatorBase, Framework
21+
from sagemaker.fw_utils import framework_name_from_image, validate_distribution
2522
from sagemaker.huggingface.model import HuggingFaceModel
26-
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27-
2823
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
24+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2925
from sagemaker.workflow.entities import PipelineVariable
3026

3127
logger = logging.getLogger("sagemaker")
@@ -66,7 +62,7 @@ def __init__(
6662
Args:
6763
py_version (str): Python version you want to use for executing your model training
6864
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
69-
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
65+
using PyTorch, the current supported version is ``py39``. If using TensorFlow,
7066
the current supported version is ``py37``.
7167
entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source
7268
file which should be executed as the entry point to training.

src/sagemaker/processing.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,51 +18,51 @@
1818
"""
1919
from __future__ import absolute_import
2020

21+
import logging
2122
import os
2223
import pathlib
23-
import logging
24+
import re
25+
from copy import copy
2426
from textwrap import dedent
2527
from typing import Dict, List, Optional, Union
26-
from copy import copy
27-
import re
2828

2929
import attr
30-
3130
from six.moves.urllib.parse import urlparse
3231
from six.moves.urllib.request import url2pathname
32+
3333
from sagemaker import s3
34+
from sagemaker.apiutils._base_types import ApiObject
3435
from sagemaker.config import (
36+
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
37+
PROCESSING_JOB_ENVIRONMENT_PATH,
38+
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
3539
PROCESSING_JOB_KMS_KEY_ID_PATH,
40+
PROCESSING_JOB_ROLE_ARN_PATH,
3641
PROCESSING_JOB_SECURITY_GROUP_IDS_PATH,
3742
PROCESSING_JOB_SUBNETS_PATH,
38-
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
3943
PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH,
40-
PROCESSING_JOB_ROLE_ARN_PATH,
41-
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
42-
PROCESSING_JOB_ENVIRONMENT_PATH,
4344
)
45+
from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input
4446
from sagemaker.job import _Job
4547
from sagemaker.local import LocalSession
4648
from sagemaker.network import NetworkConfig
49+
from sagemaker.s3 import S3Uploader
50+
from sagemaker.session import Session
4751
from sagemaker.utils import (
52+
Tags,
4853
base_name_from_image,
54+
check_and_get_run_experiment_config,
55+
format_tags,
4956
get_config_value,
5057
name_from_base,
51-
check_and_get_run_experiment_config,
52-
resolve_value_from_config,
5358
resolve_class_attribute_from_config,
54-
Tags,
55-
format_tags,
59+
resolve_value_from_config,
5660
)
57-
from sagemaker.session import Session
5861
from sagemaker.workflow import is_pipeline_variable
62+
from sagemaker.workflow.entities import PipelineVariable
63+
from sagemaker.workflow.execution_variables import ExecutionVariables
5964
from sagemaker.workflow.functions import Join
6065
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
61-
from sagemaker.workflow.execution_variables import ExecutionVariables
62-
from sagemaker.workflow.entities import PipelineVariable
63-
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
64-
from sagemaker.apiutils._base_types import ApiObject
65-
from sagemaker.s3 import S3Uploader
6666

6767
logger = logging.getLogger(__name__)
6868

@@ -1465,7 +1465,7 @@ def __init__(
14651465
instance_type (str or PipelineVariable): The type of EC2 instance to use for
14661466
processing, for example, 'ml.c4.xlarge'.
14671467
py_version (str): Python version you want to use for executing your
1468-
model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value
1468+
model training code. Ex `py38, py39, py310, py311`. Value
14691469
is ignored when ``image_uri`` is provided.
14701470
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
14711471
processing jobs (default: None).

0 commit comments

Comments
 (0)