40
40
VulnerableJumpStartModelError ,
41
41
)
42
42
from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartVersionedModelId
43
- from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
43
+ from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec , get_prototype_manifest
44
44
from mock import MagicMock
45
45
46
46
@@ -1178,7 +1178,7 @@ def test_mime_type_enum_from_str():
1178
1178
class TestIsValidModelId (TestCase ):
1179
1179
@patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest" )
1180
1180
@patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs" )
1181
- def test_validate_model_id_and_get_type_true (
1181
+ def test_validate_model_id_and_get_type_open_weights (
1182
1182
self ,
1183
1183
mock_get_model_specs : Mock ,
1184
1184
mock_get_manifest : Mock ,
@@ -1197,11 +1197,11 @@ def test_validate_model_id_and_get_type_true(
1197
1197
)
1198
1198
1199
1199
with patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" , patched ):
1200
- self . assertTrue ( utils .validate_model_id_and_get_type ("bee" ))
1200
+ assert utils .validate_model_id_and_get_type ("bee" ) == JumpStartModelType . OPEN_WEIGHTS
1201
1201
mock_get_manifest .assert_called_with (
1202
1202
region = JUMPSTART_DEFAULT_REGION_NAME ,
1203
1203
s3_client = mock_s3_client_value ,
1204
- model_type = JumpStartModelType .PROPRIETARY ,
1204
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
1205
1205
)
1206
1206
mock_get_model_specs .assert_not_called ()
1207
1207
@@ -1215,25 +1215,30 @@ def test_validate_model_id_and_get_type_true(
1215
1215
]
1216
1216
1217
1217
mock_get_model_specs .return_value = Mock (training_supported = True )
1218
- self .assertTrue (
1218
+ self .assertIsNone (
1219
+ utils .validate_model_id_and_get_type (
1220
+ "invalid" , script = JumpStartScriptScope .TRAINING
1221
+ )
1222
+ )
1223
+ assert (
1219
1224
utils .validate_model_id_and_get_type ("bee" , script = JumpStartScriptScope .TRAINING )
1225
+ == JumpStartModelType .OPEN_WEIGHTS
1220
1226
)
1227
+
1221
1228
mock_get_manifest .assert_called_with (
1222
1229
region = JUMPSTART_DEFAULT_REGION_NAME ,
1223
1230
s3_client = mock_s3_client_value ,
1224
- model_type = JumpStartModelType .PROPRIETARY ,
1231
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
1225
1232
)
1226
1233
1227
1234
@patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest" )
1228
1235
@patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs" )
1229
- def test_validate_model_id_and_get_type_false (
1236
+ def test_validate_model_id_and_get_type_invalid (
1230
1237
self , mock_get_model_specs : Mock , mock_get_manifest : Mock
1231
1238
):
1232
- mock_get_manifest .return_value = [
1233
- Mock (model_id = "ay" ),
1234
- Mock (model_id = "bee" ),
1235
- Mock (model_id = "see" ),
1236
- ]
1239
+ mock_get_manifest .side_effect = (
1240
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest (region , model_type )
1241
+ )
1237
1242
1238
1243
mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1239
1244
mock_s3_client_value = mock_session_value .s3_client
@@ -1244,10 +1249,10 @@ def test_validate_model_id_and_get_type_false(
1244
1249
1245
1250
with patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" , patched ):
1246
1251
1247
- self .assertFalse (utils .validate_model_id_and_get_type ("dee" ))
1248
- self .assertFalse (utils .validate_model_id_and_get_type ("" ))
1249
- self .assertFalse (utils .validate_model_id_and_get_type (None ))
1250
- self .assertFalse (utils .validate_model_id_and_get_type (set ()))
1252
+ self .assertIsNone (utils .validate_model_id_and_get_type ("dee" ))
1253
+ self .assertIsNone (utils .validate_model_id_and_get_type ("" ))
1254
+ self .assertIsNone (utils .validate_model_id_and_get_type (None ))
1255
+ self .assertIsNone (utils .validate_model_id_and_get_type (set ()))
1251
1256
1252
1257
mock_get_manifest .assert_called ()
1253
1258
@@ -1256,53 +1261,44 @@ def test_validate_model_id_and_get_type_false(
1256
1261
mock_get_manifest .reset_mock ()
1257
1262
mock_get_model_specs .reset_mock ()
1258
1263
1259
- mock_get_manifest .return_value = [
1260
- Mock (model_id = "ay" ),
1261
- Mock (model_id = "bee" ),
1262
- Mock (model_id = "see" ),
1263
- ]
1264
- self .assertFalse (
1265
- utils .validate_model_id_and_get_type ("dee" , script = JumpStartScriptScope .TRAINING )
1264
+ assert (
1265
+ utils .validate_model_id_and_get_type ("ai21-summarization" )
1266
+ == JumpStartModelType .PROPRIETARY
1266
1267
)
1268
+ self .assertIsNone (utils .validate_model_id_and_get_type ("ai21-summarization-2" ))
1269
+
1267
1270
mock_get_manifest .assert_called_with (
1268
1271
region = JUMPSTART_DEFAULT_REGION_NAME ,
1269
1272
s3_client = mock_s3_client_value ,
1270
1273
model_type = JumpStartModelType .PROPRIETARY ,
1271
1274
)
1272
1275
1273
- mock_get_manifest .reset_mock ()
1274
-
1275
- self .assertFalse (
1276
+ self .assertIsNone (
1276
1277
utils .validate_model_id_and_get_type ("dee" , script = JumpStartScriptScope .TRAINING )
1277
1278
)
1278
- self .assertFalse (
1279
+ self .assertIsNone (
1279
1280
utils .validate_model_id_and_get_type ("" , script = JumpStartScriptScope .TRAINING )
1280
1281
)
1281
- self .assertFalse (
1282
+ self .assertIsNone (
1282
1283
utils .validate_model_id_and_get_type (None , script = JumpStartScriptScope .TRAINING )
1283
1284
)
1284
- self .assertFalse (
1285
+ self .assertIsNone (
1285
1286
utils .validate_model_id_and_get_type (set (), script = JumpStartScriptScope .TRAINING )
1286
1287
)
1287
1288
1288
- mock_get_model_specs .assert_not_called ()
1289
+ assert (
1290
+ utils .validate_model_id_and_get_type ("pytorch-eqa-bert-base-cased" )
1291
+ == JumpStartModelType .OPEN_WEIGHTS
1292
+ )
1289
1293
mock_get_manifest .assert_called_with (
1290
1294
region = JUMPSTART_DEFAULT_REGION_NAME ,
1291
1295
s3_client = mock_s3_client_value ,
1292
- model_type = JumpStartModelType .PROPRIETARY ,
1296
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
1293
1297
)
1294
1298
1295
- mock_get_manifest .reset_mock ()
1296
- mock_get_model_specs .reset_mock ()
1297
-
1298
- mock_get_model_specs .return_value = Mock (training_supported = False )
1299
- self .assertTrue (
1300
- utils .validate_model_id_and_get_type ("ay" , script = JumpStartScriptScope .TRAINING )
1301
- )
1302
- mock_get_manifest .assert_called_with (
1303
- region = JUMPSTART_DEFAULT_REGION_NAME ,
1304
- s3_client = mock_s3_client_value ,
1305
- model_type = JumpStartModelType .PROPRIETARY ,
1299
+ with pytest .raises (ValueError ):
1300
+ utils .validate_model_id_and_get_type (
1301
+ "ai21-summarization" , script = JumpStartScriptScope .TRAINING
1306
1302
)
1307
1303
1308
1304
0 commit comments