Skip to content

feat: local download dir for Model and Estimator classes #3535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
c97c467
fix: type hint of PySparkProcessor __init__ (#3297)
NivekNey Dec 2, 2022
de58941
fix: fix PySparkProcessor __init__ params type (#3354)
andre-marcos-perez Dec 2, 2022
41dd330
fix: Allow Py 3.7 for MMS Test Docker env (#3080)
shreyapandit Dec 2, 2022
1e23a3f
refactoring : using with statement (#3286)
maldil Dec 2, 2022
19efadf
Update local_requirements.txt PyYAML version (#3095)
shreyapandit Dec 2, 2022
76f7782
feature: Update TF 2.9 and TF 2.10 inference DLCs (#3465)
arjkesh Dec 2, 2022
fde0738
feature: Added transform with monitoring pipeline step in transformer…
keshav-chandak Dec 2, 2022
7f9f3b0
fix: Fix bug forcing uploaded tar to be named sourcedir (#3412)
claytonparnell Dec 2, 2022
5d59767
feature: Add Code Owners file (#3503)
navinsoni Dec 2, 2022
0f5cf18
prepare release v2.119.0
Dec 3, 2022
f1f0013
update development version to v2.119.1.dev0
Dec 3, 2022
bb4b689
feature: Add DXB region to frameworks by DLC (#3387)
RadhikaB-97 Dec 5, 2022
b68bcd9
fix: support idempotency for framework and spark processors (#3460)
brockwade633 Dec 5, 2022
32969da
feature: Update registries with new region account number mappings. (…
kenny-ezirim Dec 6, 2022
767da0a
feature: Adding support for SageMaker Training Compiler in PyTorch es…
Lokiiiiii Dec 7, 2022
d779d1b
feature: Add Neo image uri config for Pytorch 1.12 (#3507)
HappyAmazonian Dec 7, 2022
83327fb
prepare release v2.120.0
Dec 7, 2022
5bffb04
update development version to v2.120.1.dev0
Dec 7, 2022
b828396
feature: Algorithms Region Expansion OSU/DXB (#3508)
malav-shastri Dec 7, 2022
357f732
fix: Add constraints file for apache-airflow (#3510)
navinsoni Dec 7, 2022
a28d1dd
fix: FrameworkProcessor S3 uploads (#3493)
brockwade633 Dec 8, 2022
11d2475
prepare release v2.121.0
Dec 8, 2022
24171b5
update development version to v2.121.1.dev0
Dec 8, 2022
d5847d5
Fix: Differentiate SageMaker Training Compiler's PT DLCs from base PT…
Lokiiiiii Dec 8, 2022
3f6ea88
fix: Fix failing jumpstart cache unit tests (#3514)
evakravi Dec 8, 2022
4570aa6
fix: Pop out ModelPackageName from pipeline definition (#3472)
qidewenwhen Dec 9, 2022
959ea1a
prepare release v2.121.1
Dec 9, 2022
b2e8b66
update development version to v2.121.2.dev0
Dec 9, 2022
355975d
fix: Skip Bad Transform Test (#3521)
amzn-choeric Dec 9, 2022
fadc817
fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524)
mufaddal-rohawala Dec 9, 2022
c5fc93f
change: Update for Tensorflow Serving 2.11 inference DLCs (#3509)
hballuru Dec 9, 2022
ec8da98
prepare release v2.121.2
Dec 12, 2022
0352122
update development version to v2.121.3.dev0
Dec 12, 2022
d6c0214
feature: Add OSU region to frameworks for DLC (#3532)
kace Dec 12, 2022
5af4feb
fix: Remove content type image/jpg from analysis configuration schema…
xgchena Dec 12, 2022
4389847
fix: unpin packaging version (#3533)
claytonparnell Dec 13, 2022
ef0c3e0
feat: local download dir for Model and Estimator classes
evakravi Dec 13, 2022
27cfe6c
Merge branch 'master' into feat/local-download-dir
evakravi Dec 13, 2022
f3e0feb
chore: move local_download_dir to SessionSettings
evakravi Dec 13, 2022
d314b67
fix: tox errors
evakravi Dec 14, 2022
f912edf
Merge remote-tracking branch 'origin' into feat/local-download-dir
evakravi Dec 30, 2022
51899a9
chore: clean git diff
evakravi Dec 30, 2022
eb3ba1b
Merge remote-tracking branch 'origin' into feat/local-download-dir
evakravi Jan 5, 2023
094e476
fix: pytorch compiler unit tests
evakravi Jan 5, 2023
0086f6b
Merge branch 'master' into feat/local-download-dir
evakravi Jan 12, 2023
7ff0891
Merge branch 'master' into feat/local-download-dir
evakravi Jan 17, 2023
b6f9918
Merge branch 'master' into feat/local-download-dir
evakravi Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,20 @@ def tar_and_upload_dir(
script_name = script if directory else os.path.basename(script)
dependencies = dependencies or []
key = "%s/sourcedir.tar.gz" % s3_key_prefix
tmp = tempfile.mkdtemp()
if (
settings is not None
and settings.local_download_dir is not None
and not (
os.path.exists(settings.local_download_dir)
and os.path.isdir(settings.local_download_dir)
)
):
raise ValueError(
"Inputted directory for storing newly generated temporary directory does "
f"not exist: '{settings.local_download_dir}'"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about attempting to create the target dir with os.makedirs(dir, exists_ok=True) instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not. Making directories on behalf of customers seems risky.

local_download_dir = None if settings is None else settings.local_download_dir
tmp = tempfile.mkdtemp(dir=local_download_dir)
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts

try:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/session_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@
class SessionSettings(object):
"""Optional container class for settings to apply to a SageMaker session."""

def __init__(self, encrypt_repacked_artifacts=True) -> None:
def __init__(self, encrypt_repacked_artifacts=True, local_download_dir=None) -> None:
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.

Args:
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
is not provided (Default: True).
local_download_dir (str): Optional. A path specifying the local directory
for downloading artifacts. (Default: None).
"""
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
self._local_download_dir = local_download_dir

@property
def encrypt_repacked_artifacts(self) -> bool:
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
return self._encrypt_repacked_artifacts

@property
def local_download_dir(self) -> str:
"""Return path specifying the local directory for downloading artifacts."""
return self._local_download_dir
19 changes: 16 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def create_tar_file(source_files, target=None):


@contextlib.contextmanager
def _tmpdir(suffix="", prefix="tmp"):
def _tmpdir(suffix="", prefix="tmp", directory=None):
"""Create a temporary directory with a context manager.

The file is deleted when the context exits.
Expand All @@ -369,11 +369,18 @@ def _tmpdir(suffix="", prefix="tmp"):
suffix, otherwise there will be no suffix.
prefix (str): If prefix is specified, the file name will begin with that
prefix; otherwise, a default prefix is used.
directory (str): If a directory is specified, the file will be downloaded
in this directory; otherwise, a default directory is used.

Returns:
str: path to the directory
"""
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None)
if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
raise ValueError(
"Inputted directory for storing newly generated temporary "
f"directory does not exist: '{directory}'"
)
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
yield tmp
shutil.rmtree(tmp)

Expand Down Expand Up @@ -427,7 +434,13 @@ def repack_model(
"""
dependencies = dependencies or []

with _tmpdir() as tmp:
local_download_dir = (
None
if sagemaker_session.settings is None
or sagemaker_session.settings.local_download_dir is None
else sagemaker_session.settings.local_download_dir
)
with _tmpdir(directory=local_download_dir) as tmp:
model_dir = _extract_model(model_uri, sagemaker_session, tmp)

_create_or_update_code_dir(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/automl/test_auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mock import Mock, patch
from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel
from sagemaker.predictor import Predictor
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow.functions import Join

MODEL_DATA = "s3://bucket/model.tar.gz"
Expand Down Expand Up @@ -254,6 +255,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.upload_data = Mock(name="upload_data", return_value=DEFAULT_S3_INPUT_DATA)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mock import MagicMock, Mock, patch

from sagemaker.huggingface import HuggingFace, HuggingFaceModel
from sagemaker.session_settings import SessionSettings


from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand Down Expand Up @@ -63,6 +65,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/huggingface/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sagemaker.huggingface.processing import HuggingFaceProcessor
from sagemaker.fw_utils import UploadedCode
from sagemaker.session_settings import SessionSettings

from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand All @@ -42,6 +43,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/image_uris/jumpstart/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mock.mock import Mock
import pytest

from sagemaker.session_settings import SessionSettings

REGION_NAME = "us-west-2"
BUCKET_NAME = "some-bucket-name"

Expand All @@ -26,6 +28,7 @@ def session():
boto_session=boto_mock,
boto_region_name=REGION_NAME,
config=None,
settings=SessionSettings(),
)
sms.default_bucket = Mock(return_value=BUCKET_NAME)
return sms
5 changes: 4 additions & 1 deletion tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mock import patch, Mock

import sagemaker.local.utils
from sagemaker.session_settings import SessionSettings


@patch("sagemaker.local.utils.os.path")
Expand All @@ -42,7 +43,9 @@ def test_move_to_destination_local(recursive_copy):
@patch("shutil.rmtree", Mock())
@patch("sagemaker.local.utils.recursive_copy")
def test_move_to_destination_s3(recursive_copy):
sms = Mock()
sms = Mock(
settings=SessionSettings(),
)

# without trailing slash in prefix
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/model/test_framework_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import pytest
from mock import MagicMock, Mock, patch

from sagemaker.session_settings import SessionSettings

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
ENTRY_POINT = "blah.py"
Expand Down Expand Up @@ -89,6 +91,7 @@ def sagemaker_session():
local_mode=False,
s3_client=None,
s3_resource=None,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
return sms
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,27 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses
== get_model_package_args.call_args_list[0][1]["validation_specification"]
), """ValidationSpecification from model.register method is not identical to validation_spec from
get_model_package_args"""


@patch("sagemaker.utils.repack_model")
def test_model_local_download_dir(repack_model, sagemaker_session):

source_dir = "s3://blah/blah/blah"
local_download_dir = "local download dir"

sagemaker_session.settings.local_download_dir = local_download_dir

t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=source_dir,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert (
repack_model.call_args_list[0][1]["sagemaker_session"].settings.local_download_dir
== local_download_dir
)
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/spark/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.session_settings import SessionSettings
from sagemaker.spark.processing import (
PySparkProcessor,
SparkJarProcessor,
Expand Down Expand Up @@ -57,6 +58,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from sagemaker.estimator import _TrainingJob
from sagemaker.session_settings import SessionSettings
from sagemaker.tensorflow import TensorFlow
from sagemaker.instance_group import InstanceGroup
from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
Expand Down Expand Up @@ -71,6 +72,7 @@ def sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)
session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
session.expand_role = Mock(name="expand_role", return_value=ROLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.instance_group import InstanceGroup
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -72,6 +73,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker import image_uris
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -70,6 +71,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.pytorch import PyTorch, TrainingCompilerConfig
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.instance_group import InstanceGroup
from sagemaker.session_settings import SessionSettings

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES

Expand Down Expand Up @@ -71,6 +72,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mock import MagicMock, Mock, patch

from sagemaker import image_uris
from sagemaker.session_settings import SessionSettings
from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig

from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES
Expand Down Expand Up @@ -76,6 +77,7 @@ def fixture_sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/wrangler/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
from mock import Mock, MagicMock
from sagemaker.session_settings import SessionSettings

from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.processing import ProcessingInput
Expand All @@ -36,6 +37,7 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
session_mock.expand_role.return_value = ROLE
return session_mock
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_build_shards,
FileSystemRecordSet,
)
from sagemaker.session_settings import SessionSettings

COMMON_ARGS = {"role": "myrole", "instance_count": 1, "instance_type": "ml.c4.xlarge"}

Expand All @@ -40,6 +41,7 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HyperparameterTuningJobAnalytics,
TrainingJobAnalytics,
)
from sagemaker.session_settings import SessionSettings

BUCKET_NAME = "mybucket"
REGION = "us-west-2"
Expand All @@ -47,6 +48,7 @@ def create_sagemaker_session(
boto_region_name=REGION,
config=None,
local_mode=False,
settings=SessionSettings(),
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.chainer import defaults
from sagemaker.chainer import Chainer
from sagemaker.chainer import ChainerPredictor, ChainerModel
from sagemaker.session_settings import SessionSettings

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
Expand Down Expand Up @@ -62,6 +63,7 @@ def sagemaker_session():
local_mode=False,
s3_resource=None,
s3_client=None,
settings=SessionSettings(),
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
Loading