Skip to content

Commit 0b449df

Browse files
authored
change: Check that session is a LocalSession when using local mode (#1356)
1 parent 95bf7fa commit 0b449df

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

src/sagemaker/estimator.py

+5
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ def __init__(
218218
if self.train_instance_type == "local_gpu" and self.train_instance_count > 1:
219219
raise RuntimeError("Distributed Training in Local GPU is not supported")
220220
self.sagemaker_session = sagemaker_session or LocalSession()
221+
if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession):
222+
raise RuntimeError(
223+
"instance_type local or local_gpu is only supported with an"
224+
"instance of LocalSession"
225+
)
221226
else:
222227
self.sagemaker_session = sagemaker_session or Session()
223228

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
443443

444444
@pytest.mark.canary_quick
445445
def test_chainer_airflow_config_uploads_data_source_to_s3(
446-
sagemaker_session, cpu_instance_type, chainer_full_version
446+
sagemaker_local_session, cpu_instance_type, chainer_full_version
447447
):
448448
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
449449
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -456,7 +456,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
456456
train_instance_type="local",
457457
framework_version=chainer_full_version,
458458
py_version=PYTHON_VERSION,
459-
sagemaker_session=sagemaker_session,
459+
sagemaker_session=sagemaker_local_session,
460460
hyperparameters={"epochs": 1},
461461
use_mpi=True,
462462
num_processes=2,
@@ -474,7 +474,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
474474
)
475475

476476
_assert_that_s3_url_contains_data(
477-
sagemaker_session,
477+
sagemaker_local_session,
478478
training_config["HyperParameters"]["sagemaker_submit_directory"].strip('"'),
479479
)
480480

tests/unit/test_estimator.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sagemaker.session import s3_input, ShuffleConfig
3131
from sagemaker.transformer import Transformer
3232
from botocore.exceptions import ClientError
33+
import sagemaker.local
3334

3435
MODEL_DATA = "s3://bucket/model.tar.gz"
3536
MODEL_IMAGE = "mi"
@@ -560,6 +561,7 @@ def test_local_code_location():
560561
boto_region_name=REGION,
561562
config=config,
562563
local_mode=True,
564+
spec=sagemaker.local.LocalSession,
563565
)
564566
t = DummyFramework(
565567
entry_point=SCRIPT_PATH,
@@ -2231,7 +2233,7 @@ def test_deploy_with_no_model_name(sagemaker_session):
22312233
@patch("sagemaker.estimator.LocalSession")
22322234
@patch("sagemaker.estimator.Session")
22332235
def test_local_mode(session_class, local_session_class):
2234-
local_session = Mock()
2236+
local_session = Mock(spec=sagemaker.local.LocalSession)
22352237
local_session.local_mode = True
22362238

22372239
session = Mock()
@@ -2259,7 +2261,7 @@ def test_distributed_gpu_local_mode(LocalSession):
22592261

22602262
@patch("sagemaker.estimator.LocalSession")
22612263
def test_local_mode_file_output_path(local_session_class):
2262-
local_session = Mock()
2264+
local_session = Mock(spec=sagemaker.local.LocalSession)
22632265
local_session.local_mode = True
22642266
local_session_class.return_value = local_session
22652267

@@ -2392,3 +2394,28 @@ def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session):
23922394
'"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together'
23932395
in str(error)
23942396
)
2397+
2398+
2399+
def test_estimator_local_mode_error(sagemaker_session):
2400+
# When using instance local with a session which is not LocalSession we should error out
2401+
with pytest.raises(RuntimeError):
2402+
Estimator(
2403+
image_name="some-image",
2404+
role="some_image",
2405+
train_instance_count=1,
2406+
train_instance_type="local",
2407+
sagemaker_session=sagemaker_session,
2408+
base_job_name="base_job_name",
2409+
)
2410+
2411+
2412+
def test_estimator_local_mode_ok(sagemaker_local_session):
2413+
# When using instance local with a session which is not LocalSession we should error out
2414+
Estimator(
2415+
image_name="some-image",
2416+
role="some_image",
2417+
train_instance_count=1,
2418+
train_instance_type="local",
2419+
sagemaker_session=sagemaker_local_session,
2420+
base_job_name="base_job_name",
2421+
)

0 commit comments

Comments
 (0)