14
14
from __future__ import absolute_import
15
15
16
16
import json
17
+ import logging
17
18
import os
18
19
19
20
from sagemaker import utils
20
21
22
+ logger = logging .getLogger (__name__ )
23
+
21
24
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
22
25
23
26
24
- def retrieve (framework , region , version = None , py_version = None , instance_type = None ):
27
+ def retrieve (
28
+ framework ,
29
+ region ,
30
+ version = None ,
31
+ py_version = None ,
32
+ instance_type = None ,
33
+ accelerator_type = None ,
34
+ image_scope = None ,
35
+ ):
25
36
"""Retrieves the ECR URI for the Docker image matching the given arguments.
26
37
27
38
Args:
@@ -34,28 +45,48 @@ def retrieve(framework, region, version=None, py_version=None, instance_type=Non
34
45
instance_type (str): The SageMaker instance type. For supported types, see
35
46
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
36
47
there are different images for different processor types.
48
+ accelerator_type (str): Elastic Inference accelerator type. For more, see
49
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
50
+ image_scope (str): The image type, i.e. what it is used for.
51
+ Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
52
+ ``image_scope`` is ignored.
37
53
38
54
Returns:
39
55
str: the ECR URI for the corresponding SageMaker Docker image.
40
56
41
57
Raises:
42
- ValueError: If the framework version, Python version, processor type, or region is
43
- not supported given the other arguments.
58
+ ValueError: If the combination of arguments specified is not supported.
44
59
"""
45
- config = config_for_framework (framework )
60
+ config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
46
61
version_config = config ["versions" ][_version_for_config (version , config , framework )]
47
62
63
+ py_version = _validate_py_version_and_set_if_needed (py_version , version_config )
64
+ version_config = version_config .get (py_version ) or version_config
65
+
48
66
registry = _registry_from_region (region , version_config ["registries" ])
49
67
hostname = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )["hostname" ]
50
68
51
69
repo = version_config ["repository" ]
52
-
53
- _validate_py_version (py_version , version_config ["py_versions" ], framework , version )
54
- tag = "{}-{}-{}" .format (version , _processor (instance_type , config ["processors" ]), py_version )
70
+ tag = _format_tag (version , _processor (instance_type , config ["processors" ]), py_version )
55
71
56
72
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo , tag = tag )
57
73
58
74
75
+ def _config_for_framework_and_scope (framework , image_scope , accelerator_type = None ):
76
+ """Loads the JSON config for the given framework and image scope."""
77
+ config = config_for_framework (framework )
78
+
79
+ if accelerator_type :
80
+ if image_scope not in ("eia" , "inference" ):
81
+ logger .warning (
82
+ "Elastic inference is for inference only. Ignoring image scope: %s." , image_scope
83
+ )
84
+ image_scope = "eia"
85
+
86
+ _validate_arg ("image scope" , image_scope , config .get ("scope" , config .keys ()))
87
+ return config if "scope" in config else config [image_scope ]
88
+
89
+
59
90
def config_for_framework (framework ):
60
91
"""Loads the JSON config for the given framework."""
61
92
fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
@@ -69,27 +100,13 @@ def _version_for_config(version, config, framework):
69
100
if version in config ["version_aliases" ].keys ():
70
101
return config ["version_aliases" ][version ]
71
102
72
- available_versions = config ["versions" ].keys ()
73
- if version in available_versions :
74
- return version
75
-
76
- raise ValueError (
77
- "Unsupported {} version: {}. "
78
- "You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
79
- "Supported version(s): {}." .format (framework , version , ", " .join (available_versions ))
80
- )
103
+ _validate_arg ("{} version" .format (framework ), version , config ["versions" ].keys ())
104
+ return version
81
105
82
106
83
107
def _registry_from_region (region , registry_dict ):
84
108
"""Returns the ECR registry (AWS account number) for the given region."""
85
- available_regions = registry_dict .keys ()
86
- if region not in available_regions :
87
- raise ValueError (
88
- "Unsupported region: {}. You may need to upgrade "
89
- "your SDK version (pip install -U sagemaker) for newer regions. "
90
- "Supported region(s): {}." .format (region , ", " .join (available_regions ))
91
- )
92
-
109
+ _validate_arg ("region" , region , registry_dict .keys ())
93
110
return registry_dict [region ]
94
111
95
112
@@ -106,22 +123,37 @@ def _processor(instance_type, available_processors):
106
123
family = instance_type .split ("." )[1 ]
107
124
processor = "gpu" if family [0 ] in ("g" , "p" ) else "cpu"
108
125
109
- if processor in available_processors :
110
- return processor
111
-
112
- raise ValueError (
113
- "Unsupported processor type: {} (for {}). "
114
- "Supported type(s): {}." .format (processor , instance_type , ", " .join (available_processors ))
115
- )
126
+ _validate_arg ("processor" , processor , available_processors )
127
+ return processor
116
128
117
129
118
- def _validate_py_version (py_version , available_versions , framework , fw_version ):
130
+ def _validate_py_version_and_set_if_needed (py_version , version_config ):
119
131
"""Checks if the Python version is one of the supported versions."""
120
- if py_version not in available_versions :
132
+ available_versions = version_config .get ("py_versions" , version_config .keys ())
133
+
134
+ if len (available_versions ) == 0 :
135
+ if py_version :
136
+ logger .info ("Ignoring unnecessary Python version: %s." , py_version )
137
+ return None
138
+
139
+ if py_version is None and len (available_versions ) == 1 :
140
+ logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
141
+ return available_versions [0 ]
142
+
143
+ _validate_arg ("Python version" , py_version , available_versions )
144
+ return py_version
145
+
146
+
147
+ def _validate_arg (arg_name , arg , available_options ):
148
+ """Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
149
+ if arg not in available_options :
121
150
raise ValueError (
122
- "Unsupported Python version for {} {}: {}. You may need to upgrade "
123
- "your SDK version (pip install -U sagemaker) for newer versions. "
124
- "Supported Python version(s): {}." .format (
125
- framework , fw_version , py_version , ", " .join (available_versions )
126
- )
151
+ "Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
152
+ "(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
153
+ "{options}." .format (arg_name = arg_name , arg = arg , options = ", " .join (available_options ))
127
154
)
155
+
156
+
157
+ def _format_tag (version , processor , py_version ):
158
+ """Creates a tag for the image URI."""
159
+ return "-" .join ([x for x in (version , processor , py_version ) if x ])
0 commit comments