16
16
import pytest
17
17
import re
18
18
import time
19
+ import warnings
20
+ from packaging import version
19
21
20
22
from sagemaker import image_uris
23
+ import sagemaker .fw_utils as fw
21
24
from sagemaker .pytorch import PyTorch
22
25
from sagemaker .tensorflow import TensorFlow
23
26
from sagemaker .debugger .profiler_config import ProfilerConfig , FrameworkProfile
@@ -656,6 +659,21 @@ def test_validation():
656
659
REGION = "us-west-2"
657
660
658
661
662
+ def _check_framework_profile_deprecation_warning (framework_version , framework_name , warn_list ):
663
+ """Check the collected warnings for a framework fromfile DeprecationWarning"""
664
+
665
+ thresh = version .parse ("2.12" ) if framework_name == "tensorflow" else version .parse ("2.0" )
666
+ actual = version .parse (framework_version )
667
+
668
+ if actual >= thresh :
669
+ # should find a Framework profiling deprecation warning
670
+ for w in warn_list :
671
+ if issubclass (w .category , DeprecationWarning ):
672
+ if "Framework profiling" in str (w .message ):
673
+ return
674
+ assert 0 # Should have found a deprecation and exited above
675
+
676
+
659
677
def test_create_pytorch_estimator_with_framework_profile (
660
678
sagemaker_session ,
661
679
pytorch_inference_version ,
@@ -664,18 +682,24 @@ def test_create_pytorch_estimator_with_framework_profile(
664
682
):
665
683
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
666
684
667
- pytorch = PyTorch (
668
- entry_point = SCRIPT_PATH ,
669
- framework_version = pytorch_inference_version ,
670
- py_version = pytorch_inference_py_version ,
671
- role = ROLE ,
672
- sagemaker_session = sagemaker_session ,
673
- instance_count = INSTANCE_COUNT ,
674
- instance_type = INSTANCE_TYPE ,
675
- base_job_name = "job" ,
676
- profiler_config = profiler_config ,
677
- )
678
- assert pytorch ._framework_name == "pytorch"
685
+ with warnings .catch_warnings (record = True ) as warn_list :
686
+ warnings .simplefilter ("always" )
687
+ framework_version = pytorch_inference_version
688
+ pytorch = PyTorch (
689
+ entry_point = SCRIPT_PATH ,
690
+ framework_version = framework_version ,
691
+ py_version = pytorch_inference_py_version ,
692
+ role = ROLE ,
693
+ sagemaker_session = sagemaker_session ,
694
+ instance_count = INSTANCE_COUNT ,
695
+ instance_type = INSTANCE_TYPE ,
696
+ base_job_name = "job" ,
697
+ profiler_config = profiler_config ,
698
+ )
699
+
700
+ _check_framework_profile_deprecation_warning (
701
+ framework_version , pytorch ._framework_name , warn_list
702
+ )
679
703
680
704
681
705
def test_create_pytorch_estimator_w_image_with_framework_profile (
@@ -696,35 +720,53 @@ def test_create_pytorch_estimator_w_image_with_framework_profile(
696
720
697
721
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
698
722
699
- pytorch = PyTorch (
700
- entry_point = SCRIPT_PATH ,
701
- role = ROLE ,
702
- sagemaker_session = sagemaker_session ,
703
- instance_count = INSTANCE_COUNT ,
704
- instance_type = gpu_pytorch_instance_type ,
705
- image_uri = image_uri ,
706
- profiler_config = profiler_config ,
707
- )
708
- assert pytorch ._framework_name == "pytorch"
723
+ with warnings .catch_warnings (record = True ) as warn_list :
724
+ warnings .simplefilter ("always" )
725
+ pytorch = PyTorch (
726
+ entry_point = SCRIPT_PATH ,
727
+ role = ROLE ,
728
+ sagemaker_session = sagemaker_session ,
729
+ instance_count = INSTANCE_COUNT ,
730
+ instance_type = gpu_pytorch_instance_type ,
731
+ image_uri = image_uri ,
732
+ profiler_config = profiler_config ,
733
+ )
734
+
735
+ framework_version = None
736
+ _ , _ , image_tag , _ = fw .framework_name_from_image (image_uri )
709
737
738
+ if image_tag is not None :
739
+ framework_version = fw .framework_version_from_tag (image_tag )
710
740
711
- def test_create_tf_estimator_with_framework_profile (
741
+ if framework_version is not None :
742
+ _check_framework_profile_deprecation_warning (
743
+ framework_version , pytorch ._framework_name , warn_list
744
+ )
745
+
746
+
747
+ def test_create_tf_estimator_with_framework_profile_212 (
712
748
sagemaker_session ,
713
749
default_framework_profile ,
714
750
):
715
751
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
716
752
717
- tf = TensorFlow (
718
- entry_point = SCRIPT_PATH ,
719
- role = ROLE ,
720
- framework_version = "2.8" ,
721
- py_version = "py39" ,
722
- sagemaker_session = sagemaker_session ,
723
- instance_count = INSTANCE_COUNT ,
724
- instance_type = INSTANCE_TYPE ,
725
- profiler_config = profiler_config ,
726
- )
727
- assert tf ._framework_name == "tensorflow"
753
+ with warnings .catch_warnings (record = True ) as warn_list :
754
+ warnings .simplefilter ("always" )
755
+ framework_version = "2.12"
756
+ tf = TensorFlow (
757
+ entry_point = SCRIPT_PATH ,
758
+ role = ROLE ,
759
+ framework_version = framework_version ,
760
+ py_version = "py39" ,
761
+ sagemaker_session = sagemaker_session ,
762
+ instance_count = INSTANCE_COUNT ,
763
+ instance_type = INSTANCE_TYPE ,
764
+ profiler_config = profiler_config ,
765
+ )
766
+
767
+ _check_framework_profile_deprecation_warning (
768
+ framework_version , tf ._framework_name , warn_list
769
+ )
728
770
729
771
730
772
def test_create_tf_estimator_w_image_with_framework_profile (
@@ -744,13 +786,25 @@ def test_create_tf_estimator_w_image_with_framework_profile(
744
786
745
787
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
746
788
747
- tf = TensorFlow (
748
- entry_point = SCRIPT_PATH ,
749
- role = ROLE ,
750
- sagemaker_session = sagemaker_session ,
751
- instance_count = INSTANCE_COUNT ,
752
- instance_type = INSTANCE_TYPE ,
753
- image_uri = image_uri ,
754
- profiler_config = profiler_config ,
755
- )
756
- assert tf ._framework_name == "tensorflow"
789
+ with warnings .catch_warnings (record = True ) as warn_list :
790
+ warnings .simplefilter ("always" )
791
+ tf = TensorFlow (
792
+ entry_point = SCRIPT_PATH ,
793
+ role = ROLE ,
794
+ sagemaker_session = sagemaker_session ,
795
+ instance_count = INSTANCE_COUNT ,
796
+ instance_type = INSTANCE_TYPE ,
797
+ image_uri = image_uri ,
798
+ profiler_config = profiler_config ,
799
+ )
800
+
801
+ framework_version = None
802
+ _ , _ , image_tag , _ = fw .framework_name_from_image (image_uri )
803
+
804
+ if image_tag is not None :
805
+ framework_version = fw .framework_version_from_tag (image_tag )
806
+
807
+ if framework_version is not None :
808
+ _check_framework_profile_deprecation_warning (
809
+ framework_version , tf ._framework_name , warn_list
810
+ )
0 commit comments