Skip to content

Commit 53cef39

Browse files
nmadanNamrata Madan
and
Namrata Madan
committed
convert source_dir (str) argument to include_local_workdir (bool) (aws#883)
Co-authored-by: Namrata Madan <[email protected]>
1 parent 7f2b368 commit 53cef39

File tree

8 files changed

+56
-70
lines changed

8 files changed

+56
-70
lines changed

src/sagemaker/config/config_schema.py

-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
INSTANCE_TYPE = "InstanceType"
5252
S3_KMS_KEY_ID = "S3KmsKeyId"
5353
S3_ROOT_URI = "S3RootUri"
54-
SOURCE_DIR = "SourceDir"
5554
JOB_CONDA_ENV = "JobCondaEnvironment"
5655
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
5756
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
@@ -255,7 +254,6 @@ def _simple_path(*args: str):
255254
)
256255

257256

258-
<<<<<<< HEAD
259257
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
260258
"$schema": "https://json-schema.org/draft/2020-12/schema",
261259
TYPE: OBJECT,

src/sagemaker/remote_function/client.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def remote(
6161
dependencies: str = None,
6262
environment_variables: Dict[str, str] = None,
6363
image_uri: str = None,
64+
include_local_workdir: bool = False,
6465
instance_count: int = 1,
6566
instance_type: str = None,
6667
job_conda_env: str = None,
@@ -73,7 +74,6 @@ def remote(
7374
s3_root_uri: str = None,
7475
sagemaker_session: Session = None,
7576
security_group_ids: List[str] = None,
76-
source_dir: str = None,
7777
subnets: List[str] = None,
7878
tags: List[Tuple[str, str]] = None,
7979
volume_kms_key: str = None,
@@ -87,6 +87,8 @@ def remote(
8787
``from_active_conda_env``. Defaults to None.
8888
environment_variables (Dict): environment variables
8989
image_uri (str): Docker image URI on ECR.
90+
include_local_workdir (bool): Set to ``True`` if the remote function code imports local
91+
modules and methods that are not available via PyPI or conda. Default value is ``False``.
9092
instance_count (int): Number of instance to use. Default is 1.
9193
instance_type (str): EC2 instance type.
9294
job_conda_env (str): Name of the conda environment to activate during execution of the job.
@@ -106,8 +108,6 @@ def remote(
106108
AWS service calls are delegated to (default: None). If not provided, one is created
107109
with default AWS configuration chain.
108110
security_group_ids (List[str]): List of security group IDs.
109-
source_dir (str): Path to locally defined modules that are used in the remote function.
110-
Default is None.
111111
subnets (List[str]): List of subnet IDs.
112112
tags (List[Tuple[str, str]]): List of tags attached to the job.
113113
volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training
@@ -126,6 +126,7 @@ def wrapper(*args, **kwargs):
126126
dependencies=dependencies,
127127
environment_variables=environment_variables,
128128
image_uri=image_uri,
129+
include_local_workdir=include_local_workdir,
129130
instance_count=instance_count,
130131
instance_type=instance_type,
131132
job_conda_env=job_conda_env,
@@ -138,7 +139,6 @@ def wrapper(*args, **kwargs):
138139
s3_root_uri=s3_root_uri,
139140
sagemaker_session=sagemaker_session,
140141
security_group_ids=security_group_ids,
141-
source_dir=source_dir,
142142
subnets=subnets,
143143
tags=tags,
144144
volume_kms_key=volume_kms_key,
@@ -309,6 +309,7 @@ def __init__(
309309
dependencies: str = None,
310310
environment_variables: Dict[str, str] = None,
311311
image_uri: str = None,
312+
include_local_workdir: bool = False,
312313
instance_count: int = 1,
313314
instance_type: str = None,
314315
job_conda_env: str = None,
@@ -322,7 +323,6 @@ def __init__(
322323
s3_root_uri: str = None,
323324
sagemaker_session: Session = None,
324325
security_group_ids: List[str] = None,
325-
source_dir: str = None,
326326
subnets: List[str] = None,
327327
tags: List[Tuple[str, str]] = None,
328328
volume_kms_key: str = None,
@@ -336,6 +336,9 @@ def __init__(
336336
environment_variables (Dict): Environment variables passed to the underlying sagemaker
337337
job. Defaults to None
338338
image_uri (str): Docker image URI on ECR. Defaults to base Python image.
339+
include_local_workdir (bool): Set to ``True`` if the remote function code imports local
340+
modules and methods that are not available via PyPI or conda. Default value is
341+
``False``.
339342
instance_count (int): Number of instance to use. Defaults to 1.
340343
instance_type (str): EC2 instance type.
341344
job_conda_env (str): Name of the conda environment to activate during execution
@@ -360,7 +363,6 @@ def __init__(
360363
AWS service calls are delegated to (default: None). If not provided, one is created
361364
with default AWS configuration chain.
362365
security_group_ids (List[str]): List of security group IDs. Defaults to None.
363-
source_dir: path to local dependencies/modules.
364366
subnets (List[str]): List of subnet IDs. Defaults to None.
365367
tags (List[Tuple[str, str]]): List of tags attached to the job. Defaults to None.
366368
volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training
@@ -377,6 +379,7 @@ def __init__(
377379
dependencies=dependencies,
378380
environment_variables=environment_variables,
379381
image_uri=image_uri,
382+
include_local_workdir=include_local_workdir,
380383
instance_count=instance_count,
381384
instance_type=instance_type,
382385
job_conda_env=job_conda_env,
@@ -389,7 +392,6 @@ def __init__(
389392
s3_root_uri=s3_root_uri,
390393
sagemaker_session=sagemaker_session,
391394
security_group_ids=security_group_ids,
392-
source_dir=source_dir,
393395
subnets=subnets,
394396
tags=tags,
395397
volume_kms_key=volume_kms_key,

src/sagemaker/remote_function/job.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
dependencies: str = None,
108108
environment_variables: Dict[str, str] = None,
109109
image_uri: str = None,
110+
include_local_workdir: bool = False,
110111
instance_count: int = 1,
111112
instance_type: str = None,
112113
job_conda_env: str = None,
@@ -119,7 +120,6 @@ def __init__(
119120
s3_root_uri: str = None,
120121
sagemaker_session: Session = None,
121122
security_group_ids: List[str] = None,
122-
source_dir: str = None,
123123
subnets: List[str] = None,
124124
tags: List[Tuple[str, str]] = None,
125125
volume_kms_key: str = None,
@@ -142,7 +142,7 @@ def __init__(
142142
self.image_uri = self._get_default_image(self.sagemaker_session)
143143

144144
self.dependencies = self._get_from_config(dependencies, config_schema.DEPENDENCIES)
145-
145+
self.include_local_workdir = include_local_workdir
146146
self.instance_type = self._get_from_config(
147147
instance_type, config_schema.INSTANCE_TYPE, required=True
148148
)
@@ -151,7 +151,6 @@ def __init__(
151151
self.max_runtime_in_seconds = max_runtime_in_seconds
152152
self.max_retry_attempts = max_retry_attempts
153153
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
154-
self.source_dir = self._get_from_config(source_dir, config_schema.SOURCE_DIR)
155154
self.job_conda_env = self._get_from_config(job_conda_env, config_schema.JOB_CONDA_ENV)
156155
self.job_name_prefix = job_name_prefix
157156

@@ -280,7 +279,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
280279
dependencies_list_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
281280
user_dependencies_s3uri = _prepare_and_upload_dependencies(
282281
local_dependencies_path=dependencies_list_path,
283-
workdir_path=job_settings.source_dir,
282+
include_local_workdir=job_settings.include_local_workdir,
284283
s3_base_uri=s3_base_uri,
285284
s3_kms_key=job_settings.s3_kms_key,
286285
sagemaker_session=job_settings.sagemaker_session,
@@ -481,23 +480,23 @@ def _prepare_and_upload_runtime_scripts(
481480

482481
def _prepare_and_upload_dependencies(
483482
local_dependencies_path: str,
484-
workdir_path: str,
483+
include_local_workdir: bool,
485484
s3_base_uri: str,
486485
s3_kms_key: str,
487486
sagemaker_session: Session,
488487
) -> str:
489488
"""Upload the job dependencies to S3 if present"""
490489

491-
if not local_dependencies_path and not workdir_path:
490+
if not local_dependencies_path and not include_local_workdir:
492491
return None
493492

494493
with _tmpdir() as tmp_workspace:
495494
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
496495
tmp_workspace = os.path.join(tmp_workspace, "remote_function/")
497496

498-
if workdir_path:
497+
if include_local_workdir:
499498
shutil.copytree(
500-
workdir_path,
499+
os.getcwd(),
501500
tmp_workspace,
502501
ignore=_filter_non_python_files,
503502
)
@@ -516,12 +515,14 @@ def _prepare_and_upload_dependencies(
516515
workspace_archive_path = shutil.make_archive(workspace_archive_path, "zip", tmp_workspace)
517516
logger.info("Successfully created workdir archive at '%s'", workspace_archive_path)
518517

519-
return S3Uploader.upload(
518+
upload_path = S3Uploader.upload(
520519
workspace_archive_path,
521520
s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE),
522521
s3_kms_key,
523522
sagemaker_session,
524523
)
524+
logger.info("Successfully uploaded workdir to '%s'", upload_path)
525+
return upload_path
525526

526527

527528
def _convert_run_to_json(run: Run) -> str:

tests/data/config/config.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ SageMaker:
1212
S3KmsKeyId: "somekmskeyid"
1313
S3RootUri: "s3://bucket/key"
1414
SecurityGroupIds: ["sg123"]
15-
SourceDir: "../mymodule"
1615
Subnets: ["subnet-1234"]
1716
Tags: [{"someTagKey": "someTagValue"}]
1817
VolumeKmsKeyId: "somekmskeyid"

tests/data/remote_function/config.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ SageMaker:
99
JobCondaEnvironment: "my_conda_env"
1010
S3KmsKeyId: "someS3KmsKey"
1111
SecurityGroupIds: ["sg123"]
12-
SourceDir: "../mymodule"
1312
Subnets: ["subnet-1234"]
1413
Tags: [{"someTagKey": "someTagValue"}, {"someTagKey2": "someTagValue2"}]
1514
VolumeKmsKeyId: "someVolumeKmsKey"

tests/integ/sagemaker/remote_function/test_decorator.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -113,26 +113,28 @@ def divide(x, y):
113113

114114

115115
def test_with_local_dependencies(
116-
sagemaker_session, dummy_container_without_error, cpu_instance_type
116+
sagemaker_session, dummy_container_without_error, cpu_instance_type, monkeypatch
117117
):
118-
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
119118
source_dir_path = os.path.join(os.path.dirname(__file__))
119+
with monkeypatch.context() as m:
120+
m.chdir(source_dir_path)
121+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
120122

121-
@remote(
122-
role=ROLE,
123-
image_uri=dummy_container_without_error,
124-
dependencies=dependencies_path,
125-
instance_type=cpu_instance_type,
126-
sagemaker_session=sagemaker_session,
127-
source_dir=source_dir_path,
128-
)
129-
def train(x):
130-
from helpers import local_module
131-
from helpers.nested_helper import local_module2
123+
@remote(
124+
role=ROLE,
125+
image_uri=dummy_container_without_error,
126+
dependencies=dependencies_path,
127+
instance_type=cpu_instance_type,
128+
sagemaker_session=sagemaker_session,
129+
include_local_workdir=True,
130+
)
131+
def train(x):
132+
from helpers import local_module
133+
from helpers.nested_helper import local_module2
132134

133-
return local_module.square(x) + local_module2.cube(x)
135+
return local_module.square(x) + local_module2.cube(x)
134136

135-
assert train(2) == 12
137+
assert train(2) == 12
136138

137139

138140
def test_with_additional_dependencies(

tests/unit/sagemaker/config/conftest.py

-16
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,9 @@ def base_config_with_schema():
2222
return {"SchemaVersion": "1.0"}
2323

2424

25-
<<<<<<< HEAD
2625
@pytest.fixture()
2726
def valid_vpc_config():
2827
return {"SecurityGroupIds": ["sg123"], "Subnets": ["subnet-1234"]}
29-
=======
30-
@pytest.fixture(scope="module")
31-
def valid_vpc_subnet():
32-
return "subnet-1234"
33-
34-
35-
@pytest.fixture(scope="module")
36-
def valid_vpc_security_group():
37-
return "sg123"
38-
39-
40-
@pytest.fixture(scope="module")
41-
def valid_vpc_config(valid_vpc_security_group, valid_vpc_subnet):
42-
return {"SecurityGroupIds": [valid_vpc_security_group], "Subnets": [valid_vpc_subnet]}
43-
>>>>>>> 740cd77d (feature: support intelligent defaults config for pathways)
4428

4529

4630
@pytest.fixture()

0 commit comments

Comments
 (0)