|
30 | 30 | from sagemaker.session import s3_input, ShuffleConfig
|
31 | 31 | from sagemaker.transformer import Transformer
|
32 | 32 | from botocore.exceptions import ClientError
|
| 33 | +import sagemaker.local |
33 | 34 |
|
34 | 35 | MODEL_DATA = "s3://bucket/model.tar.gz"
|
35 | 36 | MODEL_IMAGE = "mi"
|
@@ -560,6 +561,7 @@ def test_local_code_location():
|
560 | 561 | boto_region_name=REGION,
|
561 | 562 | config=config,
|
562 | 563 | local_mode=True,
|
| 564 | + spec=sagemaker.local.LocalSession, |
563 | 565 | )
|
564 | 566 | t = DummyFramework(
|
565 | 567 | entry_point=SCRIPT_PATH,
|
@@ -2231,7 +2233,7 @@ def test_deploy_with_no_model_name(sagemaker_session):
|
2231 | 2233 | @patch("sagemaker.estimator.LocalSession")
|
2232 | 2234 | @patch("sagemaker.estimator.Session")
|
2233 | 2235 | def test_local_mode(session_class, local_session_class):
|
2234 |
| - local_session = Mock() |
| 2236 | + local_session = Mock(spec=sagemaker.local.LocalSession) |
2235 | 2237 | local_session.local_mode = True
|
2236 | 2238 |
|
2237 | 2239 | session = Mock()
|
@@ -2259,7 +2261,7 @@ def test_distributed_gpu_local_mode(LocalSession):
|
2259 | 2261 |
|
2260 | 2262 | @patch("sagemaker.estimator.LocalSession")
|
2261 | 2263 | def test_local_mode_file_output_path(local_session_class):
|
2262 |
| - local_session = Mock() |
| 2264 | + local_session = Mock(spec=sagemaker.local.LocalSession) |
2263 | 2265 | local_session.local_mode = True
|
2264 | 2266 | local_session_class.return_value = local_session
|
2265 | 2267 |
|
@@ -2392,3 +2394,28 @@ def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session):
|
2392 | 2394 | '"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together'
|
2393 | 2395 | in str(error)
|
2394 | 2396 | )
|
| 2397 | + |
| 2398 | + |
| 2399 | +def test_estimator_local_mode_error(sagemaker_session): |
| 2400 | + # When using instance local with a session which is not LocalSession we should error out |
| 2401 | + with pytest.raises(RuntimeError): |
| 2402 | + Estimator( |
| 2403 | + image_name="some-image", |
| 2404 | + role="some_image", |
| 2405 | + train_instance_count=1, |
| 2406 | + train_instance_type="local", |
| 2407 | + sagemaker_session=sagemaker_session, |
| 2408 | + base_job_name="base_job_name", |
| 2409 | + ) |
| 2410 | + |
| 2411 | + |
| 2412 | +def test_estimator_local_mode_ok(sagemaker_local_session): |
| 2413 | + # When using instance local with a session which is not LocalSession we should error out |
| 2414 | + Estimator( |
| 2415 | + image_name="some-image", |
| 2416 | + role="some_image", |
| 2417 | + train_instance_count=1, |
| 2418 | + train_instance_type="local", |
| 2419 | + sagemaker_session=sagemaker_local_session, |
| 2420 | + base_job_name="base_job_name", |
| 2421 | + ) |
0 commit comments