Skip to content

feat: local download dir for Model and Estimator classes #3602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,20 @@ def tar_and_upload_dir(
script_name = script if directory else os.path.basename(script)
dependencies = dependencies or []
key = "%s/sourcedir.tar.gz" % s3_key_prefix
tmp = tempfile.mkdtemp()
if (
settings is not None
and settings.local_download_dir is not None
and not (
os.path.exists(settings.local_download_dir)
and os.path.isdir(settings.local_download_dir)
)
):
raise ValueError(
"Inputted directory for storing newly generated temporary directory does "
f"not exist: '{settings.local_download_dir}'"
)
local_download_dir = None if settings is None else settings.local_download_dir
tmp = tempfile.mkdtemp(dir=local_download_dir)
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts

try:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/session_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@
class SessionSettings(object):
"""Optional container class for settings to apply to a SageMaker session."""

def __init__(self, encrypt_repacked_artifacts=True) -> None:
def __init__(self, encrypt_repacked_artifacts=True, local_download_dir=None) -> None:
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.

Args:
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
is not provided (Default: True).
local_download_dir (str): Optional. A path specifying the local directory
for downloading artifacts. (Default: None).
"""
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
self._local_download_dir = local_download_dir

@property
def encrypt_repacked_artifacts(self) -> bool:
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
return self._encrypt_repacked_artifacts

@property
def local_download_dir(self) -> str:
"""Return path specifying the local directory for downloading artifacts."""
return self._local_download_dir
19 changes: 16 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def create_tar_file(source_files, target=None):


@contextlib.contextmanager
def _tmpdir(suffix="", prefix="tmp"):
def _tmpdir(suffix="", prefix="tmp", directory=None):
"""Create a temporary directory with a context manager.

The file is deleted when the context exits.
Expand All @@ -369,11 +369,18 @@ def _tmpdir(suffix="", prefix="tmp"):
suffix, otherwise there will be no suffix.
prefix (str): If prefix is specified, the file name will begin with that
prefix; otherwise, a default prefix is used.
directory (str): If a directory is specified, the file will be downloaded
in this directory; otherwise, a default directory is used.

Returns:
str: path to the directory
"""
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None)
if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
raise ValueError(
"Inputted directory for storing newly generated temporary "
f"directory does not exist: '{directory}'"
)
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
yield tmp
shutil.rmtree(tmp)

Expand Down Expand Up @@ -427,7 +434,13 @@ def repack_model(
"""
dependencies = dependencies or []

with _tmpdir() as tmp:
local_download_dir = (
None
if sagemaker_session.settings is None
or sagemaker_session.settings.local_download_dir is None
else sagemaker_session.settings.local_download_dir
)
with _tmpdir(directory=local_download_dir) as tmp:
model_dir = _extract_model(model_uri, sagemaker_session, tmp)

_create_or_update_code_dir(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/automl/test_auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mock import Mock, patch
from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel
from sagemaker.predictor import Predictor
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow.functions import Join

MODEL_DATA = "s3://bucket/model.tar.gz"
Expand Down Expand Up @@ -254,6 +255,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.upload_data = Mock(name="upload_data", return_value=DEFAULT_S3_INPUT_DATA)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mock import MagicMock, Mock, patch

from sagemaker.huggingface import HuggingFace, HuggingFaceModel
from sagemaker.session_settings import SessionSettings


from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand Down Expand Up @@ -63,6 +65,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/huggingface/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sagemaker.huggingface.processing import HuggingFaceProcessor
from sagemaker.fw_utils import UploadedCode
from sagemaker.session_settings import SessionSettings

from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand All @@ -42,6 +43,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/image_uris/jumpstart/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mock.mock import Mock
import pytest

from sagemaker.session_settings import SessionSettings

REGION_NAME = "us-west-2"
BUCKET_NAME = "some-bucket-name"

Expand All @@ -26,6 +28,7 @@ def session():
boto_session=boto_mock,
boto_region_name=REGION_NAME,
config=None,
settings=SessionSettings(),
)
sms.default_bucket = Mock(return_value=BUCKET_NAME)
return sms
5 changes: 4 additions & 1 deletion tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mock import patch, Mock

import sagemaker.local.utils
from sagemaker.session_settings import SessionSettings


@patch("sagemaker.local.utils.os.path")
Expand All @@ -42,7 +43,9 @@ def test_move_to_destination_local(recursive_copy):
@patch("shutil.rmtree", Mock())
@patch("sagemaker.local.utils.recursive_copy")
def test_move_to_destination_s3(recursive_copy):
sms = Mock()
sms = Mock(
settings=SessionSettings(),
)

# without trailing slash in prefix
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/model/test_framework_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import pytest
from mock import MagicMock, Mock, patch

from sagemaker.session_settings import SessionSettings

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
ENTRY_POINT = "blah.py"
Expand Down Expand Up @@ -89,6 +91,7 @@ def sagemaker_session():
local_mode=False,
s3_client=None,
s3_resource=None,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
return sms
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,27 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses
== get_model_package_args.call_args_list[0][1]["validation_specification"]
), """ValidationSpecification from model.register method is not identical to validation_spec from
get_model_package_args"""


@patch("sagemaker.utils.repack_model")
def test_model_local_download_dir(repack_model, sagemaker_session):

source_dir = "s3://blah/blah/blah"
local_download_dir = "local download dir"

sagemaker_session.settings.local_download_dir = local_download_dir

t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=source_dir,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert (
repack_model.call_args_list[0][1]["sagemaker_session"].settings.local_download_dir
== local_download_dir
)
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/spark/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.session_settings import SessionSettings
from sagemaker.spark.processing import (
PySparkProcessor,
SparkJarProcessor,
Expand Down Expand Up @@ -57,6 +58,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from sagemaker.estimator import _TrainingJob
from sagemaker.session_settings import SessionSettings
from sagemaker.tensorflow import TensorFlow
from sagemaker.instance_group import InstanceGroup
from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
Expand Down Expand Up @@ -71,6 +72,7 @@ def sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)
session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
session.expand_role = Mock(name="expand_role", return_value=ROLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.instance_group import InstanceGroup
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -72,6 +73,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker import image_uris
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -70,6 +71,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.pytorch import PyTorch, TrainingCompilerConfig
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.instance_group import InstanceGroup
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -71,6 +72,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mock import MagicMock, Mock, patch

from sagemaker import image_uris
from sagemaker.session_settings import SessionSettings
from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES
Expand Down Expand Up @@ -76,6 +77,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/wrangler/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
from mock import Mock, MagicMock
from sagemaker.session_settings import SessionSettings

from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.processing import ProcessingInput
Expand All @@ -36,6 +37,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.expand_role.return_value = ROLE
return session_mock
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_build_shards,
FileSystemRecordSet,
)
from sagemaker.session_settings import SessionSettings

COMMON_ARGS = {"role": "myrole", "instance_count": 1, "instance_type": "ml.c4.xlarge"}

Expand All @@ -40,6 +41,7 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HyperparameterTuningJobAnalytics,
TrainingJobAnalytics,
)
from sagemaker.session_settings import SessionSettings

BUCKET_NAME = "mybucket"
REGION = "us-west-2"
Expand All @@ -47,6 +48,7 @@ def create_sagemaker_session(
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.chainer import defaults
from sagemaker.chainer import Chainer
from sagemaker.chainer import ChainerPredictor, ChainerModel
from sagemaker.session_settings import SessionSettings

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
Expand Down Expand Up @@ -62,6 +63,7 @@ def sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Loading