Skip to content

Commit 70bddb2

Browse files
committed
address PR comment
1 parent e5e327b commit 70bddb2

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

tests/unit/test_pytorch.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import pytest
1919
from mock import ANY, MagicMock, Mock, patch
20+
from packaging.version import Version
2021

2122
from sagemaker.pytorch import defaults
2223
from sagemaker.pytorch import PyTorch
@@ -585,46 +586,40 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version):
585586
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
586587

587588

588-
def test_pt_enable_sm_metrics(sagemaker_session, pytorch_full_version):
589+
def test_pt_enable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version):
589590
pytorch = _pytorch_estimator(
590591
sagemaker_session,
591-
framework_version=pytorch_full_version,
592-
py_version="py3",
592+
framework_version=pytorch_version,
593+
py_version=pytorch_py_version,
593594
enable_sagemaker_metrics=True,
594595
)
595596
assert pytorch.enable_sagemaker_metrics
596597

597598

598-
def test_pt_disable_sm_metrics(sagemaker_session, pytorch_full_version):
599+
def test_pt_disable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version):
599600
pytorch = _pytorch_estimator(
600601
sagemaker_session,
601-
framework_version=pytorch_full_version,
602-
py_version="py3",
602+
framework_version=pytorch_version,
603+
py_version=pytorch_py_version,
603604
enable_sagemaker_metrics=False,
604605
)
605606
assert not pytorch.enable_sagemaker_metrics
606607

607608

608-
def test_pt_disable_sm_metrics_if_pt_ver_is_less_than_1_15(sagemaker_session):
609-
for fw_version in ["1.1", "1.2"]:
610-
pytorch = _pytorch_estimator(
611-
sagemaker_session, framework_version=fw_version, py_version="py3"
612-
)
609+
def test_pt_default_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version):
610+
pytorch = _pytorch_estimator(
611+
sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version
612+
)
613+
if Version(pytorch_version) < Version("1.3"):
613614
assert pytorch.enable_sagemaker_metrics is None
614-
615-
616-
def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
617-
for fw_version in ["1.3", "1.4", "2.0", "2.1"]:
618-
pytorch = _pytorch_estimator(
619-
sagemaker_session, framework_version=fw_version, py_version="py3"
620-
)
615+
else:
621616
assert pytorch.enable_sagemaker_metrics
622617

623618

624-
def test_custom_image_estimator_deploy(sagemaker_session, pytorch_full_version):
619+
def test_custom_image_estimator_deploy(sagemaker_session, pytorch_version, pytorch_py_version):
625620
custom_image = "mycustomimage:latest"
626621
pytorch = _pytorch_estimator(
627-
sagemaker_session, framework_version=pytorch_full_version, py_version="py3"
622+
sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version
628623
)
629624
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
630625
model = pytorch.create_model(image=custom_image)

0 commit comments

Comments
 (0)