Skip to content

Commit f3e0d6c

Browse files
Merge branch 'master' into feature/transformer_with_monitoring
2 parents a2ac915 + 96e417f commit f3e0d6c

File tree

13 files changed

+1450
-425
lines changed

13 files changed

+1450
-425
lines changed

CHANGELOG.md

+21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
# Changelog
22

3+
## v2.114.0 (2022-10-26)
4+
5+
### Features
6+
7+
* Graviton support for XGB and SKLearn frameworks
8+
* Graviton support for PyTorch and Tensorflow frameworks
9+
* do not expand estimator role when it is pipeline parameter
10+
* added support for batch transform with model monitoring
11+
12+
### Bug Fixes and Other Changes
13+
14+
* regex in tuning integs
15+
* remove debugger environment var set up
16+
* adjacent slash in s3 key
17+
* Fix Repack step auto install behavior
18+
* Add retry for airflow ParsingError
19+
20+
### Documentation Changes
21+
22+
* doc fix
23+
324
## v2.113.0 (2022-10-21)
425

526
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.113.1.dev0
1+
2.114.1.dev0

src/sagemaker/estimator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
UploadedCode,
4545
_region_supports_debugger,
4646
_region_supports_profiler,
47+
_instance_type_supports_profiler,
4748
get_mp_parameters,
4849
tar_and_upload_dir,
4950
validate_source_dir,
@@ -592,7 +593,9 @@ def __init__(
592593

593594
self.max_retry_attempts = max_retry_attempts
594595

595-
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
596+
if not _region_supports_profiler(
597+
self.sagemaker_session.boto_region_name
598+
) or _instance_type_supports_profiler(self.instance_type):
596599
self.disable_profiler = True
597600

598601
self.profiler_rule_configs = None

src/sagemaker/fw_utils.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
"2.8.0",
103103
"2.9",
104104
"2.9.1",
105+
"2.10",
106+
"2.10.0",
105107
],
106108
"pytorch": [
107109
"1.6",
@@ -144,6 +146,21 @@
144146
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
145147

146148

149+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = [
150+
"m6g",
151+
"m6gd",
152+
"c6g",
153+
"c6gd",
154+
"c6gn",
155+
"c7g",
156+
"r6g",
157+
"r6gd",
158+
]
159+
160+
161+
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch", "xgboost", "sklearn"])
162+
163+
147164
def validate_source_dir(script, directory):
148165
"""Validate that the source directory exists and it contains the user script.
149166
@@ -163,12 +180,6 @@ def validate_source_dir(script, directory):
163180
return True
164181

165182

166-
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = ["c6g", "t4g", "r6g", "m6g"]
167-
168-
169-
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch"])
170-
171-
172183
def validate_source_code_input_against_pipeline_variables(
173184
entry_point: Optional[Union[str, PipelineVariable]] = None,
174185
source_dir: Optional[Union[str, PipelineVariable]] = None,
@@ -1065,6 +1076,22 @@ def _region_supports_profiler(region_name):
10651076
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
10661077

10671078

1079+
def _instance_type_supports_profiler(instance_type):
1080+
"""Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
1081+
1082+
Args:
1083+
instance_type (str): Name of the instance_type to check against.
1084+
1085+
Returns:
1086+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1087+
"""
1088+
if isinstance(instance_type, str):
1089+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1090+
if match and match[1].startswith("trn"):
1091+
return True
1092+
return False
1093+
1094+
10681095
def validate_version_or_image_args(framework_version, py_version, image_uri):
10691096
"""Checks if version or image arguments are specified.
10701097

0 commit comments

Comments
 (0)