Skip to content

change: Check that session is a LocalSession when using local mode #1356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 20, 2020
5 changes: 5 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def __init__(
if self.train_instance_type == "local_gpu" and self.train_instance_count > 1:
raise RuntimeError("Distributed Training in Local GPU is not supported")
self.sagemaker_session = sagemaker_session or LocalSession()
if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession):
raise RuntimeError(
"instance_type local or local_gpu is only supported with an"
"instance of LocalSession"
)
else:
self.sagemaker_session = sagemaker_session or Session()

Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins

@pytest.mark.canary_quick
def test_chainer_airflow_config_uploads_data_source_to_s3(
sagemaker_session, cpu_instance_type, chainer_full_version
sagemaker_local_session, cpu_instance_type, chainer_full_version
):
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
Expand All @@ -456,7 +456,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
train_instance_type="local",
framework_version=chainer_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
sagemaker_session=sagemaker_local_session,
hyperparameters={"epochs": 1},
use_mpi=True,
num_processes=2,
Expand All @@ -474,7 +474,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
)

_assert_that_s3_url_contains_data(
sagemaker_session,
sagemaker_local_session,
training_config["HyperParameters"]["sagemaker_submit_directory"].strip('"'),
)

Expand Down
31 changes: 29 additions & 2 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sagemaker.session import s3_input, ShuffleConfig
from sagemaker.transformer import Transformer
from botocore.exceptions import ClientError
import sagemaker.local

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
Expand Down Expand Up @@ -560,6 +561,7 @@ def test_local_code_location():
boto_region_name=REGION,
config=config,
local_mode=True,
spec=sagemaker.local.LocalSession,
)
t = DummyFramework(
entry_point=SCRIPT_PATH,
Expand Down Expand Up @@ -2231,7 +2233,7 @@ def test_deploy_with_no_model_name(sagemaker_session):
@patch("sagemaker.estimator.LocalSession")
@patch("sagemaker.estimator.Session")
def test_local_mode(session_class, local_session_class):
local_session = Mock()
local_session = Mock(spec=sagemaker.local.LocalSession)
local_session.local_mode = True

session = Mock()
Expand Down Expand Up @@ -2259,7 +2261,7 @@ def test_distributed_gpu_local_mode(LocalSession):

@patch("sagemaker.estimator.LocalSession")
def test_local_mode_file_output_path(local_session_class):
local_session = Mock()
local_session = Mock(spec=sagemaker.local.LocalSession)
local_session.local_mode = True
local_session_class.return_value = local_session

Expand Down Expand Up @@ -2392,3 +2394,28 @@ def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session):
'"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together'
in str(error)
)


def test_estimator_local_mode_error(sagemaker_session):
# When using instance local with a session which is not LocalSession we should error out
with pytest.raises(RuntimeError):
Estimator(
image_name="some-image",
role="some_image",
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_session,
base_job_name="base_job_name",
)


def test_estimator_local_mode_ok(sagemaker_local_session):
# When using instance local with a session which is not LocalSession we should error out
Estimator(
image_name="some-image",
role="some_image",
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_local_session,
base_job_name="base_job_name",
)