diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 1da3767448..c83a107a2d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1713,6 +1713,7 @@ def _stage_user_code_in_s3(self): directory=self.source_dir, dependencies=self.dependencies, kms_key=kms_key, + s3_resource=self.sagemaker_session.s3_resource, ) def _model_source_dir(self): diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f8479838f0..fa209908b8 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -307,7 +307,14 @@ def validate_source_dir(script, directory): def tar_and_upload_dir( - session, bucket, s3_key_prefix, script, directory=None, dependencies=None, kms_key=None + session, + bucket, + s3_key_prefix, + script, + directory=None, + dependencies=None, + kms_key=None, + s3_resource=None, ): """Package source files and upload a compress tar file to S3. The S3 location will be ``s3:///s3_key_prefix/sourcedir.tar.gz``. @@ -331,6 +338,9 @@ def tar_and_upload_dir( copied into /opt/ml/lib kms_key (str): Optional. KMS key ID used to upload objects to the bucket (default: None). + s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource + for S3 connections, can be used to customize the configuration, + e.g. set the endpoint URL (default: None). Returns: sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. @@ -354,7 +364,12 @@ def tar_and_upload_dir( else: extra_args = None - session.resource("s3").Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args) + if s3_resource is None: + s3_resource = session.resource("s3", region_name=session.region_name) + else: + print("Using provided s3_resource") + + s3_resource.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args) finally: shutil.rmtree(tmp) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index c2ac5f7575..8fe3550c03 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -46,6 +46,7 @@ # Environment variables to be set during training REGION_ENV_NAME = "AWS_REGION" TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" +S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" logger = logging.getLogger(__name__) @@ -139,6 +140,11 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name REGION_ENV_NAME: self.sagemaker_session.boto_region_name, TRAINING_JOB_NAME_ENV_NAME: job_name, } + if self.sagemaker_session.s3_resource is not None: + training_env_vars[ + S3_ENDPOINT_URL_ENV_NAME + ] = self.sagemaker_session.s3_resource.meta.client._endpoint.host + compose_data = self._generate_compose_file( "train", additional_volumes=volumes, additional_env_vars=training_env_vars ) @@ -206,6 +212,7 @@ def serve(self, model_dir, environment): "serve", additional_env_vars=environment, additional_volumes=volumes ) compose_command = self._compose() + self.container = _HostingContainer(compose_command) self.container.start() diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index e407e49e4a..1e80f6e1b4 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -98,6 +98,7 @@ def create_training_job( ) training_job = _LocalTrainingJob(container) hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {} + logger.info("Starting training job") training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName) LocalSagemakerClient._training_jobs[TrainingJobName] = training_job @@ -377,7 +378,9 @@ def invoke_endpoint( class LocalSession(Session): """Placeholder docstring""" - def __init__(self, boto_session=None): + def __init__(self, boto_session=None, s3_endpoint_url=None): + self.s3_endpoint_url = s3_endpoint_url + super(LocalSession, self).__init__(boto_session) if platform.system() == "Windows": @@ -407,6 +410,10 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + if self.s3_endpoint_url is not None: + self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url) + self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url) + sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") if os.path.exists(sagemaker_config_file): try: diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index f6eb5682e0..89dbef5ebf 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -92,7 +92,13 @@ def __init__( self.model = model self.container_mode = MULTI_MODEL_CONTAINER_MODE self.sagemaker_session = sagemaker_session or Session() - self.s3_client = self.sagemaker_session.boto_session.client("s3") + + if self.sagemaker_session.s3_client is None: + self.s3_client = self.sagemaker_session.boto_session.client( + "s3", region_name=self.sagemaker_session.boto_session.region_name + ) + else: + self.s3_client = self.sagemaker_session.s3_client # Set the ``Model`` parameters if the model parameter is not specified if not self.model: diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3ef73f0b15..0dc5bf3161 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -107,8 +107,8 @@ def __init__( """ self._default_bucket = None self._default_bucket_name_override = default_bucket - - # currently is used for local_code in local mode + self.s3_resource = None + self.s3_client = None self.config = None self._initialize( @@ -199,7 +199,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): key_suffix = name bucket = bucket or self.default_bucket() - s3 = self.boto_session.resource("s3") + if self.s3_resource is None: + s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_resource for local_path, s3_key in files: s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args) @@ -227,7 +230,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None): str: The S3 URI of the uploaded file. The URI format is: ``s3://{bucket name}/{key}``. """ - s3 = self.boto_session.resource("s3") + if self.s3_resource is None: + s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_resource + s3_object = s3.Object(bucket_name=bucket, key=key) if kms_key is not None: @@ -254,7 +261,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): """ # Initialize the S3 client. - s3 = self.boto_session.client("s3") + if self.s3_client is None: + s3 = self.boto_session.client("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_client # Initialize the variables used to loop through the contents of the S3 bucket. keys = [] @@ -299,7 +309,10 @@ def read_s3_file(self, bucket, key_prefix): str: The body of the s3 file as a string. """ - s3 = self.boto_session.client("s3") + if self.s3_client is None: + s3 = self.boto_session.client("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_client # Explicitly passing a None kms_key to boto3 throws a validation error. s3_object = s3.get_object(Bucket=bucket, Key=key_prefix) @@ -317,7 +330,10 @@ def list_s3_files(self, bucket, key_prefix): [str]: The list of files at the S3 path. """ - s3 = self.boto_session.resource("s3") + if self.s3_resource is None: + s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) + else: + s3 = self.s3_resource s3_bucket = s3.Bucket(name=bucket) s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() @@ -330,6 +346,7 @@ def default_bucket(self): str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. """ + if self._default_bucket: return self._default_bucket @@ -364,10 +381,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): already being created, no exception is raised. """ - bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name) + if self.s3_resource is None: + s3 = self.boto_session.resource("s3", region_name=region) + else: + s3 = self.s3_resource + + bucket = s3.Bucket(name=bucket_name) if bucket.creation_date is None: try: - s3 = self.boto_session.resource("s3", region_name=region) if region == "us-east-1": # 'us-east-1' cannot be specified because it is the default region: # https://github.com/boto/boto3/issues/125 diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index b65fbfa30a..b878069c59 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -36,6 +36,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session._default_bucket = BUCKET_NAME diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 07cf9f5837..4f04e78d55 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -62,6 +62,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 3069860b1e..eb8a5e3385 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -174,6 +174,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock( diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 2b5d0eaba5..5c4ac46900 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -55,6 +55,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=False, + s3_resource=False, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 1047533d0f..0f2f49d95c 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -158,7 +158,9 @@ def cd(path): @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) - session_mock = Mock(name="sagemaker_session", boto_session=boto_mock) + session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None + ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session_mock.expand_role = Mock(name="expand_role", return_value=ROLE) session_mock.sagemaker_client.describe_training_job = Mock( diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 0a8b86f51f..694f2d7e7a 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -73,7 +73,9 @@ def estimator(sagemaker_session): @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session") - mock_session = Mock(name="sagemaker_session", boto_session=boto_mock) + mock_session = Mock( + name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None + ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) return mock_session diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 7ed837a558..555b78b451 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -49,6 +49,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 613b3960ea..b839cdefa9 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -55,6 +55,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 91a543e938..4bb4d5a594 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -39,7 +39,14 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) - sms = Mock(name="sagemaker_session", boto_session=boto_mock, config=None, local_mode=False) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock( diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index ed466b8954..3b16e85c02 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -50,6 +50,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_local_session.py b/tests/unit/test_local_session.py index ad7bc0fb85..a263d6b640 100644 --- a/tests/unit/test_local_session.py +++ b/tests/unit/test_local_session.py @@ -14,9 +14,10 @@ import pytest import urllib3 - +import os from botocore.exceptions import ClientError from mock import Mock, patch +from tests.unit import DATA_DIR import sagemaker @@ -33,6 +34,10 @@ MODEL_NAME = "test-model" PRIMARY_CONTAINER = {"ModelDataUrl": "/some/model/path", "Environment": {"env1": 1, "env2": "b"}} +ENDPOINT_URL = "http://127.0.0.1:9000" +BUCKET_NAME = "mybucket" +LS_FILES = {"Contents": [{"Key": "/data/test.csv"}]} + @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") @@ -475,3 +480,53 @@ def test_local_session_is_set_to_local_mode(): boto_session = Mock(region_name="us-west-2") local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session) assert local_session.local_mode + + +@pytest.fixture() +def sagemaker_session_custom_endpoint(): + + boto_session = Mock("boto_session") + resource_mock = Mock("resource") + client_mock = Mock("client") + boto_attrs = {"region_name": "us-east-1"} + boto_session.configure_mock(**boto_attrs) + boto_session.resource = Mock(name="resource", return_value=resource_mock) + boto_session.client = Mock(name="client", return_value=client_mock) + + local_session = sagemaker.local.local_session.LocalSession( + boto_session=boto_session, s3_endpoint_url=ENDPOINT_URL + ) + + local_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + return local_session + + +def test_local_session_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint): + + boto_session = sagemaker_session_custom_endpoint.boto_session + + boto_session.client.assert_called_with("s3", endpoint_url=ENDPOINT_URL) + boto_session.resource.assert_called_with("s3", endpoint_url=ENDPOINT_URL) + + assert sagemaker_session_custom_endpoint.s3_client is not None + assert sagemaker_session_custom_endpoint.s3_resource is not None + + +def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint): + + DOWNLOAD_DATA_TESTS_FILES_DIR = os.path.join(DATA_DIR, "download_data_tests") + sagemaker_session_custom_endpoint.s3_client.list_objects_v2 = Mock( + name="list_objects_v2", return_value=LS_FILES + ) + sagemaker_session_custom_endpoint.s3_client.download_file = Mock(name="download_file") + + sagemaker_session_custom_endpoint.download_data( + DOWNLOAD_DATA_TESTS_FILES_DIR, BUCKET_NAME, key_prefix="/data/test.csv" + ) + + sagemaker_session_custom_endpoint.s3_client.download_file.assert_called_with( + Bucket=BUCKET_NAME, + Key="/data/test.csv", + Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"), + ExtraArgs=None, + ) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 82c42ae049..d25d267551 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -130,6 +130,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index edcd7df621..84aed3fc58 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -64,6 +64,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index e1bd92e233..3bd3111516 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -71,6 +71,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 1f843e04d7..6fddd4b475 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -49,6 +49,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index cde721d6f9..9c4175fdec 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -57,6 +57,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 5e6bed83b7..ffa6bafe19 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -49,6 +49,8 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 31ac46704a..114920fa5b 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -66,6 +66,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_client=None, + s3_resource=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 9d984ab5e0..04cb8f0cbf 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -65,6 +65,8 @@ def fixture_sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index a1a82cf3d5..80c5a02336 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -62,6 +62,8 @@ def fixture_sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index de41d16fb9..7b23f69c9e 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -64,6 +64,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index d7ec041ca4..f8ec996d62 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -74,6 +74,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/test_tfs.py b/tests/unit/test_tfs.py index 0b915a710c..8de1a5424d 100644 --- a/tests/unit/test_tfs.py +++ b/tests/unit/test_tfs.py @@ -57,6 +57,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) session.default_bucket = Mock(name="default_bucket", return_value="my_bucket") session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 29a503ff1e..490ae96145 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -40,7 +40,7 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) - sms = Mock(name="sagemaker_session", boto_session=boto_mock) + sms = Mock(name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.config = None diff --git a/tests/unit/test_upload_data.py b/tests/unit/test_upload_data.py index 7c0ec51608..e2fab2e6c7 100644 --- a/tests/unit/test_upload_data.py +++ b/tests/unit/test_upload_data.py @@ -25,6 +25,7 @@ UPLOAD_DATA_TESTS_SINGLE_FILE = os.path.join(UPLOAD_DATA_TESTS_FILES_DIR, SINGLE_FILE_NAME) BUCKET_NAME = "mybucket" AES_ENCRYPTION_ENABLED = {"ServerSideEncryption": "AES256"} +ENDPOINT_URL = "http://127.0.0.1:9000" @pytest.fixture() @@ -35,6 +36,25 @@ def sagemaker_session(): return ims +@pytest.fixture() +def sagemaker_session_custom_endpoint(): + + boto_session = Mock("boto_session") + resource_mock = Mock("resource") + client_mock = Mock("client") + boto_attrs = {"region_name": "us-east-1"} + boto_session.configure_mock(**boto_attrs) + boto_session.resource = Mock(name="resource", return_value=resource_mock) + boto_session.client = Mock(name="client", return_value=client_mock) + + local_session = sagemaker.local.local_session.LocalSession( + boto_session=boto_session, s3_endpoint_url=ENDPOINT_URL + ) + + local_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + return local_session + + def test_upload_data_absolute_dir(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR) @@ -50,6 +70,24 @@ def test_upload_data_absolute_dir(sagemaker_session): assert kwargs["ExtraArgs"] is None +def test_upload_data_absolute_dir_custom_endpoint(sagemaker_session_custom_endpoint): + + sagemaker_session_custom_endpoint.s3_resource.Object = Mock() + + result_s3_uri = sagemaker_session_custom_endpoint.upload_data(UPLOAD_DATA_TESTS_FILES_DIR) + + uploaded_files_with_args = [ + (args[0], kwargs) + for name, args, kwargs in sagemaker_session_custom_endpoint.s3_resource.mock_calls + if name == "Object().upload_file" + ] + assert result_s3_uri == "s3://{}/data".format(BUCKET_NAME) + assert len(uploaded_files_with_args) == 4 + for file, kwargs in uploaded_files_with_args: + assert os.path.exists(file) + assert kwargs["ExtraArgs"] is None + + def test_upload_data_absolute_file(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE) diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 6ef43bea58..35b8b6acd4 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -65,6 +65,8 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + s3_resource=None, + s3_client=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}