|
27 | 27 | from sagemaker.model import NEO_IMAGE_ACCOUNT
|
28 | 28 | from sagemaker.session import s3_input
|
29 | 29 | from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
|
30 |
| -from sagemaker.xgboost.defaults import XGBOOST_VERSION_1, XGBOOST_SUPPORTED_VERSIONS |
| 30 | +from sagemaker.xgboost.defaults import ( |
| 31 | + XGBOOST_LATEST_VERSION, |
| 32 | + XGBOOST_SUPPORTED_VERSIONS, |
| 33 | + XGBOOST_VERSION_1, |
| 34 | + XGBOOST_VERSION_EQUIVALENTS, |
| 35 | +) |
31 | 36 | from sagemaker.xgboost.estimator import get_xgboost_image_uri
|
32 | 37 |
|
33 | 38 | logger = logging.getLogger(__name__)
|
@@ -611,24 +616,46 @@ def get_image_uri(region_name, repo_name, repo_version=1):
|
611 | 616 | repo_version:
|
612 | 617 | """
|
613 | 618 | if repo_name == "xgboost":
|
| 619 | + if not _is_latest_xgboost_version(repo_version): |
| 620 | + logging.warning( |
| 621 | + "There is a more up to date SageMaker XGBoost image. " |
| 622 | + "To use the newer image, please set 'repo_version'=" |
| 623 | + "'%s'. For example:\n" |
| 624 | + "\tget_image_uri(region, 'xgboost', '%s').", |
| 625 | + XGBOOST_LATEST_VERSION, |
| 626 | + XGBOOST_LATEST_VERSION, |
| 627 | + ) |
| 628 | + |
614 | 629 | if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]:
|
615 | 630 | return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1)
|
616 | 631 |
|
617 | 632 | supported_version = [
|
618 | 633 | version
|
619 | 634 | for version in XGBOOST_SUPPORTED_VERSIONS
|
620 |
| - if repo_version in (version, version + "-cpu-py3") |
| 635 | + if repo_version in _generate_version_equivalents(version) |
621 | 636 | ]
|
622 | 637 | if supported_version:
|
623 | 638 | return get_xgboost_image_uri(region_name, supported_version[0])
|
624 | 639 |
|
625 |
| - logging.warning( |
626 |
| - "There is a more up to date SageMaker XGBoost image. " |
627 |
| - "To use the newer image, please set 'repo_version'=" |
628 |
| - "'%s'. For example:\n" |
629 |
| - "\tget_image_uri(region, 'xgboost', '%s').", |
630 |
| - XGBOOST_VERSION_1, |
631 |
| - XGBOOST_VERSION_1, |
632 |
| - ) |
633 | 640 | repo = "{}:{}".format(repo_name, repo_version)
|
634 | 641 | return "{}/{}".format(registry(region_name, repo_name), repo)
|
| 642 | + |
| 643 | + |
| 644 | +def _is_latest_xgboost_version(repo_version): |
| 645 | + """Compare xgboost image version with latest version |
| 646 | +
|
| 647 | + Args: |
| 648 | + repo_version: |
| 649 | + """ |
| 650 | + if repo_version in (1, "latest"): |
| 651 | + return False |
| 652 | + return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION) |
| 653 | + |
| 654 | + |
| 655 | +def _generate_version_equivalents(version): |
| 656 | + """Returns a list of version equivalents for XGBoost |
| 657 | +
|
| 658 | + Args: |
| 659 | + version: |
| 660 | + """ |
| 661 | + return [version + suffix for suffix in XGBOOST_VERSION_EQUIVALENTS] + [version] |
0 commit comments