diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index e96b068899..15de7833f0 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -379,7 +379,7 @@ def __init__(self, boto_session=None): if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): """Initialize this Local SageMaker Session. Args: @@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.config = yaml.load(open(sagemaker_config_file, "r")) + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): """ diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e5058541fc..ae0abbd8d1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods bucket based on a naming convention which includes the current AWS account ID. """ - def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None): + def __init__( + self, + boto_session=None, + sagemaker_client=None, + sagemaker_runtime_client=None, + default_bucket=None, + ): """Initialize a SageMaker ``Session``. Args: @@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. + default_bucket (str): The default s3 bucket to be used by this session. + Ex: "sagemaker-us-west-2" + """ self._default_bucket = None # currently is used for local_code in local mode self.config = None - self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client) + self._initialize( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime_client, + default_bucket=default_bucket, + ) - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): """Initialize this SageMaker Session. Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. @@ -126,6 +140,12 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): prepend_user_agent(self.sagemaker_runtime_client) + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + + # Create default bucket on session init to verify that desired name, if specified, is valid + self.default_bucket() + self.local_mode = False @property @@ -314,11 +334,14 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket + default_bucket = self._desired_default_bucket_name region = self.boto_session.region_name - account = self.boto_session.client( - "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) - ).get_caller_identity()["Account"] - default_bucket = "sagemaker-{}-{}".format(region, account) + + if not default_bucket: + account = self.boto_session.client( + "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) + ).get_caller_identity()["Account"] + default_bucket = "sagemaker-{}-{}".format(region, account) s3 = self.boto_session.resource("s3") try: diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index a73c9e1e0d..f076a79404 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession): def __init__(self): super(LocalSession, self).__init__() - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + @pytest.fixture(scope="module") def mxnet_model(sagemaker_local_session, mxnet_full_version): diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 364bc1d6d6..715eaac2cf 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -14,7 +14,10 @@ import os +import boto3 import pytest +from botocore.config import Config +from sagemaker import Session from sagemaker.fw_registry import default_framework_uri from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor @@ -23,6 +26,35 @@ from tests.integ.kms_utils import get_or_create_kms_key ROLE = "SageMakerRole" +DEFAULT_REGION = "us-west-2" +CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket" + + +@pytest.fixture(scope="module") +def sagemaker_session_with_custom_bucket( + boto_config, sagemaker_client_config, sagemaker_runtime_config +): + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + default_bucket=CUSTOM_BUCKET_PATH, + ) @pytest.fixture(scope="module") @@ -170,6 +202,90 @@ def test_sklearn_with_customizations( assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} +def test_sklearn_with_custom_default_bucket( + sagemaker_session_with_custom_bucket, + image_uri, + sklearn_full_version, + cpu_instance_type, + output_kms_key, +): + + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_full_version, + role=ROLE, + command=["python3"], + instance_type=cpu_instance_type, + instance_count=1, + volume_size_in_gb=100, + volume_kms_key=None, + output_kms_key=output_kms_key, + max_runtime_in_seconds=3600, + base_job_name="test-sklearn-with-customizations", + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], + sagemaker_session=sagemaker_session_with_custom_bucket, + ) + + sklearn_processor.run( + code=os.path.join(DATA_DIR, "dummy_script.py"), + inputs=[ + ProcessingInput( + source=input_file_path, + destination="/opt/ml/processing/input/container/path/", + input_name="dummy_input", + s3_data_type="S3Prefix", + s3_input_mode="File", + s3_data_distribution_type="FullyReplicated", + s3_compression_type="None", + ) + ], + outputs=[ + ProcessingOutput( + source="/opt/ml/processing/output/container/path/", + output_name="dummy_output", + s3_upload_mode="EndOfJob", + ) + ], + arguments=["-v"], + wait=True, + logs=True, + ) + + job_description = sklearn_processor.latest_job.describe() + + assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] + + assert job_description["ProcessingInputs"][1]["InputName"] == "code" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"] + + assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations") + + assert job_description["ProcessingJobStatus"] == "Completed" + + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" + + assert job_description["ProcessingResources"] == { + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} + } + + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ + "python3", + "/opt/ml/processing/input/code/dummy_script.py", + ] + assert job_description["AppSpecification"]["ImageUri"] == image_uri + + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + assert ROLE in job_description["RoleArn"] + + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + def test_sklearn_with_no_inputs_or_outputs( sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type ): diff --git a/tests/unit/test_create_deploy_entities.py b/tests/unit/test_create_deploy_entities.py index cb1b6eafb7..b5146ce321 100644 --- a/tests/unit/test_create_deploy_entities.py +++ b/tests/unit/test_create_deploy_entities.py @@ -34,6 +34,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock ims = sagemaker.Session(boto_session=boto_mock) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index dfcdb83378..8c99375509 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -25,6 +25,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} ims = sagemaker.Session(boto_session=boto_mock) return ims @@ -48,11 +55,13 @@ def test_default_already_cached(sagemaker_session): existing_default = "mydefaultbucket" sagemaker_session._default_bucket = existing_default + before_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls + bucket_name = sagemaker_session.default_bucket() - create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls + after_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls assert bucket_name == existing_default - assert create_calls == [] + assert before_create_calls == after_create_calls def test_default_bucket_exists(sagemaker_session): @@ -78,22 +87,42 @@ def test_concurrent_bucket_modification(sagemaker_session): assert bucket_name == DEFAULT_BUCKET_NAME -def test_bucket_creation_client_error(sagemaker_session): +def test_bucket_creation_client_error(): with pytest.raises(ClientError): + boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} + error = ClientError( error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}}, operation_name="foo", ) - sagemaker_session.boto_session.resource().create_bucket.side_effect = error + boto_mock.resource().create_bucket.side_effect = error - sagemaker_session.default_bucket() - assert sagemaker_session._default_bucket is None + session = sagemaker.Session(boto_session=boto_mock) + assert session._default_bucket is None -def test_bucket_creation_other_error(sagemaker_session): +def test_bucket_creation_other_error(): with pytest.raises(RuntimeError): + boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} + error = RuntimeError() - sagemaker_session.boto_session.resource().create_bucket.side_effect = error + boto_mock.resource().create_bucket.side_effect = error - sagemaker_session.default_bucket() - assert sagemaker_session._default_bucket is None + session = sagemaker.Session(boto_session=boto_mock) + assert session._default_bucket is None diff --git a/tests/unit/test_endpoint_from_job.py b/tests/unit/test_endpoint_from_job.py index 2b83305afe..484812526e 100644 --- a/tests/unit/test_endpoint_from_job.py +++ b/tests/unit/test_endpoint_from_job.py @@ -43,6 +43,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock) ims.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=TRAINING_JOB_RESPONSE diff --git a/tests/unit/test_endpoint_from_model_data.py b/tests/unit/test_endpoint_from_model_data.py index 75f72ab221..83677abaf0 100644 --- a/tests/unit/test_endpoint_from_model_data.py +++ b/tests/unit/test_endpoint_from_model_data.py @@ -36,6 +36,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock) ims.sagemaker_client.describe_model = Mock( name="describe_model", side_effect=_raise_does_not_exist_client_error diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index dc288edc5a..9df402e69b 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -29,7 +29,13 @@ def get_sagemaker_session(returns_status): client_mock.describe_model_package = MagicMock( return_value={"ModelPackageStatus": returns_status} ) + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status}) + boto_mock.client.return_value = client_mock ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims diff --git a/tests/unit/test_local_session.py b/tests/unit/test_local_session.py index c733111206..3a8214b9f8 100644 --- a/tests/unit/test_local_session.py +++ b/tests/unit/test_local_session.py @@ -473,5 +473,12 @@ def test_file_input_content_type(): def test_local_session_is_set_to_local_mode(): boto_session = Mock(region_name="us-west-2") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_session.client.return_value = client_mock local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session) assert local_session.local_mode diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4b63134281..a10a5f0547 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -48,6 +48,11 @@ def boto_session(): client_mock._client_config.user_agent = ( "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" ) + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } boto_mock.client.return_value = client_mock return boto_mock @@ -1325,6 +1330,13 @@ def __init__(self, code): @pytest.fixture() def sagemaker_session_complete(): boto_mock = Mock(name="boto_session") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) @@ -1338,6 +1350,13 @@ def sagemaker_session_complete(): @pytest.fixture() def sagemaker_session_stopped(): boto_mock = Mock(name="boto_session") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) @@ -1349,6 +1368,13 @@ def sagemaker_session_stopped(): @pytest.fixture() def sagemaker_session_ready_lifecycle(): boto_mock = Mock(name="boto_session") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) @@ -1368,6 +1394,13 @@ def sagemaker_session_ready_lifecycle(): @pytest.fixture() def sagemaker_session_full_lifecycle(): boto_mock = Mock(name="boto_session") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock boto_mock.client("logs").describe_log_streams.side_effect = LIFECYCLE_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) diff --git a/tests/unit/test_upload_data.py b/tests/unit/test_upload_data.py index 6b731cc25b..3adc82f787 100644 --- a/tests/unit/test_upload_data.py +++ b/tests/unit/test_upload_data.py @@ -30,6 +30,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session") + client_mock = Mock() + client_mock.get_caller_identity.return_value = { + "UserId": "mock_user_id", + "Account": "012345678910", + "Arn": "arn:aws:iam::012345678910:user/mock-user", + } + boto_mock.client.return_value = client_mock ims = sagemaker.Session(boto_session=boto_mock) ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return ims