Skip to content

Commit 50a6da2

Browse files
mariumofknikure
authored andcommitted
fix: added framework profiling version deprecation warning
1 parent 1268d6a commit 50a6da2

File tree

4 files changed

+165
-1
lines changed

4 files changed

+165
-1
lines changed

src/sagemaker/fw_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tempfile
2323
from collections import namedtuple
2424
from typing import Optional, Union, Dict
25+
from packaging import version
2526

2627
import sagemaker.image_uris
2728
from sagemaker.session_settings import SessionSettings
@@ -638,6 +639,39 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
638639
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
639640

640641

642+
def profiler_config_deprecation_warning(
643+
profiler_config, image_uri, framework_name, framework_version
644+
):
645+
"""
646+
Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0
647+
"""
648+
if profiler_config is None or profiler_config.framework_profile_params is None:
649+
return
650+
651+
if framework_name not in ("pytorch", "tensorflow"):
652+
return
653+
654+
if framework_version is None:
655+
logger.warning(framework_version)
656+
framework_name, _, image_tag, _ = framework_name_from_image(image_uri)
657+
658+
if image_tag is not None:
659+
framework_version = framework_version_from_tag(image_tag)
660+
661+
if framework_version is not None:
662+
framework_profile_thresh = (
663+
version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12")
664+
)
665+
framework_profile = version.parse(framework_version)
666+
if framework_profile >= framework_profile_thresh:
667+
logger.warning(
668+
"Framework profiling is deprecated from %s version %s."
669+
"No framework metrics will be collected.",
670+
framework_name,
671+
framework_version,
672+
)
673+
674+
641675
def validate_smdistributed(
642676
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
643677
):

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
2727
validate_distribution,
28+
profiler_config_deprecation_warning,
2829
)
2930
from sagemaker.pytorch import defaults
3031
from sagemaker.pytorch.model import PyTorchModel
@@ -298,6 +299,11 @@ def __init__(
298299
)
299300
self.compiler_config = compiler_config
300301

302+
if "profiler_config" in kwargs:
303+
profiler_config_deprecation_warning(
304+
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
305+
)
306+
301307
def _pytorch_distribution_configuration(self, distribution):
302308
"""Returns a dict of distribution config for PyTorch training
303309

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def __init__(
235235
compiler_config.validate(self)
236236
self.compiler_config = compiler_config
237237

238+
if "profiler_config" in kwargs:
239+
fw.profiler_config_deprecation_warning(
240+
kwargs["profiler_config"], image_uri, self._framework_name, framework_version
241+
)
242+
238243
def _validate_args(self, py_version):
239244
"""Placeholder docstring"""
240245

tests/unit/test_profiler_config.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import pytest
1617
import re
1718
import time
1819

19-
20+
from sagemaker import image_uris
21+
from sagemaker.pytorch import PyTorch
22+
from sagemaker.tensorflow import TensorFlow
2023
from sagemaker.debugger.profiler_config import ProfilerConfig, FrameworkProfile
2124

2225
from sagemaker.debugger.metrics_config import (
@@ -643,3 +646,119 @@ def test_validation():
643646

644647
with pytest.raises(AssertionError, match=ErrorMessages.INVALID_CPROFILE_TIMER.value):
645648
PythonProfilingConfig(cprofile_timer="bad_cprofile_timer")
649+
650+
651+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
652+
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
653+
INSTANCE_COUNT = 1
654+
INSTANCE_TYPE = "ml.p3.2xlarge"
655+
ROLE = "Dummy"
656+
REGION = "us-west-2"
657+
658+
659+
def test_create_pytorch_estimator_with_framework_profile(
660+
sagemaker_session,
661+
pytorch_inference_version,
662+
pytorch_inference_py_version,
663+
default_framework_profile,
664+
):
665+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
666+
667+
container_log_level = '"logging.INFO"'
668+
source_dir = "s3://mybucket/source"
669+
pytorch = PyTorch(
670+
entry_point=SCRIPT_PATH,
671+
framework_version=pytorch_inference_version,
672+
py_version=pytorch_inference_py_version,
673+
role=ROLE,
674+
sagemaker_session=sagemaker_session,
675+
instance_count=INSTANCE_COUNT,
676+
instance_type=INSTANCE_TYPE,
677+
base_job_name="job",
678+
profiler_config=profiler_config,
679+
)
680+
681+
682+
def test_create_pytorch_estimator_w_image_with_framework_profile(
683+
sagemaker_session,
684+
pytorch_inference_version,
685+
pytorch_inference_py_version,
686+
gpu_pytorch_instance_type,
687+
default_framework_profile,
688+
):
689+
image_uri = image_uris.retrieve(
690+
"pytorch",
691+
REGION,
692+
version=pytorch_inference_version,
693+
py_version=pytorch_inference_py_version,
694+
instance_type=gpu_pytorch_instance_type,
695+
image_scope="inference",
696+
)
697+
698+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
699+
700+
pytorch = PyTorch(
701+
entry_point=SCRIPT_PATH,
702+
role=ROLE,
703+
sagemaker_session=sagemaker_session,
704+
instance_count=INSTANCE_COUNT,
705+
instance_type=gpu_pytorch_instance_type,
706+
image_uri=image_uri,
707+
profiler_config=profiler_config,
708+
)
709+
710+
711+
"""
712+
def test_create_tf_estimator_with_framework_profile(
713+
sagemaker_session,
714+
tensorflow_inference_version,
715+
tensorflow_inference_py_version,
716+
default_framework_profile,
717+
):
718+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
719+
720+
tf = TensorFlow(
721+
entry_point=SCRIPT_PATH,
722+
role=ROLE,
723+
framework_version=tensorflow_inference_version,
724+
py_version=tensorflow_inference_py_version,
725+
sagemaker_session=sagemaker_session,
726+
instance_count=INSTANCE_COUNT,
727+
instance_type=INSTANCE_TYPE,
728+
profiler_config=profiler_config,
729+
)
730+
... ValueError: TF 1.5 supports only legacy mode.
731+
Please supply the image URI directly with
732+
'image_uri=520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.5-cpu-py2'
733+
and set 'model_dir=False' etc etc
734+
"""
735+
736+
737+
def test_create_tf_estimator_w_image_with_framework_profile(
738+
sagemaker_session,
739+
tensorflow_inference_version,
740+
tensorflow_inference_py_version,
741+
default_framework_profile,
742+
):
743+
image_uri = image_uris.retrieve(
744+
"tensorflow",
745+
REGION,
746+
version=tensorflow_inference_version,
747+
py_version=tensorflow_inference_py_version,
748+
instance_type=INSTANCE_TYPE,
749+
image_scope="inference",
750+
)
751+
752+
assert image_uri is not None
753+
754+
profiler_config = ProfilerConfig(framework_profile_params=default_framework_profile)
755+
756+
tf = TensorFlow(
757+
entry_point=SCRIPT_PATH,
758+
role=ROLE,
759+
sagemaker_session=sagemaker_session,
760+
instance_count=INSTANCE_COUNT,
761+
instance_type=INSTANCE_TYPE,
762+
image_uri=image_uri,
763+
profiler_config=profiler_config,
764+
)

0 commit comments

Comments
 (0)