@@ -664,8 +664,6 @@ def test_create_pytorch_estimator_with_framework_profile(
664
664
):
665
665
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
666
666
667
- container_log_level = '"logging.INFO"'
668
- source_dir = "s3://mybucket/source"
669
667
pytorch = PyTorch (
670
668
entry_point = SCRIPT_PATH ,
671
669
framework_version = pytorch_inference_version ,
@@ -677,6 +675,7 @@ def test_create_pytorch_estimator_with_framework_profile(
677
675
base_job_name = "job" ,
678
676
profiler_config = profiler_config ,
679
677
)
678
+ assert pytorch ._framework_name == "pytorch"
680
679
681
680
682
681
def test_create_pytorch_estimator_w_image_with_framework_profile (
@@ -706,6 +705,7 @@ def test_create_pytorch_estimator_w_image_with_framework_profile(
706
705
image_uri = image_uri ,
707
706
profiler_config = profiler_config ,
708
707
)
708
+ assert pytorch ._framework_name == "pytorch"
709
709
710
710
711
711
def test_create_tf_estimator_with_framework_profile (
@@ -724,14 +724,7 @@ def test_create_tf_estimator_with_framework_profile(
724
724
instance_type = INSTANCE_TYPE ,
725
725
profiler_config = profiler_config ,
726
726
)
727
-
728
-
729
- """
730
- ... ValueError: TF 1.5 supports only legacy mode.
731
- Please supply the image URI directly with
732
- 'image_uri=520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.5-cpu-py2'
733
- and set 'model_dir=False' etc etc
734
- """
727
+ assert tf ._framework_name == "tensorflow"
735
728
736
729
737
730
def test_create_tf_estimator_w_image_with_framework_profile (
@@ -749,8 +742,6 @@ def test_create_tf_estimator_w_image_with_framework_profile(
749
742
image_scope = "inference" ,
750
743
)
751
744
752
- assert image_uri is not None
753
-
754
745
profiler_config = ProfilerConfig (framework_profile_params = default_framework_profile )
755
746
756
747
tf = TensorFlow (
@@ -762,3 +753,4 @@ def test_create_tf_estimator_w_image_with_framework_profile(
762
753
image_uri = image_uri ,
763
754
profiler_config = profiler_config ,
764
755
)
756
+ assert tf ._framework_name == "tensorflow"
0 commit comments