diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 1931725397..1da3767448 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -218,6 +218,11 @@ def __init__( if self.train_instance_type == "local_gpu" and self.train_instance_count > 1: raise RuntimeError("Distributed Training in Local GPU is not supported") self.sagemaker_session = sagemaker_session or LocalSession() + if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession): + raise RuntimeError( + "instance_type local or local_gpu is only supported with an" + "instance of LocalSession" + ) else: self.sagemaker_session = sagemaker_session or Session() diff --git a/tests/integ/test_airflow_config.py b/tests/integ/test_airflow_config.py index 8b1e112c49..25439f67af 100644 --- a/tests/integ/test_airflow_config.py +++ b/tests/integ/test_airflow_config.py @@ -443,7 +443,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins @pytest.mark.canary_quick def test_chainer_airflow_config_uploads_data_source_to_s3( - sagemaker_session, cpu_instance_type, chainer_full_version + sagemaker_local_session, cpu_instance_type, chainer_full_version ): with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS): script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py") @@ -456,7 +456,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3( train_instance_type="local", framework_version=chainer_full_version, py_version=PYTHON_VERSION, - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_local_session, hyperparameters={"epochs": 1}, use_mpi=True, num_processes=2, @@ -474,7 +474,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3( ) _assert_that_s3_url_contains_data( - sagemaker_session, + sagemaker_local_session, training_config["HyperParameters"]["sagemaker_submit_directory"].strip('"'), ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 62245d7500..3069860b1e 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -30,6 +30,7 @@ from sagemaker.session import s3_input, ShuffleConfig from sagemaker.transformer import Transformer from botocore.exceptions import ClientError +import sagemaker.local MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -560,6 +561,7 @@ def test_local_code_location(): boto_region_name=REGION, config=config, local_mode=True, + spec=sagemaker.local.LocalSession, ) t = DummyFramework( entry_point=SCRIPT_PATH, @@ -2231,7 +2233,7 @@ def test_deploy_with_no_model_name(sagemaker_session): @patch("sagemaker.estimator.LocalSession") @patch("sagemaker.estimator.Session") def test_local_mode(session_class, local_session_class): - local_session = Mock() + local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True session = Mock() @@ -2259,7 +2261,7 @@ def test_distributed_gpu_local_mode(LocalSession): @patch("sagemaker.estimator.LocalSession") def test_local_mode_file_output_path(local_session_class): - local_session = Mock() + local_session = Mock(spec=sagemaker.local.LocalSession) local_session.local_mode = True local_session_class.return_value = local_session @@ -2392,3 +2394,28 @@ def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session): '"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together' in str(error) ) + + +def test_estimator_local_mode_error(sagemaker_session): + # When using instance local with a session which is not LocalSession we should error out + with pytest.raises(RuntimeError): + Estimator( + image_name="some-image", + role="some_image", + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + + +def test_estimator_local_mode_ok(sagemaker_local_session): + # When using instance local with a session which is not LocalSession we should error out + Estimator( + image_name="some-image", + role="some_image", + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + base_job_name="base_job_name", + )