Skip to content

Commit b5d5f6b

Browse files
committed
Deprecation warning for TF2.12 ant PT 2.0 for framework profiling
1 parent eb6d511 commit b5d5f6b

File tree

4 files changed

+114
-50
lines changed

4 files changed

+114
-50
lines changed

src/sagemaker/debugger/profiler_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sagemaker.debugger.framework_profile import FrameworkProfile
2020
from sagemaker.workflow.entities import PipelineVariable
21+
from sagemaker.deprecations import deprecation_warn_base
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -85,7 +86,7 @@ class and SageMaker Framework estimators.
8586
self.disable_profiler = disable_profiler
8687

8788
if self.framework_profile_params is not None:
88-
logger.warning(
89+
deprecation_warn_base(
8990
"Framework profiling will be deprecated from tensorflow 2.12 and pytorch 2.0"
9091
)
9192

src/sagemaker/deprecations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def deprecation_warn(name, date, msg=None):
6464
_warn(f"{name} will be deprecated on {date}.{msg}")
6565

6666

67+
def deprecation_warn_base(msg):
68+
"""Raise a warning for soon to be deprecated feature in sagemaker>=2
69+
70+
Args:
71+
msg (str): the warning message.
72+
"""
73+
_warn(msg)
74+
75+
6776
def deprecation_warning(date, msg=None):
6877
"""Decorator for raising deprecation warning for a feature in sagemaker>=2
6978

src/sagemaker/fw_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from sagemaker.deprecations import renamed_warning, renamed_kwargs
3333
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.deprecations import deprecation_warn_base
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -661,11 +662,10 @@ def profiler_config_deprecation_warning(
661662
)
662663
framework_profile = version.parse(framework_version)
663664
if framework_profile >= framework_profile_thresh:
664-
logger.warning(
665-
"Framework profiling is deprecated from %s version %s."
666-
"No framework metrics will be collected.",
667-
framework_name,
668-
framework_version,
665+
deprecation_warn_base(
666+
f"Framework profiling is deprecated from\
667+
{framework_name} version {framework_version}.\
668+
No framework metrics will be collected"
669669
)
670670

671671

tests/unit/test_profiler_config.py

Lines changed: 98 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import pytest
1717
import re
1818
import time
19+
import warnings
20+
from packaging import version
1921

2022
from sagemaker import image_uris
23+
import sagemaker.fw_utils as fw
2124
from sagemaker.pytorch import PyTorch
2225
from sagemaker.tensorflow import TensorFlow
2326
from sagemaker.debugger.profiler_config import ProfilerConfig, FrameworkProfile
@@ -656,6 +659,21 @@ def test_validation():
656659
REGION = "us-west-2"
657660

658661

662+
def _check_framework_profile_deprecation_warning(framework_version, framework_name, warn_list):
663+
"""Check the collected warnings for a framework fromfile DeprecationWarning"""
664+
665+
thresh = version.parse("2.12") if framework_name == "tensorflow" else version.parse("2.0")
666+
actual = version.parse(framework_version)
667+
668+
if actual >= thresh:
669+
# should find a Framework profiling deprecation warning
670+
for w in warn_list:
671+
if issubclass(w.category, DeprecationWarning):
672+
if "Framework profiling" in str(w.message):
673+
return
674+
assert 0 # Should have found a deprecation and exited above
675+
676+
659677
def test_create_pytorch_estimator_with_framework_profile(
660678
sagemaker_session,
661679
pytorch_inference_version,
@@ -664,18 +682,24 @@ def test_create_pytorch_estimator_with_framework_profile(
664682
):
665683
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
666684

667-
pytorch = PyTorch(
668-
entry_point=SCRIPT_PATH,
669-
framework_version=pytorch_inference_version,
670-
py_version=pytorch_inference_py_version,
671-
role=ROLE,
672-
sagemaker_session=sagemaker_session,
673-
instance_count=INSTANCE_COUNT,
674-
instance_type=INSTANCE_TYPE,
675-
base_job_name="job",
676-
profiler_config=profiler_config,
677-
)
678-
assert pytorch._framework_name == "pytorch"
685+
with warnings.catch_warnings(record=True) as warn_list:
686+
warnings.simplefilter("always")
687+
framework_version = pytorch_inference_version
688+
pytorch = PyTorch(
689+
entry_point=SCRIPT_PATH,
690+
framework_version=framework_version,
691+
py_version=pytorch_inference_py_version,
692+
role=ROLE,
693+
sagemaker_session=sagemaker_session,
694+
instance_count=INSTANCE_COUNT,
695+
instance_type=INSTANCE_TYPE,
696+
base_job_name="job",
697+
profiler_config=profiler_config,
698+
)
699+
700+
_check_framework_profile_deprecation_warning(
701+
framework_version, pytorch._framework_name, warn_list
702+
)
679703

680704

681705
def test_create_pytorch_estimator_w_image_with_framework_profile(
@@ -696,35 +720,53 @@ def test_create_pytorch_estimator_w_image_with_framework_profile(
696720

697721
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
698722

699-
pytorch = PyTorch(
700-
entry_point=SCRIPT_PATH,
701-
role=ROLE,
702-
sagemaker_session=sagemaker_session,
703-
instance_count=INSTANCE_COUNT,
704-
instance_type=gpu_pytorch_instance_type,
705-
image_uri=image_uri,
706-
profiler_config=profiler_config,
707-
)
708-
assert pytorch._framework_name == "pytorch"
723+
with warnings.catch_warnings(record=True) as warn_list:
724+
warnings.simplefilter("always")
725+
pytorch = PyTorch(
726+
entry_point=SCRIPT_PATH,
727+
role=ROLE,
728+
sagemaker_session=sagemaker_session,
729+
instance_count=INSTANCE_COUNT,
730+
instance_type=gpu_pytorch_instance_type,
731+
image_uri=image_uri,
732+
profiler_config=profiler_config,
733+
)
734+
735+
framework_version = None
736+
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)
709737

738+
if image_tag is not None:
739+
framework_version = fw.framework_version_from_tag(image_tag)
710740

711-
def test_create_tf_estimator_with_framework_profile(
741+
if framework_version is not None:
742+
_check_framework_profile_deprecation_warning(
743+
framework_version, pytorch._framework_name, warn_list
744+
)
745+
746+
747+
def test_create_tf_estimator_with_framework_profile_212(
712748
sagemaker_session,
713749
default_framework_profile,
714750
):
715751
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
716752

717-
tf = TensorFlow(
718-
entry_point=SCRIPT_PATH,
719-
role=ROLE,
720-
framework_version="2.8",
721-
py_version="py39",
722-
sagemaker_session=sagemaker_session,
723-
instance_count=INSTANCE_COUNT,
724-
instance_type=INSTANCE_TYPE,
725-
profiler_config=profiler_config,
726-
)
727-
assert tf._framework_name == "tensorflow"
753+
with warnings.catch_warnings(record=True) as warn_list:
754+
warnings.simplefilter("always")
755+
framework_version = "2.12"
756+
tf = TensorFlow(
757+
entry_point=SCRIPT_PATH,
758+
role=ROLE,
759+
framework_version=framework_version,
760+
py_version="py39",
761+
sagemaker_session=sagemaker_session,
762+
instance_count=INSTANCE_COUNT,
763+
instance_type=INSTANCE_TYPE,
764+
profiler_config=profiler_config,
765+
)
766+
767+
_check_framework_profile_deprecation_warning(
768+
framework_version, tf._framework_name, warn_list
769+
)
728770

729771

730772
def test_create_tf_estimator_w_image_with_framework_profile(
@@ -744,13 +786,25 @@ def test_create_tf_estimator_w_image_with_framework_profile(
744786

745787
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
746788

747-
tf = TensorFlow(
748-
entry_point=SCRIPT_PATH,
749-
role=ROLE,
750-
sagemaker_session=sagemaker_session,
751-
instance_count=INSTANCE_COUNT,
752-
instance_type=INSTANCE_TYPE,
753-
image_uri=image_uri,
754-
profiler_config=profiler_config,
755-
)
756-
assert tf._framework_name == "tensorflow"
789+
with warnings.catch_warnings(record=True) as warn_list:
790+
warnings.simplefilter("always")
791+
tf = TensorFlow(
792+
entry_point=SCRIPT_PATH,
793+
role=ROLE,
794+
sagemaker_session=sagemaker_session,
795+
instance_count=INSTANCE_COUNT,
796+
instance_type=INSTANCE_TYPE,
797+
image_uri=image_uri,
798+
profiler_config=profiler_config,
799+
)
800+
801+
framework_version = None
802+
_, _, image_tag, _ = fw.framework_name_from_image(image_uri)
803+
804+
if image_tag is not None:
805+
framework_version = fw.framework_version_from_tag(image_tag)
806+
807+
if framework_version is not None:
808+
_check_framework_profile_deprecation_warning(
809+
framework_version, tf._framework_name, warn_list
810+
)

0 commit comments

Comments
 (0)