|
25 | 25 | DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
26 | 26 | JUMPSTART_DEFAULT_REGION_NAME,
|
27 | 27 | )
|
| 28 | +from sagemaker.jumpstart.constants import ( |
| 29 | + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, |
| 30 | + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, |
| 31 | +) |
28 | 32 | from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType
|
29 | 33 |
|
30 | 34 | from sagemaker.jumpstart.model import JumpStartModel
|
|
41 | 45 | overwrite_dictionary,
|
42 | 46 | get_special_model_spec_for_inference_component_based_endpoint,
|
43 | 47 | get_prototype_manifest,
|
| 48 | + get_prototype_model_spec, |
44 | 49 | )
|
45 | 50 | import boto3
|
46 | 51 |
|
@@ -1365,6 +1370,50 @@ def test_jumpstart_model_session(
|
1365 | 1370 | assert len(s3_clients) == 1
|
1366 | 1371 | assert list(s3_clients)[0] == session.s3_client
|
1367 | 1372 |
|
| 1373 | + @mock.patch.dict( |
| 1374 | + "sagemaker.jumpstart.cache.os.environ", |
| 1375 | + { |
| 1376 | + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", |
| 1377 | + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", |
| 1378 | + }, |
| 1379 | + ) |
| 1380 | + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") |
| 1381 | + @mock.patch("sagemaker.jumpstart.factory.model.Session") |
| 1382 | + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 1383 | + @mock.patch("sagemaker.jumpstart.model.Model.deploy") |
| 1384 | + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) |
| 1385 | + def test_model_local_mode( |
| 1386 | + self, |
| 1387 | + mock_model_deploy: mock.Mock, |
| 1388 | + mock_get_model_specs: mock.Mock, |
| 1389 | + mock_session: mock.Mock, |
| 1390 | + mock_get_manifest: mock.Mock, |
| 1391 | + ): |
| 1392 | + mock_get_model_specs.side_effect = get_prototype_model_spec |
| 1393 | + mock_get_manifest.side_effect = ( |
| 1394 | + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) |
| 1395 | + ) |
| 1396 | + mock_model_deploy.return_value = default_predictor |
| 1397 | + |
| 1398 | + model_id, _ = "pytorch-eqa-bert-base-cased", "*" |
| 1399 | + |
| 1400 | + mock_session.return_value = sagemaker_session |
| 1401 | + |
| 1402 | + model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") |
| 1403 | + |
| 1404 | + model.deploy() |
| 1405 | + |
| 1406 | + mock_model_deploy.assert_called_once_with( |
| 1407 | + initial_instance_count=1, |
| 1408 | + instance_type="ml.p2.xlarge", |
| 1409 | + tags=[ |
| 1410 | + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, |
| 1411 | + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, |
| 1412 | + ], |
| 1413 | + wait=True, |
| 1414 | + endpoint_logging=False, |
| 1415 | + ) |
| 1416 | + |
1368 | 1417 |
|
1369 | 1418 | def test_jumpstart_model_requires_model_id():
|
1370 | 1419 | with pytest.raises(ValueError):
|
|
0 commit comments