|
17 | 17 | import os
|
18 | 18 | import pytest
|
19 | 19 | from mock import ANY, MagicMock, Mock, patch
|
| 20 | +from packaging.version import Version |
20 | 21 |
|
21 | 22 | from sagemaker.pytorch import defaults
|
22 | 23 | from sagemaker.pytorch import PyTorch
|
@@ -585,46 +586,40 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version):
|
585 | 586 | warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
|
586 | 587 |
|
587 | 588 |
|
588 |
| -def test_pt_enable_sm_metrics(sagemaker_session, pytorch_full_version): |
| 589 | +def test_pt_enable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): |
589 | 590 | pytorch = _pytorch_estimator(
|
590 | 591 | sagemaker_session,
|
591 |
| - framework_version=pytorch_full_version, |
592 |
| - py_version="py3", |
| 592 | + framework_version=pytorch_version, |
| 593 | + py_version=pytorch_py_version, |
593 | 594 | enable_sagemaker_metrics=True,
|
594 | 595 | )
|
595 | 596 | assert pytorch.enable_sagemaker_metrics
|
596 | 597 |
|
597 | 598 |
|
598 |
| -def test_pt_disable_sm_metrics(sagemaker_session, pytorch_full_version): |
| 599 | +def test_pt_disable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): |
599 | 600 | pytorch = _pytorch_estimator(
|
600 | 601 | sagemaker_session,
|
601 |
| - framework_version=pytorch_full_version, |
602 |
| - py_version="py3", |
| 602 | + framework_version=pytorch_version, |
| 603 | + py_version=pytorch_py_version, |
603 | 604 | enable_sagemaker_metrics=False,
|
604 | 605 | )
|
605 | 606 | assert not pytorch.enable_sagemaker_metrics
|
606 | 607 |
|
607 | 608 |
|
608 |
| -def test_pt_disable_sm_metrics_if_pt_ver_is_less_than_1_15(sagemaker_session): |
609 |
| - for fw_version in ["1.1", "1.2"]: |
610 |
| - pytorch = _pytorch_estimator( |
611 |
| - sagemaker_session, framework_version=fw_version, py_version="py3" |
612 |
| - ) |
| 609 | +def test_pt_default_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): |
| 610 | + pytorch = _pytorch_estimator( |
| 611 | + sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version |
| 612 | + ) |
| 613 | + if Version(pytorch_version) < Version("1.3"): |
613 | 614 | assert pytorch.enable_sagemaker_metrics is None
|
614 |
| - |
615 |
| - |
616 |
| -def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session): |
617 |
| - for fw_version in ["1.3", "1.4", "2.0", "2.1"]: |
618 |
| - pytorch = _pytorch_estimator( |
619 |
| - sagemaker_session, framework_version=fw_version, py_version="py3" |
620 |
| - ) |
| 615 | + else: |
621 | 616 | assert pytorch.enable_sagemaker_metrics
|
622 | 617 |
|
623 | 618 |
|
624 |
| -def test_custom_image_estimator_deploy(sagemaker_session, pytorch_full_version): |
| 619 | +def test_custom_image_estimator_deploy(sagemaker_session, pytorch_version, pytorch_py_version): |
625 | 620 | custom_image = "mycustomimage:latest"
|
626 | 621 | pytorch = _pytorch_estimator(
|
627 |
| - sagemaker_session, framework_version=pytorch_full_version, py_version="py3" |
| 622 | + sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version |
628 | 623 | )
|
629 | 624 | pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
|
630 | 625 | model = pytorch.create_model(image=custom_image)
|
|
0 commit comments