@@ -89,7 +89,24 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
89
89
)
90
90
image_scope = "eia"
91
91
92
- _validate_arg ("image scope" , image_scope , config .get ("scope" , config .keys ()))
92
+ available_scopes = config .get ("scope" , config .keys ())
93
+ if len (available_scopes ) == 1 :
94
+ if image_scope and image_scope != available_scopes [0 ]:
95
+ logger .warning (
96
+ "Defaulting to only supported image scope: %s. Ignoring image scope: %s." ,
97
+ available_scopes [0 ],
98
+ image_scope ,
99
+ )
100
+ image_scope = available_scopes [0 ]
101
+
102
+ if not image_scope and "scope" in config and set (available_scopes ) == {"training" , "inference" }:
103
+ logger .info (
104
+ "Same images used for training and inference. Defaulting to image scope: %s." ,
105
+ available_scopes [0 ],
106
+ )
107
+ image_scope = available_scopes [0 ]
108
+
109
+ _validate_arg (image_scope , available_scopes , "image scope" )
93
110
return config if "scope" in config else config [image_scope ]
94
111
95
112
@@ -116,8 +133,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
116
133
117
134
return available_versions [0 ]
118
135
119
- _validate_arg ("{} version" .format (framework ), version , available_versions + aliased_versions )
120
-
136
+ _validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
121
137
return version
122
138
123
139
@@ -132,7 +148,7 @@ def _version_for_config(version, config):
132
148
133
149
def _registry_from_region (region , registry_dict ):
134
150
"""Returns the ECR registry (AWS account number) for the given region."""
135
- _validate_arg (" region" , region , registry_dict .keys ())
151
+ _validate_arg (region , registry_dict .keys (), "region" )
136
152
return registry_dict [region ]
137
153
138
154
@@ -159,7 +175,7 @@ def _processor(instance_type, available_processors):
159
175
family = instance_type .split ("." )[1 ]
160
176
processor = "gpu" if family [0 ] in ("g" , "p" ) else "cpu"
161
177
162
- _validate_arg (" processor" , processor , available_processors )
178
+ _validate_arg (processor , available_processors , "processor" )
163
179
return processor
164
180
165
181
@@ -179,11 +195,11 @@ def _validate_py_version_and_set_if_needed(py_version, version_config):
179
195
logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
180
196
return available_versions [0 ]
181
197
182
- _validate_arg ("Python version" , py_version , available_versions )
198
+ _validate_arg (py_version , available_versions , "Python version" )
183
199
return py_version
184
200
185
201
186
- def _validate_arg (arg_name , arg , available_options ):
202
+ def _validate_arg (arg , available_options , arg_name ):
187
203
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
188
204
if arg not in available_options :
189
205
raise ValueError (
0 commit comments