|
28 | 28 | from sagemaker.session import s3_input
|
29 | 29 | from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
|
30 | 30 | from sagemaker.xgboost.defaults import (
|
| 31 | + XGBOOST_1P_VERSIONS, |
31 | 32 | XGBOOST_LATEST_VERSION,
|
| 33 | + XGBOOST_NAME, |
32 | 34 | XGBOOST_SUPPORTED_VERSIONS,
|
33 |
| - XGBOOST_VERSION_0_90_1, |
34 |
| - XGBOOST_VERSION_0_90, |
35 | 35 | XGBOOST_VERSION_EQUIVALENTS,
|
36 | 36 | )
|
37 | 37 | from sagemaker.xgboost.estimator import get_xgboost_image_uri
|
@@ -621,41 +621,68 @@ def get_image_uri(region_name, repo_name, repo_version=1):
|
621 | 621 | "in SageMaker Python SDK v2."
|
622 | 622 | )
|
623 | 623 |
|
624 |
| - if repo_name == "xgboost": |
625 |
| - if not _is_latest_xgboost_version(repo_version): |
626 |
| - logging.warning( |
627 |
| - "There is a more up to date SageMaker XGBoost image. " |
628 |
| - "To use the newer image, please set 'repo_version'=" |
629 |
| - "'%s'. For example:\n" |
630 |
| - "\tget_image_uri(region, 'xgboost', '%s').", |
631 |
| - XGBOOST_LATEST_VERSION, |
632 |
| - XGBOOST_LATEST_VERSION, |
633 |
| - ) |
| 624 | + repo_version = str(repo_version) |
| 625 | + |
| 626 | + if repo_name == XGBOOST_NAME: |
| 627 | + |
| 628 | + if repo_version in XGBOOST_1P_VERSIONS: |
| 629 | + _warn_newer_xgboost_image() |
| 630 | + return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version) |
634 | 631 |
|
635 |
| - if repo_version in [XGBOOST_VERSION_0_90] + _generate_version_equivalents( |
636 |
| - XGBOOST_VERSION_0_90_1 |
637 |
| - ): |
638 |
| - return get_xgboost_image_uri(region_name, XGBOOST_VERSION_0_90_1) |
| 632 | + if "-" not in repo_version: |
| 633 | + xgboost_version_matches = [ |
| 634 | + version |
| 635 | + for version in XGBOOST_SUPPORTED_VERSIONS |
| 636 | + if repo_version == version.split("-")[0] |
| 637 | + ] |
| 638 | + if xgboost_version_matches: |
| 639 | + # Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest. |
| 640 | + # When SageMaker version is not specified, we use the oldest one that matches |
| 641 | + # XGBoost version for backward compatibility. |
| 642 | + repo_version = xgboost_version_matches[0] |
639 | 643 |
|
640 |
| - supported_version = [ |
| 644 | + supported_framework_versions = [ |
641 | 645 | version
|
642 | 646 | for version in XGBOOST_SUPPORTED_VERSIONS
|
643 | 647 | if repo_version in _generate_version_equivalents(version)
|
644 | 648 | ]
|
645 |
| - if supported_version: |
646 |
| - return get_xgboost_image_uri(region_name, supported_version[0]) |
| 649 | + |
| 650 | + if not supported_framework_versions: |
| 651 | + raise ValueError( |
| 652 | + "SageMaker XGBoost version {} is not supported. Supported versions: {}".format( |
| 653 | + repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS) |
| 654 | + ) |
| 655 | + ) |
| 656 | + |
| 657 | + if not _is_latest_xgboost_version(repo_version): |
| 658 | + _warn_newer_xgboost_image() |
| 659 | + |
| 660 | + return get_xgboost_image_uri(region_name, supported_framework_versions[-1]) |
647 | 661 |
|
648 | 662 | repo = "{}:{}".format(repo_name, repo_version)
|
649 | 663 | return "{}/{}".format(registry(region_name, repo_name), repo)
|
650 | 664 |
|
651 | 665 |
|
| 666 | +def _warn_newer_xgboost_image(): |
| 667 | + """Print a warning when there is a newer XGBoost image""" |
| 668 | + logging.warning( |
| 669 | + "There is a more up to date SageMaker XGBoost image. " |
| 670 | + "To use the newer image, please set 'repo_version'=" |
| 671 | + "'%s'. For example:\n" |
| 672 | + "\tget_image_uri(region, '%s', '%s').", |
| 673 | + XGBOOST_LATEST_VERSION, |
| 674 | + XGBOOST_NAME, |
| 675 | + XGBOOST_LATEST_VERSION, |
| 676 | + ) |
| 677 | + |
| 678 | + |
652 | 679 | def _is_latest_xgboost_version(repo_version):
|
653 | 680 | """Compare xgboost image version with latest version
|
654 | 681 |
|
655 | 682 | Args:
|
656 | 683 | repo_version:
|
657 | 684 | """
|
658 |
| - if repo_version in (1, "latest"): |
| 685 | + if repo_version in XGBOOST_1P_VERSIONS: |
659 | 686 | return False
|
660 | 687 | return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)
|
661 | 688 |
|
|
0 commit comments