diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 3d29e15cdb..654ba544f0 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -13,10 +13,14 @@ """Configuration for collecting system and framework metrics in SageMaker training jobs.""" from __future__ import absolute_import +import logging from typing import Optional, Union from sagemaker.debugger.framework_profile import FrameworkProfile from sagemaker.workflow.entities import PipelineVariable +from sagemaker.deprecations import deprecation_warn_base + +logger = logging.getLogger(__name__) class ProfilerConfig(object): @@ -81,6 +85,11 @@ class and SageMaker Framework estimators. self.framework_profile_params = framework_profile_params self.disable_profiler = disable_profiler + if self.framework_profile_params is not None: + deprecation_warn_base( + "Framework profiling will be deprecated from tensorflow 2.12 and pytorch 2.0" + ) + def _to_request_dict(self): """Generate a request dictionary using the parameters provided when initializing the object. diff --git a/src/sagemaker/deprecations.py b/src/sagemaker/deprecations.py index b9a95483ae..a8ba298082 100644 --- a/src/sagemaker/deprecations.py +++ b/src/sagemaker/deprecations.py @@ -64,6 +64,15 @@ def deprecation_warn(name, date, msg=None): _warn(f"{name} will be deprecated on {date}.{msg}") +def deprecation_warn_base(msg): + """Raise a warning for soon to be deprecated feature in sagemaker>=2 + + Args: + msg (str): the warning message. + """ + _warn(msg) + + def deprecation_warning(date, msg=None): """Decorator for raising deprecation warning for a feature in sagemaker>=2 diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index fa8f35af0c..88e3afe3f3 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -22,6 +22,7 @@ import tempfile from collections import namedtuple from typing import Optional, Union, Dict +from packaging import version import sagemaker.image_uris from sagemaker.session_settings import SessionSettings @@ -30,6 +31,7 @@ from sagemaker.deprecations import renamed_warning, renamed_kwargs from sagemaker.workflow.entities import PipelineVariable +from sagemaker.deprecations import deprecation_warn_base logger = logging.getLogger(__name__) @@ -638,6 +640,35 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING) +def profiler_config_deprecation_warning( + profiler_config, image_uri, framework_name, framework_version +): + """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0""" + if profiler_config is None or profiler_config.framework_profile_params is None: + return + + if framework_name not in ("pytorch", "tensorflow"): + return + + if framework_version is None: + framework_name, _, image_tag, _ = framework_name_from_image(image_uri) + + if image_tag is not None: + framework_version = framework_version_from_tag(image_tag) + + if framework_version is not None: + framework_profile_thresh = ( + version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12") + ) + framework_profile = version.parse(framework_version) + if framework_profile >= framework_profile_thresh: + deprecation_warn_base( + f"Framework profiling is deprecated from\ + {framework_name} version {framework_version}.\ + No framework metrics will be collected" + ) + + def validate_smdistributed( instance_type, framework_name, framework_version, py_version, distribution, image_uri=None ): diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 21ebc48351..2950c73cf0 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + profiler_config_deprecation_warning, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -298,6 +299,11 @@ def __init__( ) self.compiler_config = compiler_config + if "profiler_config" in kwargs: + profiler_config_deprecation_warning( + kwargs["profiler_config"], image_uri, self._framework_name, framework_version + ) + def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index c7463dfc03..914a56dd3f 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -235,6 +235,11 @@ def __init__( compiler_config.validate(self) self.compiler_config = compiler_config + if "profiler_config" in kwargs: + fw.profiler_config_deprecation_warning( + kwargs["profiler_config"], image_uri, self._framework_name, framework_version + ) + def _validate_args(self, py_version): """Placeholder docstring""" diff --git a/tests/unit/test_profiler_config.py b/tests/unit/test_profiler_config.py index c7afce2c45..92077fa926 100644 --- a/tests/unit/test_profiler_config.py +++ b/tests/unit/test_profiler_config.py @@ -12,11 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import pytest import re import time +import warnings +from packaging import version - +from sagemaker import image_uris +import sagemaker.fw_utils as fw +from sagemaker.pytorch import PyTorch +from sagemaker.tensorflow import TensorFlow from sagemaker.debugger.profiler_config import ProfilerConfig, FrameworkProfile from sagemaker.debugger.metrics_config import ( @@ -643,3 +649,162 @@ def test_validation(): with pytest.raises(AssertionError, match=ErrorMessages.INVALID_CPROFILE_TIMER.value): PythonProfilingConfig(cprofile_timer="bad_cprofile_timer") + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.p3.2xlarge" +ROLE = "Dummy" +REGION = "us-west-2" + + +def _check_framework_profile_deprecation_warning(framework_version, framework_name, warn_list): + """Check the collected warnings for a framework fromfile DeprecationWarning""" + + thresh = version.parse("2.12") if framework_name == "tensorflow" else version.parse("2.0") + actual = version.parse(framework_version) + + if actual >= thresh: + # should find a Framework profiling deprecation warning + for w in warn_list: + if issubclass(w.category, DeprecationWarning): + if "Framework profiling" in str(w.message): + return + assert 0 # Should have found a deprecation and exited above + + +def test_create_pytorch_estimator_with_framework_profile( + sagemaker_session, + pytorch_inference_version, + pytorch_inference_py_version, + default_framework_profile, +): + profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile) + + with warnings.catch_warnings(record=True) as warn_list: + warnings.simplefilter("always") + framework_version = pytorch_inference_version + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + framework_version=framework_version, + py_version=pytorch_inference_py_version, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + base_job_name="job", + profiler_config=profiler_config, + ) + + _check_framework_profile_deprecation_warning( + framework_version, pytorch._framework_name, warn_list + ) + + +def test_create_pytorch_estimator_w_image_with_framework_profile( + sagemaker_session, + pytorch_inference_version, + pytorch_inference_py_version, + gpu_pytorch_instance_type, + default_framework_profile, +): + image_uri = image_uris.retrieve( + "pytorch", + REGION, + version=pytorch_inference_version, + py_version=pytorch_inference_py_version, + instance_type=gpu_pytorch_instance_type, + image_scope="inference", + ) + + profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile) + + with warnings.catch_warnings(record=True) as warn_list: + warnings.simplefilter("always") + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=gpu_pytorch_instance_type, + image_uri=image_uri, + profiler_config=profiler_config, + ) + + framework_version = None + _, _, image_tag, _ = fw.framework_name_from_image(image_uri) + + if image_tag is not None: + framework_version = fw.framework_version_from_tag(image_tag) + + if framework_version is not None: + _check_framework_profile_deprecation_warning( + framework_version, pytorch._framework_name, warn_list + ) + + +def test_create_tf_estimator_with_framework_profile_212( + sagemaker_session, + default_framework_profile, +): + profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile) + + with warnings.catch_warnings(record=True) as warn_list: + warnings.simplefilter("always") + framework_version = "2.12" + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + framework_version=framework_version, + py_version="py39", + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + profiler_config=profiler_config, + ) + + _check_framework_profile_deprecation_warning( + framework_version, tf._framework_name, warn_list + ) + + +def test_create_tf_estimator_w_image_with_framework_profile( + sagemaker_session, + tensorflow_inference_version, + tensorflow_inference_py_version, + default_framework_profile, +): + image_uri = image_uris.retrieve( + "tensorflow", + REGION, + version=tensorflow_inference_version, + py_version=tensorflow_inference_py_version, + instance_type=INSTANCE_TYPE, + image_scope="inference", + ) + + profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile) + + with warnings.catch_warnings(record=True) as warn_list: + warnings.simplefilter("always") + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + image_uri=image_uri, + profiler_config=profiler_config, + ) + + framework_version = None + _, _, image_tag, _ = fw.framework_name_from_image(image_uri) + + if image_tag is not None: + framework_version = fw.framework_version_from_tag(image_tag) + + if framework_version is not None: + _check_framework_profile_deprecation_warning( + framework_version, tf._framework_name, warn_list + )