Skip to content

Commit 516da5b

Browse files
authored
Merge branch 'master' into feat/jumpstart-instance-types
2 parents d3c77aa + 93f33d9 commit 516da5b

File tree

7 files changed

+228
-1
lines changed

7 files changed

+228
-1
lines changed

src/sagemaker/debugger/profiler_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from typing import Optional, Union
1718

1819
from sagemaker.debugger.framework_profile import FrameworkProfile
1920
from sagemaker.workflow.entities import PipelineVariable
21+
from sagemaker.deprecations import deprecation_warn_base
22+
23+
logger = logging.getLogger(__name__)
2024

2125

2226
class ProfilerConfig(object):
@@ -81,6 +85,11 @@ class and SageMaker Framework estimators.
8185
self.framework_profile_params = framework_profile_params
8286
self.disable_profiler = disable_profiler
8387

88+
if self.framework_profile_params is not None:
89+
deprecation_warn_base(
90+
"Framework profiling will be deprecated from tensorflow 2.12 and pytorch 2.0"
91+
)
92+
8493
def _to_request_dict(self):
8594
"""Generate a request dictionary using the parameters provided when initializing the object.
8695

src/sagemaker/deprecations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def deprecation_warn(name, date, msg=None):
6464
_warn(f"{name} will be deprecated on {date}.{msg}")
6565

6666

67+
def deprecation_warn_base(msg):
68+
"""Raise a warning for soon to be deprecated feature in sagemaker>=2
69+
70+
Args:
71+
msg (str): the warning message.
72+
"""
73+
_warn(msg)
74+
75+
6776
def deprecation_warning(date, msg=None):
6877
"""Decorator for raising deprecation warning for a feature in sagemaker>=2
6978

src/sagemaker/fw_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tempfile
2323
from collections import namedtuple
2424
from typing import Optional, Union, Dict
25+
from packaging import version
2526

2627
import sagemaker.image_uris
2728
from sagemaker.session_settings import SessionSettings
@@ -30,6 +31,7 @@
3031

3132
from sagemaker.deprecations import renamed_warning, renamed_kwargs
3233
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.deprecations import deprecation_warn_base
3335

3436
logger = logging.getLogger(__name__)
3537

@@ -638,6 +640,35 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
638640
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
639641

640642

643+
def profiler_config_deprecation_warning(
644+
profiler_config, image_uri, framework_name, framework_version
645+
):
646+
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
647+
if profiler_config is None or profiler_config.framework_profile_params is None:
648+
return
649+
650+
if framework_name not in ("pytorch", "tensorflow"):
651+
return
652+
653+
if framework_version is None:
654+
framework_name, _, image_tag, _ = framework_name_from_image(image_uri)
655+
656+
if image_tag is not None:
657+
framework_version = framework_version_from_tag(image_tag)
658+
659+
if framework_version is not None:
660+
framework_profile_thresh = (
661+
version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12")
662+
)
663+
framework_profile = version.parse(framework_version)
664+
if framework_profile >= framework_profile_thresh:
665+
deprecation_warn_base(
666+
f"Framework profiling is deprecated from\
667+
{framework_name} version {framework_version}.\
668+
No framework metrics will be collected"
669+
)
670+
671+
641672
def validate_smdistributed(
642673
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
643674
):

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
2727
validate_distribution,
28+
profiler_config_deprecation_warning,
2829
)
2930
from sagemaker.pytorch import defaults
3031
from sagemaker.pytorch.model import PyTorchModel
@@ -298,6 +299,11 @@ def __init__(
298299
)
299300
self.compiler_config = compiler_config
300301

302+
if "profiler_config" in kwargs:
303+
profiler_config_deprecation_warning(
304+
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
305+
)
306+
301307
def _pytorch_distribution_configuration(self, distribution):
302308
"""Returns a dict of distribution config for PyTorch training
303309

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def __init__(
235235
compiler_config.validate(self)
236236
self.compiler_config = compiler_config
237237

238+
if "profiler_config" in kwargs:
239+
fw.profiler_config_deprecation_warning(
240+
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
241+
)
242+
238243
def _validate_args(self, py_version):
239244
"""Placeholder docstring"""
240245

src/sagemaker/tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def to_input_req(self):
559559
"""
560560
completion_criteria_config = {}
561561
if self.max_number_of_training_jobs_not_improving is not None:
562+
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING] = {}
562563
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
563564
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
564565
] = self.max_number_of_training_jobs_not_improving
@@ -569,6 +570,7 @@ def to_input_req(self):
569570
] = self.target_objective_metric_value
570571

571572
if self.complete_on_convergence is not None:
573+
completion_criteria_config[CONVERGENCE_DETECTED] = {}
572574
completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] = (
573575
"Enabled" if self.complete_on_convergence else "Disabled"
574576
)

tests/unit/test_profiler_config.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import pytest
1617
import re
1718
import time
19+
import warnings
20+
from packaging import version
1821

19-
22+
from sagemaker import image_uris
23+
import sagemaker.fw_utils as fw
24+
from sagemaker.pytorch import PyTorch
25+
from sagemaker.tensorflow import TensorFlow
2026
from sagemaker.debugger.profiler_config import ProfilerConfig, FrameworkProfile
2127

2228
from sagemaker.debugger.metrics_config import (
@@ -643,3 +649,162 @@ def test_validation():
643649

644650
with pytest.raises(AssertionError, match=ErrorMessages.INVALID_CPROFILE_TIMER.value):
645651
PythonProfilingConfig(cprofile_timer="bad_cprofile_timer")
652+
653+
654+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
655+
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
656+
INSTANCE_COUNT = 1
657+
INSTANCE_TYPE = "ml.p3.2xlarge"
658+
ROLE = "Dummy"
659+
REGION = "us-west-2"
660+
661+
662+
def _check_framework_profile_deprecation_warning(framework_version, framework_name, warn_list):
663+
"""Check the collected warnings for a framework fromfile DeprecationWarning"""
664+
665+
thresh = version.parse("2.12") if framework_name == "tensorflow" else version.parse("2.0")
666+
actual = version.parse(framework_version)
667+
668+
if actual >= thresh:
669+
# should find a Framework profiling deprecation warning
670+
for w in warn_list:
671+
if issubclass(w.category, DeprecationWarning):
672+
if "Framework profiling" in str(w.message):
673+
return
674+
assert 0 # Should have found a deprecation and exited above
675+
676+
677+
def test_create_pytorch_estimator_with_framework_profile(
678+
sagemaker_session,
679+
pytorch_inference_version,
680+
pytorch_inference_py_version,
681+
default_framework_profile,
682+
):
683+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
684+
685+
with warnings.catch_warnings(record=True) as warn_list:
686+
warnings.simplefilter("always")
687+
framework_version = pytorch_inference_version
688+
pytorch = PyTorch(
689+
entry_point=SCRIPT_PATH,
690+
framework_version=framework_version,
691+
py_version=pytorch_inference_py_version,
692+
role=ROLE,
693+
sagemaker_session=sagemaker_session,
694+
instance_count=INSTANCE_COUNT,
695+
instance_type=INSTANCE_TYPE,
696+
base_job_name="job",
697+
profiler_config=profiler_config,
698+
)
699+
700+
_check_framework_profile_deprecation_warning(
701+
framework_version, pytorch._framework_name, warn_list
702+
)
703+
704+
705+
def test_create_pytorch_estimator_w_image_with_framework_profile(
706+
sagemaker_session,
707+
pytorch_inference_version,
708+
pytorch_inference_py_version,
709+
gpu_pytorch_instance_type,
710+
default_framework_profile,
711+
):
712+
image_uri = image_uris.retrieve(
713+
"pytorch",
714+
REGION,
715+
version=pytorch_inference_version,
716+
py_version=pytorch_inference_py_version,
717+
instance_type=gpu_pytorch_instance_type,
718+
image_scope="inference",
719+
)
720+
721+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
722+
723+
with warnings.catch_warnings(record=True) as warn_list:
724+
warnings.simplefilter("always")
725+
pytorch = PyTorch(
726+
entry_point=SCRIPT_PATH,
727+
role=ROLE,
728+
sagemaker_session=sagemaker_session,
729+
instance_count=INSTANCE_COUNT,
730+
instance_type=gpu_pytorch_instance_type,
731+
image_uri=image_uri,
732+
profiler_config=profiler_config,
733+
)
734+
735+
framework_version = None
736+
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)
737+
738+
if image_tag is not None:
739+
framework_version = fw.framework_version_from_tag(image_tag)
740+
741+
if framework_version is not None:
742+
_check_framework_profile_deprecation_warning(
743+
framework_version, pytorch._framework_name, warn_list
744+
)
745+
746+
747+
def test_create_tf_estimator_with_framework_profile_212(
748+
sagemaker_session,
749+
default_framework_profile,
750+
):
751+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
752+
753+
with warnings.catch_warnings(record=True) as warn_list:
754+
warnings.simplefilter("always")
755+
framework_version = "2.12"
756+
tf = TensorFlow(
757+
entry_point=SCRIPT_PATH,
758+
role=ROLE,
759+
framework_version=framework_version,
760+
py_version="py39",
761+
sagemaker_session=sagemaker_session,
762+
instance_count=INSTANCE_COUNT,
763+
instance_type=INSTANCE_TYPE,
764+
profiler_config=profiler_config,
765+
)
766+
767+
_check_framework_profile_deprecation_warning(
768+
framework_version, tf._framework_name, warn_list
769+
)
770+
771+
772+
def test_create_tf_estimator_w_image_with_framework_profile(
773+
sagemaker_session,
774+
tensorflow_inference_version,
775+
tensorflow_inference_py_version,
776+
default_framework_profile,
777+
):
778+
image_uri = image_uris.retrieve(
779+
"tensorflow",
780+
REGION,
781+
version=tensorflow_inference_version,
782+
py_version=tensorflow_inference_py_version,
783+
instance_type=INSTANCE_TYPE,
784+
image_scope="inference",
785+
)
786+
787+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
788+
789+
with warnings.catch_warnings(record=True) as warn_list:
790+
warnings.simplefilter("always")
791+
tf = TensorFlow(
792+
entry_point=SCRIPT_PATH,
793+
role=ROLE,
794+
sagemaker_session=sagemaker_session,
795+
instance_count=INSTANCE_COUNT,
796+
instance_type=INSTANCE_TYPE,
797+
image_uri=image_uri,
798+
profiler_config=profiler_config,
799+
)
800+
801+
framework_version = None
802+
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)
803+
804+
if image_tag is not None:
805+
framework_version = fw.framework_version_from_tag(image_tag)
806+
807+
if framework_version is not None:
808+
_check_framework_profile_deprecation_warning(
809+
framework_version, tf._framework_name, warn_list
810+
)

0 commit comments

Comments
 (0)