diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index b3bc3d4365..59204fbc99 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -422,7 +422,20 @@ def tar_and_upload_dir( script_name = script if directory else os.path.basename(script) dependencies = dependencies or [] key = "%s/sourcedir.tar.gz" % s3_key_prefix - tmp = tempfile.mkdtemp() + if ( + settings is not None + and settings.local_download_dir is not None + and not ( + os.path.exists(settings.local_download_dir) + and os.path.isdir(settings.local_download_dir) + ) + ): + raise ValueError( + "Inputted directory for storing newly generated temporary directory does " + f"not exist: '{settings.local_download_dir}'" + ) + local_download_dir = None if settings is None else settings.local_download_dir + tmp = tempfile.mkdtemp(dir=local_download_dir) encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts try: diff --git a/src/sagemaker/session_settings.py b/src/sagemaker/session_settings.py index 53ff9a9f0d..6c7e48dce2 100644 --- a/src/sagemaker/session_settings.py +++ b/src/sagemaker/session_settings.py @@ -18,17 +18,25 @@ class SessionSettings(object): """Optional container class for settings to apply to a SageMaker session.""" - def __init__(self, encrypt_repacked_artifacts=True) -> None: + def __init__(self, encrypt_repacked_artifacts=True, local_download_dir=None) -> None: """Initialize the ``SessionSettings`` of a SageMaker ``Session``. Args: encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key is not provided (Default: True). + local_download_dir (str): Optional. A path specifying the local directory + for downloading artifacts. (Default: None). """ self._encrypt_repacked_artifacts = encrypt_repacked_artifacts + self._local_download_dir = local_download_dir @property def encrypt_repacked_artifacts(self) -> bool: """Return True if repacked artifacts at rest in S3 should be encrypted by default.""" return self._encrypt_repacked_artifacts + + @property + def local_download_dir(self) -> str: + """Return path specifying the local directory for downloading artifacts.""" + return self._local_download_dir diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 9d28e3bf4e..2f1870f1fc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -358,7 +358,7 @@ def create_tar_file(source_files, target=None): @contextlib.contextmanager -def _tmpdir(suffix="", prefix="tmp"): +def _tmpdir(suffix="", prefix="tmp", directory=None): """Create a temporary directory with a context manager. The file is deleted when the context exits. @@ -369,11 +369,18 @@ def _tmpdir(suffix="", prefix="tmp"): suffix, otherwise there will be no suffix. prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise, a default prefix is used. + directory (str): If a directory is specified, the file will be downloaded + in this directory; otherwise, a default directory is used. Returns: str: path to the directory """ - tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None) + if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)): + raise ValueError( + "Inputted directory for storing newly generated temporary " + f"directory does not exist: '{directory}'" + ) + tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory) yield tmp shutil.rmtree(tmp) @@ -427,7 +434,13 @@ def repack_model( """ dependencies = dependencies or [] - with _tmpdir() as tmp: + local_download_dir = ( + None + if sagemaker_session.settings is None + or sagemaker_session.settings.local_download_dir is None + else sagemaker_session.settings.local_download_dir + ) + with _tmpdir(directory=local_download_dir) as tmp: model_dir = _extract_model(model_uri, sagemaker_session, tmp) _create_or_update_code_dir( diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 2c997397c5..e68a019ce4 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -18,6 +18,7 @@ from mock import Mock, patch from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel from sagemaker.predictor import Predictor +from sagemaker.session_settings import SessionSettings from sagemaker.workflow.functions import Join MODEL_DATA = "s3://bucket/model.tar.gz" @@ -254,6 +255,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.upload_data = Mock(name="upload_data", return_value=DEFAULT_S3_INPUT_DATA) diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 072eefeb83..2d7261cdc6 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -20,6 +20,8 @@ from mock import MagicMock, Mock, patch from sagemaker.huggingface import HuggingFace, HuggingFaceModel +from sagemaker.session_settings import SessionSettings + from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION @@ -63,6 +65,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index 96a94f42e7..e7887cd794 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -17,6 +17,7 @@ from sagemaker.huggingface.processing import HuggingFaceProcessor from sagemaker.fw_utils import UploadedCode +from sagemaker.session_settings import SessionSettings from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION @@ -42,6 +43,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py index 66fedcab8c..a67dca6ab4 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/conftest.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/conftest.py @@ -15,6 +15,8 @@ from mock.mock import Mock import pytest +from sagemaker.session_settings import SessionSettings + REGION_NAME = "us-west-2" BUCKET_NAME = "some-bucket-name" @@ -26,6 +28,7 @@ def session(): boto_session=boto_mock, boto_region_name=REGION_NAME, config=None, + settings=SessionSettings(), ) sms.default_bucket = Mock(return_value=BUCKET_NAME) return sms diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 0129e574ea..668efd7a41 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -17,6 +17,7 @@ from mock import patch, Mock import sagemaker.local.utils +from sagemaker.session_settings import SessionSettings @patch("sagemaker.local.utils.os.path") @@ -42,7 +43,9 @@ def test_move_to_destination_local(recursive_copy): @patch("shutil.rmtree", Mock()) @patch("sagemaker.local.utils.recursive_copy") def test_move_to_destination_s3(recursive_copy): - sms = Mock() + sms = Mock( + settings=SessionSettings(), + ) # without trailing slash in prefix sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms) diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index caa6eb0779..73ff09ef07 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -21,6 +21,8 @@ import pytest from mock import MagicMock, Mock, patch +from sagemaker.session_settings import SessionSettings + MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" ENTRY_POINT = "blah.py" @@ -89,6 +91,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 0b04d3c8bc..2274a5feb6 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -776,3 +776,27 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses == get_model_package_args.call_args_list[0][1]["validation_specification"] ), """ValidationSpecification from model.register method is not identical to validation_spec from get_model_package_args""" + + +@patch("sagemaker.utils.repack_model") +def test_model_local_download_dir(repack_model, sagemaker_session): + + source_dir = "s3://blah/blah/blah" + local_download_dir = "local download dir" + + sagemaker_session.settings.local_download_dir = local_download_dir + + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert ( + repack_model.call_args_list[0][1]["sagemaker_session"].settings.local_download_dir + == local_download_dir + ) diff --git a/tests/unit/sagemaker/spark/test_processing.py b/tests/unit/sagemaker/spark/test_processing.py index 4f784e1c66..ba08f82fad 100644 --- a/tests/unit/sagemaker/spark/test_processing.py +++ b/tests/unit/sagemaker/spark/test_processing.py @@ -18,6 +18,7 @@ import pytest from sagemaker.processing import ProcessingInput, ProcessingOutput +from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import ( PySparkProcessor, SparkJarProcessor, @@ -57,6 +58,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 771b18b35a..78e4a0d281 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -21,6 +21,7 @@ import pytest from sagemaker.estimator import _TrainingJob +from sagemaker.session_settings import SessionSettings from sagemaker.tensorflow import TensorFlow from sagemaker.instance_group import InstanceGroup from sagemaker.workflow.parameters import ParameterString, ParameterBoolean @@ -71,6 +72,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 656730a47c..06490c4813 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -24,6 +24,7 @@ from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig from sagemaker.huggingface.model import HuggingFaceModel from sagemaker.instance_group import InstanceGroup +from sagemaker.session_settings import SessionSettings from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -72,6 +73,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index c3684ac649..4f048aa536 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -22,6 +22,7 @@ from sagemaker import image_uris from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig from sagemaker.huggingface.model import HuggingFaceModel +from sagemaker.session_settings import SessionSettings from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -70,6 +71,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 068bb4e4b9..0b3f0e8de6 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -24,6 +24,7 @@ from sagemaker.pytorch import PyTorch, TrainingCompilerConfig from sagemaker.pytorch.model import PyTorchModel from sagemaker.instance_group import InstanceGroup +from sagemaker.session_settings import SessionSettings from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -71,6 +72,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index a5c14b1626..5a8fce34ef 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -21,6 +21,7 @@ from mock import MagicMock, Mock, patch from sagemaker import image_uris +from sagemaker.session_settings import SessionSettings from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -76,6 +77,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index 87a98744b6..f83ae74dfa 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -14,6 +14,7 @@ import pytest from mock import Mock, MagicMock +from sagemaker.session_settings import SessionSettings from sagemaker.wrangler.processing import DataWranglerProcessor from sagemaker.processing import ProcessingInput @@ -36,6 +37,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) session_mock.expand_role.return_value = ROLE return session_mock diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 44b5818fc8..18a576d44f 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -23,6 +23,7 @@ _build_shards, FileSystemRecordSet, ) +from sagemaker.session_settings import SessionSettings COMMON_ARGS = {"role": "myrole", "instance_count": 1, "instance_type": "ml.c4.xlarge"} @@ -40,6 +41,7 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index e30abb7da9..8cb90dbf46 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -24,6 +24,7 @@ HyperparameterTuningJobAnalytics, TrainingJobAnalytics, ) +from sagemaker.session_settings import SessionSettings BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -47,6 +48,7 @@ def create_sagemaker_session( boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index eca4a9bf80..dbcedc1d99 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -24,6 +24,7 @@ from sagemaker.chainer import defaults from sagemaker.chainer import Chainer from sagemaker.chainer import ChainerPredictor, ChainerModel +from sagemaker.session_settings import SessionSettings DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -62,6 +63,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 8b771f9184..45a944ce1a 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -52,6 +52,7 @@ from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor from sagemaker.pytorch.estimator import PyTorch +from sagemaker.session_settings import SessionSettings from sagemaker.sklearn.estimator import SKLearn from sagemaker.tensorflow.estimator import TensorFlow from sagemaker.predictor_async import AsyncPredictor @@ -231,6 +232,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock( @@ -253,7 +255,7 @@ def pipeline_session(): type(role_mock).arn = PropertyMock(return_value=ROLE) resource_mock = Mock() resource_mock.Role.return_value = role_mock - session_mock = Mock(region_name=REGION) + session_mock = Mock(region_name=REGION, settings=SessionSettings()) session_mock.resource.return_value = resource_mock session_mock.client.return_value = client_mock return PipelineSession( @@ -754,6 +756,7 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -783,6 +786,7 @@ def test_framework_with_debugger_config_set_up_in_unsupported_region(region): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -809,6 +813,7 @@ def test_framework_enable_profiling_in_unsupported_region(region): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -835,6 +840,7 @@ def test_framework_update_profiling_in_unsupported_region(region): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -861,6 +867,7 @@ def test_framework_disable_profiling_in_unsupported_region(region): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -4645,3 +4652,44 @@ def test_script_mode_estimator_escapes_hyperparameters_as_json( - set(sagemaker_session.train.call_args_list[0][1]["hyperparameters"].items()) == set() ) + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +@patch("sagemaker.model.Model._upload_code") +def test_estimator_local_download_dir( + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session +): + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + sagemaker_session.boto_region_name = REGION + + local_download_dir = "some/download/dir" + + sagemaker_session.settings.local_download_dir = local_download_dir + + instance_type = "ml.p2.xlarge" + instance_count = 1 + + training_data_uri = "s3://bucket/mydata" + + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" + + generic_estimator = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + sagemaker_session=sagemaker_session, + instance_count=instance_count, + instance_type=instance_type, + source_dir=jumpstart_source_dir, + image_uri=IMAGE_URI, + model_uri=MODEL_DATA, + ) + generic_estimator.fit(training_data_uri) + + assert ( + patched_tar_and_upload_dir.call_args_list[0][1]["settings"].local_download_dir + == local_download_dir + ) diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 985262ca64..ceefeb9b3e 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -21,6 +21,7 @@ FactorizationMachinesPredictor, ) from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -58,6 +59,7 @@ def sagemaker_session(): local_mode=False, s3_client=False, s3_resource=False, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 279a90eeeb..528af4ada0 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -43,7 +43,11 @@ def cd(path): def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name="us-west-2") session_mock = Mock( - name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None + name="sagemaker_session", + boto_session=boto_mock, + s3_client=None, + s3_resource=None, + settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value="my-bucket") session_mock.expand_role = Mock(name="expand_role", return_value="my-role") diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 3190ea2e18..478d33fc06 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.ipinsights import IPInsights, IPInsightsPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings # Mocked training config ROLE = "myrole" @@ -55,6 +56,7 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index be1862175c..790ee73576 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.kmeans import KMeans, KMeansPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -52,6 +53,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 9afe468cea..02dc073d9d 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.knn import KNN, KNNPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -58,6 +59,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 0da89384ed..f2574b30b5 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.lda import LDA, LDAPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -47,6 +48,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index bd41b9e49b..bb2a140200 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -53,6 +54,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index f12d8e160f..2395856acd 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -28,6 +28,7 @@ from sagemaker.mxnet import defaults from sagemaker.mxnet import MXNet from sagemaker.mxnet import MXNetPredictor, MXNetModel +from sagemaker.session_settings import SessionSettings DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_NAME = "dummy_script.py" @@ -83,6 +84,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index f18a15457e..cbe9f18e36 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.ntm import NTM, NTMPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -52,6 +53,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 5963feb5bb..e6aaf770fa 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -19,6 +19,7 @@ from sagemaker.amazon.object2vec import Object2Vec from sagemaker.predictor import Predictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -60,6 +61,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 9ea9dcbcd6..222021caa3 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.pca import PCA, PCAPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -52,6 +53,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 38c9b373e3..913ffbc556 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -18,6 +18,7 @@ from sagemaker.model import FrameworkModel from sagemaker.pipeline import PipelineModel from sagemaker.predictor import Predictor +from sagemaker.session_settings import SessionSettings from sagemaker.sparkml import SparkMLModel ENTRY_POINT = "blah.py" @@ -65,6 +66,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 81579d397a..34c530747d 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -29,6 +29,7 @@ ScriptProcessor, ProcessingJob, ) +from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import PySparkProcessor from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.pytorch.processing import PyTorchProcessor @@ -68,6 +69,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) @@ -89,6 +91,7 @@ def pipeline_session(): boto_region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5691834c3a..a11738e25c 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -23,6 +23,7 @@ from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel from sagemaker.instance_group import InstanceGroup +from sagemaker.session_settings import SessionSettings DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -71,6 +72,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index a2a07e5296..3e3bca00dc 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -18,6 +18,7 @@ from sagemaker import image_uris from sagemaker.amazon.randomcutforest import RandomCutForest, RandomCutForestPredictor from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.session_settings import SessionSettings ROLE = "myrole" INSTANCE_COUNT = 1 @@ -52,6 +53,7 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 0c0a9c6d64..fea49a7548 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -21,6 +21,7 @@ from sagemaker.mxnet import MXNetModel, MXNetPredictor from sagemaker.rl import RLEstimator, RLFramework, RLToolkit, TOOLKIT_FRAMEWORK_VERSION_MAP +from sagemaker.session_settings import SessionSettings from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor @@ -64,6 +65,7 @@ def fixture_sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 430cb484b4..d16e887d18 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -21,6 +21,7 @@ from mock import patch from sagemaker.fw_utils import UploadedCode +from sagemaker.session_settings import SessionSettings from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") @@ -66,6 +67,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index ae975b1ac2..3fb21d62d2 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -16,6 +16,7 @@ from mock import Mock from sagemaker import image_uris +from sagemaker.session_settings import SessionSettings from sagemaker.sparkml import SparkMLModel, SparkMLPredictor MODEL_DATA = "s3://bucket/model.tar.gz" @@ -40,6 +41,7 @@ def sagemaker_session(): region_name=REGION, config=None, local_mode=False, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index 6d40a2e6dd..bded6ce2cc 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -23,6 +23,7 @@ import stopit from botocore.exceptions import ClientError +from sagemaker.session_settings import SessionSettings from tests.integ.timeout import ( timeout, @@ -56,6 +57,7 @@ def session(): boto_region_name=REGION, config=None, local_mode=True, + settings=SessionSettings(), ) sms.default_bucket = Mock(name=DEFAULT_BUCKET_NAME, return_value=BUCKET_NAME) return sms diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 7e556c7d23..c7e489784b 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -30,6 +30,7 @@ from sagemaker.jumpstart.enums import JumpStartTag from sagemaker.mxnet import MXNet from sagemaker.parameter import ParameterRange +from sagemaker.session_settings import SessionSettings from sagemaker.tuner import ( HYPERBAND_MAX_RESOURCE, HYPERBAND_MIN_RESOURCE, @@ -55,6 +56,7 @@ def sagemaker_session(): boto_session=boto_mock, s3_client=None, s3_resource=None, + settings=SessionSettings(), ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8bcbed41c2..962f43b6cf 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -558,7 +558,7 @@ def test_repack_model_from_file_to_file(tmp): model_tar_path = os.path.join(tmp, "model.tar.gz") sagemaker.utils.create_tar_file([os.path.join(tmp, "model")], model_tar_path) - sagemaker_session = MagicMock() + sagemaker_session = MagicMock(settings=SessionSettings()) file_mode_path = "file://%s" % model_tar_path destination_path = "file://%s" % os.path.join(tmp, "repacked-model.tar.gz") @@ -608,7 +608,7 @@ def test_repack_model_from_file_to_folder(tmp): [], file_mode_path, "file://%s/repacked-model.tar.gz" % tmp, - MagicMock(), + MagicMock(settings=SessionSettings()), ) assert list_tar_files("file://%s/repacked-model.tar.gz" % tmp, tmp) == { @@ -679,7 +679,7 @@ def test_repack_model_with_same_inference_file_name(tmp, fake_s3): class FakeS3(object): def __init__(self, tmp): self.tmp = tmp - self.sagemaker_session = MagicMock() + self.sagemaker_session = MagicMock(settings=SessionSettings()) self.location_map = {} self.current_bucket = None self.object_mock = MagicMock() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 87a853d5d0..8fe5a0bc78 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -23,6 +23,7 @@ from sagemaker.fw_utils import UploadedCode +from sagemaker.session_settings import SessionSettings from sagemaker.xgboost import XGBoost, XGBoostModel, XGBoostPredictor @@ -69,6 +70,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + settings=SessionSettings(), ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}