Skip to content

Commit 7f2b368

Browse files
aoguo64Ao GuoRohan GujarathiZhankuildpatro
committed
Master pathways mega pr (aws#882)
* Pathways - Experiments+ integration. * fix comments * RemoteExecutor Exp+ * change: add __init__ py in runtime_environment package * change: integ test fixes * Some improvement to logging * Use UTC for the log entry timestamp * Only configure the "sagemaker" logger instead of root logger * change: update s3 bucket for sagemaker whl file * change: update user workdir packing and unpacking * change: Change remote function dependency cache dir structure * build(deps): bump apache-airflow from 2.4.1 to 2.5.1 (aws#3722) * py37 backward compatibility * minor fix for flake8 due to rebasing * Revert "Some improvement to logging" This reverts commit a0a37de179b00e5c3e69fb1c9ba0d7b871a0fbd8. * minor integ test fix * fix: list_futures integ test --------- Co-authored-by: Ao Guo <[email protected]> Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Dipankar Patro <[email protected]> Co-authored-by: Mourya Baddam <[email protected]> Co-authored-by: Kalyani Nikure <[email protected]> Co-authored-by: Dipankar Patro <[email protected]> Co-authored-by: Namrata Madan <[email protected]>
1 parent fd62e78 commit 7f2b368

24 files changed

+1054
-184
lines changed

src/sagemaker/experiments/run.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
715715

716716
self.close()
717717

718+
def __getstate__(self):
719+
"""Overriding this method to prevent instance of Run from being pickled.
720+
721+
Raise:
722+
NotImplementedError: If attempting to pickle this instance.
723+
"""
724+
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
725+
718726

719727
def load_run(
720728
run_name: Optional[str] = None,
@@ -792,31 +800,34 @@ def load_run(
792800

793801
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
794802

795-
if run_name or environment:
796-
if run_name:
797-
logger.warning(
798-
"run_name is explicitly supplied in load_run, "
799-
"which will be prioritized to load the Run object. "
800-
"In other words, the run name in the experiment config, fetched from the "
801-
"job environment or the current run context, will be ignored."
802-
)
803-
else:
804-
exp_config = get_tc_and_exp_config_from_job_env(
805-
environment=environment, sagemaker_session=sagemaker_session
806-
)
807-
run_name = Run._extract_run_name_from_tc_name(
808-
trial_component_name=exp_config[RUN_NAME],
809-
experiment_name=exp_config[EXPERIMENT_NAME],
810-
)
811-
experiment_name = exp_config[EXPERIMENT_NAME]
812-
803+
if run_name:
804+
logger.warning(
805+
"run_name is explicitly supplied in load_run, "
806+
"which will be prioritized to load the Run object. "
807+
"In other words, the run name in the experiment config, fetched from the "
808+
"job environment or the current run context, will be ignored."
809+
)
813810
run_instance = Run(
814811
experiment_name=experiment_name,
815812
run_name=run_name,
816813
sagemaker_session=sagemaker_session,
817814
)
818815
elif _RunContext.get_current_run():
819816
run_instance = _RunContext.get_current_run()
817+
elif environment:
818+
exp_config = get_tc_and_exp_config_from_job_env(
819+
environment=environment, sagemaker_session=sagemaker_session
820+
)
821+
run_name = Run._extract_run_name_from_tc_name(
822+
trial_component_name=exp_config[RUN_NAME],
823+
experiment_name=exp_config[EXPERIMENT_NAME],
824+
)
825+
experiment_name = exp_config[EXPERIMENT_NAME]
826+
run_instance = Run(
827+
experiment_name=experiment_name,
828+
run_name=run_name,
829+
sagemaker_session=sagemaker_session,
830+
)
820831
else:
821832
raise RuntimeError(
822833
"Failed to load a Run object. "

src/sagemaker/remote_function/client.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from botocore.exceptions import ClientError
2626
from sagemaker.exceptions import UnexpectedStatusException
27+
from sagemaker.experiments._run_context import _RunContext
2728

2829
import sagemaker.remote_function.core.serialization as serialization
2930
from sagemaker.remote_function.errors import RemoteFunctionError, ServiceError, DeserializationError
@@ -33,7 +34,7 @@
3334

3435
from sagemaker.session import Session
3536
from sagemaker.s3 import s3_path_join
36-
from sagemaker.remote_function.job import _JobSettings, _Job
37+
from sagemaker.remote_function.job import _JobSettings, _Job, _RunInfo
3738
from sagemaker.remote_function import logging_config
3839
from sagemaker.utils import name_from_base, base_from_name
3940

@@ -143,8 +144,7 @@ def wrapper(*args, **kwargs):
143144
volume_kms_key=volume_kms_key,
144145
volume_size=volume_size,
145146
)
146-
future = Future()
147-
job = future._start_and_notify(job_settings, func, args, kwargs)
147+
job = _Job.start(job_settings, func, args, kwargs)
148148

149149
try:
150150
job.wait()
@@ -205,12 +205,15 @@ def wrapper(*args, **kwargs):
205205
class _SubmitRequest:
206206
"""Class that holds parameters and data for creating a new job."""
207207

208-
def __init__(self, future, job_settings: _JobSettings, func, func_args, func_kwargs):
208+
def __init__(
209+
self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
210+
):
209211
self.future = future
210212
self.job_settings = job_settings
211213
self.func = func
212214
self.args = func_args
213215
self.kwargs = func_kwargs
216+
self.run_info = run_info
214217

215218

216219
def _submit_worker(executor):
@@ -237,7 +240,7 @@ def has_work_to_do():
237240
time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"])
238241
# submit a new job
239242
job = request.future._start_and_notify(
240-
request.job_settings, request.func, request.args, request.kwargs
243+
request.job_settings, request.func, request.args, request.kwargs, request.run_info
241244
)
242245

243246
with executor._state_condition:
@@ -417,8 +420,14 @@ def submit(self, func, *args, **kwargs):
417420

418421
with self._state_condition:
419422
future = Future()
423+
424+
run_info = None
425+
if _RunContext.get_current_run() is not None:
426+
run = _RunContext.get_current_run()
427+
run_info = _RunInfo(run.experiment_name, run.run_name)
428+
420429
self._pending_request_queue.append(
421-
_SubmitRequest(future, self.job_settings, func, args, kwargs)
430+
_SubmitRequest(future, self.job_settings, func, args, kwargs, run_info)
422431
)
423432

424433
if self._workers is None:
@@ -605,7 +614,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
605614
future._return = job_return
606615
return future
607616

608-
def _start_and_notify(self, job_settings: _JobSettings, func, func_args, func_kwargs):
617+
def _start_and_notify(
618+
self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
619+
):
609620
"""Start and record the newly created job in the future object.
610621
611622
The job is recorded if one is successfully started. Otherwise, the exception is
@@ -615,7 +626,7 @@ def _start_and_notify(self, job_settings: _JobSettings, func, func_args, func_kw
615626
if self._state in [_PENDING]:
616627

617628
try:
618-
self._job = _Job.start(job_settings, func, func_args, func_kwargs)
629+
self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info)
619630
except (Exception,) as e: # pylint: disable=broad-except
620631
self._exception = e
621632
self._state = _FINISHED
@@ -675,11 +686,14 @@ def result(self, timeout: float = None) -> Any:
675686
"FailureReason" in self._job.describe()
676687
and self._job.describe()["FailureReason"]
677688
):
678-
raise RuntimeEnvironmentError(self._job.describe()["FailureReason"])
679-
self._exception = RemoteFunctionError(
680-
"Failed to execute remote function. "
681-
+ "Check corresponding job for details."
682-
)
689+
self._exception = RuntimeEnvironmentError(
690+
self._job.describe()["FailureReason"]
691+
)
692+
else:
693+
self._exception = RemoteFunctionError(
694+
"Failed to execute remote function. "
695+
+ "Check corresponding job for details."
696+
)
683697
else:
684698
self._exception = serr
685699
self._state = _FINISHED

src/sagemaker/remote_function/core/serialization.py

+7
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
9090
try:
9191
bytes_to_upload = cloudpickle.dumps(obj)
9292
except Exception as e:
93+
if isinstance(
94+
e, NotImplementedError
95+
) and "Instance of Run type is not allowed to be pickled." in str(e):
96+
raise SerializationError(
97+
"Remote function does not allow parameters of Run type."
98+
) from e
99+
93100
raise SerializationError(
94101
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
95102
) from e

src/sagemaker/remote_function/core/stored_function.py

+3-55
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
"""SageMaker job function serializer/deserializer."""
1414
from __future__ import absolute_import
1515

16-
import os
17-
import pathlib
18-
import shutil
19-
import sys
20-
21-
from sagemaker.utils import _tmpdir
22-
from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader
16+
from sagemaker.s3 import s3_path_join
2317
from sagemaker.remote_function import logging_config
2418

2519
import sagemaker.remote_function.core.serialization as serialization
@@ -44,20 +38,16 @@ def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None):
4438
self.s3_base_uri = s3_base_uri
4539
self.s3_kms_key = s3_kms_key
4640

47-
def save(self, func, source_dir=None, *args, **kwargs):
48-
"""Serialize and persist the function and it's dependencies and arguments.
41+
def save(self, func, *args, **kwargs):
42+
"""Serialize and persist the function and arguments.
4943
5044
Args:
5145
func: the python function.
52-
source_dir: path to local dependencies/modules. Defaults to ``None``
5346
args: the positional arguments to func.
5447
kwargs: the keyword arguments to func.
5548
Returns:
5649
None
5750
"""
58-
if source_dir:
59-
self._zip_and_upload_source_dir(source_dir)
60-
6151
logger.info(
6252
f"Serializing function code to {s3_path_join(self.s3_base_uri, 'function.pkl')}"
6353
)
@@ -78,22 +68,6 @@ def save(self, func, source_dir=None, *args, **kwargs):
7868
self.s3_kms_key,
7969
)
8070

81-
def _zip_and_upload_source_dir(self, source_dir):
82-
source_dir_path = pathlib.Path(source_dir)
83-
if not source_dir_path.is_dir():
84-
raise AttributeError(source_dir + " is not a valid directory.")
85-
86-
s3_path = s3_path_join(self.s3_base_uri, "source_dir")
87-
logger.info(f"Uploading function source directory to {s3_path}")
88-
with _tmpdir() as tmp:
89-
archived_filepath = shutil.make_archive(
90-
os.path.join(tmp, source_dir_path.name),
91-
"zip",
92-
source_dir_path.parent,
93-
source_dir_path.name,
94-
)
95-
S3Uploader.upload(archived_filepath, s3_path, self.s3_kms_key, self.sagemaker_session)
96-
9771
def load_and_invoke(self) -> None:
9872
"""Load and deserialize the function and the arguments and then execute it."""
9973

@@ -111,8 +85,6 @@ def load_and_invoke(self) -> None:
11185
self.sagemaker_session, s3_path_join(self.s3_base_uri, "arguments.pkl")
11286
)
11387

114-
self._download_and_unzip_source_dir()
115-
11688
logger.info("Invoking the function")
11789
result = func(*args, **kwargs)
11890

@@ -125,27 +97,3 @@ def load_and_invoke(self) -> None:
12597
s3_path_join(self.s3_base_uri, "results.pkl"),
12698
self.s3_kms_key,
12799
)
128-
129-
def _download_and_unzip_source_dir(self):
130-
source_dir_s3_path = s3_path_join(self.s3_base_uri, "source_dir")
131-
local_source_dir_path = os.path.join(os.getcwd(), "source_dir")
132-
133-
logger.info(
134-
f"Downloading source modules from {source_dir_s3_path} to {local_source_dir_path}"
135-
)
136-
137-
downloaded_paths = S3Downloader.download(
138-
source_dir_s3_path,
139-
local_source_dir_path,
140-
kms_key=self.s3_kms_key,
141-
sagemaker_session=self.sagemaker_session,
142-
)
143-
144-
if len(downloaded_paths) < 1:
145-
return
146-
147-
source_dir_archive_path = downloaded_paths[0]
148-
shutil.unpack_archive(
149-
source_dir_archive_path, pathlib.Path(source_dir_archive_path).parent.absolute()
150-
)
151-
sys.path.append(local_source_dir_path)

src/sagemaker/remote_function/invoke_function.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616

1717
import argparse
1818
import sys
19+
import json
1920

2021
import boto3
22+
from sagemaker.experiments.run import Run
23+
from sagemaker.remote_function.job import (
24+
KEY_EXPERIMENT_NAME,
25+
KEY_RUN_NAME,
26+
)
2127

2228
from sagemaker.session import Session
2329
from sagemaker.remote_function.errors import handle_error
@@ -33,6 +39,7 @@ def _parse_agrs():
3339
parser.add_argument("--region", type=str, required=True)
3440
parser.add_argument("--s3_base_uri", type=str, required=True)
3541
parser.add_argument("--s3_kms_key", type=str)
42+
parser.add_argument("--run_in_context", type=str)
3643

3744
args, _ = parser.parse_known_args()
3845
return args
@@ -44,12 +51,28 @@ def _get_sagemaker_session(region):
4451
return Session(boto_session=boto_session)
4552

4653

47-
def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key):
54+
def _load_run_object(run_in_context: str, sagemaker_session: Session) -> Run:
55+
"""Load current run in json string into run object"""
56+
run_dict = json.loads(run_in_context)
57+
return Run(
58+
experiment_name=run_dict.get(KEY_EXPERIMENT_NAME),
59+
run_name=run_dict.get(KEY_RUN_NAME),
60+
sagemaker_session=sagemaker_session,
61+
)
62+
63+
64+
def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context):
4865
"""Execute stored remote function"""
4966
from sagemaker.remote_function.core.stored_function import StoredFunction
5067

5168
stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key)
52-
stored_function.load_and_invoke()
69+
70+
if run_in_context:
71+
run_obj = _load_run_object(run_in_context, sagemaker_session)
72+
with run_obj:
73+
stored_function.load_and_invoke()
74+
else:
75+
stored_function.load_and_invoke()
5376

5477

5578
def main():
@@ -65,9 +88,10 @@ def main():
6588
region = args.region
6689
s3_base_uri = args.s3_base_uri
6790
s3_kms_key = args.s3_kms_key
91+
run_in_context = args.run_in_context
6892

6993
sagemaker_session = _get_sagemaker_session(region)
70-
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
94+
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context)
7195

7296
except Exception as e: # pylint: disable=broad-except
7397
logger.exception("Error encountered while invoking the remote function.")

0 commit comments

Comments
 (0)