Skip to content

Deprecation warning for framework profiling for TF 2.12 and on, PT 2.0 and on #3728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sagemaker/debugger/profiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
from __future__ import absolute_import

import logging
from typing import Optional, Union

from sagemaker.debugger.framework_profile import FrameworkProfile
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.deprecations import deprecation_warn_base

logger = logging.getLogger(__name__)


class ProfilerConfig(object):
Expand Down Expand Up @@ -81,6 +85,11 @@ class and SageMaker Framework estimators.
self.framework_profile_params = framework_profile_params
self.disable_profiler = disable_profiler

if self.framework_profile_params is not None:
deprecation_warn_base(
"Framework profiling will be deprecated from tensorflow 2.12 and pytorch 2.0"
)

def _to_request_dict(self):
"""Generate a request dictionary using the parameters provided when initializing the object.

Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def deprecation_warn(name, date, msg=None):
_warn(f"{name} will be deprecated on {date}.{msg}")


def deprecation_warn_base(msg):
"""Raise a warning for soon to be deprecated feature in sagemaker>=2

Args:
msg (str): the warning message.
"""
_warn(msg)


def deprecation_warning(date, msg=None):
"""Decorator for raising deprecation warning for a feature in sagemaker>=2

Expand Down
31 changes: 31 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
from collections import namedtuple
from typing import Optional, Union, Dict
from packaging import version

import sagemaker.image_uris
from sagemaker.session_settings import SessionSettings
Expand All @@ -30,6 +31,7 @@

from sagemaker.deprecations import renamed_warning, renamed_kwargs
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.deprecations import deprecation_warn_base

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -638,6 +640,35 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)


def profiler_config_deprecation_warning(
profiler_config, image_uri, framework_name, framework_version
):
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
if profiler_config is None or profiler_config.framework_profile_params is None:
return

if framework_name not in ("pytorch", "tensorflow"):
return

if framework_version is None:
framework_name, _, image_tag, _ = framework_name_from_image(image_uri)

if image_tag is not None:
framework_version = framework_version_from_tag(image_tag)

if framework_version is not None:
framework_profile_thresh = (
version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12")
)
framework_profile = version.parse(framework_version)
if framework_profile >= framework_profile_thresh:
deprecation_warn_base(
f"Framework profiling is deprecated from\
{framework_name} version {framework_version}.\
No framework metrics will be collected"
)


def validate_smdistributed(
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
):
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
python_deprecation_warning,
validate_version_or_image_args,
validate_distribution,
profiler_config_deprecation_warning,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand Down Expand Up @@ -298,6 +299,11 @@ def __init__(
)
self.compiler_config = compiler_config

if "profiler_config" in kwargs:
profiler_config_deprecation_warning(
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
)

def _pytorch_distribution_configuration(self, distribution):
"""Returns a dict of distribution config for PyTorch training

Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ def __init__(
compiler_config.validate(self)
self.compiler_config = compiler_config

if "profiler_config" in kwargs:
fw.profiler_config_deprecation_warning(
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
)

def _validate_args(self, py_version):
"""Placeholder docstring"""

Expand Down
167 changes: 166 additions & 1 deletion tests/unit/test_profiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import pytest
import re
import time
import warnings
from packaging import version


from sagemaker import image_uris
import sagemaker.fw_utils as fw
from sagemaker.pytorch import PyTorch
from sagemaker.tensorflow import TensorFlow
from sagemaker.debugger.profiler_config import ProfilerConfig, FrameworkProfile

from sagemaker.debugger.metrics_config import (
Expand Down Expand Up @@ -643,3 +649,162 @@ def test_validation():

with pytest.raises(AssertionError, match=ErrorMessages.INVALID_CPROFILE_TIMER.value):
PythonProfilingConfig(cprofile_timer="bad_cprofile_timer")


DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
INSTANCE_COUNT = 1
INSTANCE_TYPE = "ml.p3.2xlarge"
ROLE = "Dummy"
REGION = "us-west-2"


def _check_framework_profile_deprecation_warning(framework_version, framework_name, warn_list):
"""Check the collected warnings for a framework fromfile DeprecationWarning"""

thresh = version.parse("2.12") if framework_name == "tensorflow" else version.parse("2.0")
actual = version.parse(framework_version)

if actual >= thresh:
# should find a Framework profiling deprecation warning
for w in warn_list:
if issubclass(w.category, DeprecationWarning):
if "Framework profiling" in str(w.message):
return
assert 0 # Should have found a deprecation and exited above


def test_create_pytorch_estimator_with_framework_profile(
sagemaker_session,
pytorch_inference_version,
pytorch_inference_py_version,
default_framework_profile,
):
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)

with warnings.catch_warnings(record=True) as warn_list:
warnings.simplefilter("always")
framework_version = pytorch_inference_version
pytorch = PyTorch(
entry_point=SCRIPT_PATH,
framework_version=framework_version,
py_version=pytorch_inference_py_version,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
base_job_name="job",
profiler_config=profiler_config,
)

_check_framework_profile_deprecation_warning(
framework_version, pytorch._framework_name, warn_list
)


def test_create_pytorch_estimator_w_image_with_framework_profile(
sagemaker_session,
pytorch_inference_version,
pytorch_inference_py_version,
gpu_pytorch_instance_type,
default_framework_profile,
):
image_uri = image_uris.retrieve(
"pytorch",
REGION,
version=pytorch_inference_version,
py_version=pytorch_inference_py_version,
instance_type=gpu_pytorch_instance_type,
image_scope="inference",
)

profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)

with warnings.catch_warnings(record=True) as warn_list:
warnings.simplefilter("always")
pytorch = PyTorch(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=gpu_pytorch_instance_type,
image_uri=image_uri,
profiler_config=profiler_config,
)

framework_version = None
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)

if image_tag is not None:
framework_version = fw.framework_version_from_tag(image_tag)

if framework_version is not None:
_check_framework_profile_deprecation_warning(
framework_version, pytorch._framework_name, warn_list
)


def test_create_tf_estimator_with_framework_profile_212(
sagemaker_session,
default_framework_profile,
):
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)

with warnings.catch_warnings(record=True) as warn_list:
warnings.simplefilter("always")
framework_version = "2.12"
tf = TensorFlow(
entry_point=SCRIPT_PATH,
role=ROLE,
framework_version=framework_version,
py_version="py39",
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
profiler_config=profiler_config,
)

_check_framework_profile_deprecation_warning(
framework_version, tf._framework_name, warn_list
)


def test_create_tf_estimator_w_image_with_framework_profile(
sagemaker_session,
tensorflow_inference_version,
tensorflow_inference_py_version,
default_framework_profile,
):
image_uri = image_uris.retrieve(
"tensorflow",
REGION,
version=tensorflow_inference_version,
py_version=tensorflow_inference_py_version,
instance_type=INSTANCE_TYPE,
image_scope="inference",
)

profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)

with warnings.catch_warnings(record=True) as warn_list:
warnings.simplefilter("always")
tf = TensorFlow(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
image_uri=image_uri,
profiler_config=profiler_config,
)

framework_version = None
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)

if image_tag is not None:
framework_version = fw.framework_version_from_tag(image_tag)

if framework_version is not None:
_check_framework_profile_deprecation_warning(
framework_version, tf._framework_name, warn_list
)