@@ -36,10 +36,10 @@ def retrieve(
36
36
"""Retrieves the ECR URI for the Docker image matching the given arguments.
37
37
38
38
Args:
39
- framework (str): The name of the framework.
39
+ framework (str): The name of the framework or algorithm .
40
40
region (str): The AWS region.
41
- version (str): The framework version. This is required if there is
42
- more than one supported version for the given framework.
41
+ version (str): The framework or algorithm version. This is required if there is
42
+ more than one supported version for the given framework or algorithm .
43
43
py_version (str): The Python version. This is required if there is
44
44
more than one supported Python version for the given framework version.
45
45
instance_type (str): The SageMaker instance type. For supported types, see
@@ -58,7 +58,9 @@ def retrieve(
58
58
ValueError: If the combination of arguments specified is not supported.
59
59
"""
60
60
config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
61
- version_config = config ["versions" ][_version_for_config (version , config , framework )]
61
+
62
+ version = _validate_version_and_set_if_needed (version , config , framework )
63
+ version_config = config ["versions" ][_version_for_config (version , config )]
62
64
63
65
py_version = _validate_py_version_and_set_if_needed (py_version , version_config )
64
66
version_config = version_config .get (py_version ) or version_config
@@ -67,7 +69,7 @@ def retrieve(
67
69
hostname = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )["hostname" ]
68
70
69
71
repo = version_config ["repository" ]
70
- tag = _format_tag (version , _processor (instance_type , config [ "processors" ] ), py_version )
72
+ tag = _format_tag (version , _processor (instance_type , config . get ( "processors" ) ), py_version )
71
73
72
74
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo , tag = tag )
73
75
@@ -94,13 +96,33 @@ def config_for_framework(framework):
94
96
return json .load (f )
95
97
96
98
97
- def _version_for_config (version , config , framework ):
99
+ def _validate_version_and_set_if_needed (version , config , framework ):
100
+ """Checks if the framework/algorithm version is one of the supported versions."""
101
+ available_versions = list (config ["versions" ].keys ())
102
+
103
+ if len (available_versions ) == 1 :
104
+ log_message = "Defaulting to the only supported framework/algorithm version: {}." .format (
105
+ available_versions [0 ]
106
+ )
107
+ if version and version != available_versions [0 ]:
108
+ logger .warning ("%s Ignoring framework/algorithm version: %s." , log_message , version )
109
+ elif not version :
110
+ logger .info (log_message )
111
+
112
+ return available_versions [0 ]
113
+
114
+ available_versions += list (config .get ("version_aliases" , {}).keys ())
115
+ _validate_arg ("{} version" .format (framework ), version , available_versions )
116
+
117
+ return version
118
+
119
+
120
+ def _version_for_config (version , config ):
98
121
"""Returns the version string for retrieving a framework version's specific config."""
99
122
if "version_aliases" in config :
100
123
if version in config ["version_aliases" ].keys ():
101
124
return config ["version_aliases" ][version ]
102
125
103
- _validate_arg ("{} version" .format (framework ), version , config ["versions" ].keys ())
104
126
return version
105
127
106
128
@@ -112,6 +134,10 @@ def _registry_from_region(region, registry_dict):
112
134
113
135
def _processor (instance_type , available_processors ):
114
136
"""Returns the processor type for the given instance type."""
137
+ if not available_processors :
138
+ logger .info ("Ignoring unnecessary instance type: %s." , instance_type )
139
+ return None
140
+
115
141
if instance_type .startswith ("local" ):
116
142
processor = "cpu" if instance_type == "local" else "gpu"
117
143
elif not instance_type .startswith ("ml." ):
@@ -129,9 +155,12 @@ def _processor(instance_type, available_processors):
129
155
130
156
def _validate_py_version_and_set_if_needed (py_version , version_config ):
131
157
"""Checks if the Python version is one of the supported versions."""
132
- available_versions = version_config .get ("py_versions" , version_config .keys ())
158
+ if "repository" in version_config :
159
+ available_versions = version_config .get ("py_versions" )
160
+ else :
161
+ available_versions = list (version_config .keys ())
133
162
134
- if len ( available_versions ) == 0 :
163
+ if not available_versions :
135
164
if py_version :
136
165
logger .info ("Ignoring unnecessary Python version: %s." , py_version )
137
166
return None
0 commit comments