Skip to content

Commit cc5853d

Browse files
qidewenwhensagemaker-bot
authored andcommitted
change: Support local mode for remote function (aws#4306)
1 parent 9ce672e commit cc5853d

File tree

6 files changed

+63
-46
lines changed

6 files changed

+63
-46
lines changed

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def wait(self, timeout: int = None):
891891
"""
892892

893893
self._last_describe_response = _logs_for_job(
894-
boto_session=self.sagemaker_session.boto_session,
894+
sagemaker_session=self.sagemaker_session,
895895
job_name=self.job_name,
896896
wait=True,
897897
timeout=timeout,

src/sagemaker/session.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5454,7 +5454,7 @@ def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=No
54545454
exceptions.CapacityError: If the training job fails with CapacityError.
54555455
exceptions.UnexpectedStatusException: If waiting and the training job fails.
54565456
"""
5457-
_logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout)
5457+
_logs_for_job(self, job_name, wait, poll, log_type, timeout)
54585458

54595459
def logs_for_processing_job(self, job_name, wait=False, poll=10):
54605460
"""Display logs for a given processing job, optionally tailing them until the is complete.
@@ -7337,17 +7337,16 @@ def _rule_statuses_changed(current_statuses, last_statuses):
73377337

73387338

73397339
def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
7340-
boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None
7340+
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
73417341
):
73427342
"""Display logs for a given training job, optionally tailing them until job is complete.
73437343
73447344
If the output is a tty or a Jupyter cell, it will be color-coded
73457345
based on which instance the log entry is from.
73467346
73477347
Args:
7348-
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
7349-
calls are delegated to (default: None). If not provided, one is created with
7350-
default AWS configuration chain.
7348+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
7349+
object, used for SageMaker interactions.
73517350
job_name (str): Name of the training job to display the logs for.
73527351
wait (bool): Whether to keep looking for new log entries until the job completes
73537352
(default: False).
@@ -7364,13 +7363,13 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
73647363
exceptions.CapacityError: If the training job fails with CapacityError.
73657364
exceptions.UnexpectedStatusException: If waiting and the training job fails.
73667365
"""
7367-
sagemaker_client = boto_session.client("sagemaker")
7366+
sagemaker_client = sagemaker_session.sagemaker_client
73687367
request_end_time = time.time() + timeout if timeout else None
73697368
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
73707369
print(secondary_training_status_message(description, None), end="")
73717370

73727371
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
7373-
boto_session, description, job="Training"
7372+
sagemaker_session.boto_session, description, job="Training"
73747373
)
73757374

73767375
state = _get_initial_job_state(description, "TrainingJobStatus", wait)

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_with_additional_dependencies(
207207
def cuberoot(x):
208208
from scipy.special import cbrt
209209

210-
return cbrt(27)
210+
return cbrt(x)
211211

212212
assert cuberoot(27) == 3
213213

@@ -742,7 +742,7 @@ def test_with_user_and_workdir_set_in_the_image(
742742
def cuberoot(x):
743743
from scipy.special import cbrt
744744

745-
return cbrt(27)
745+
return cbrt(x)
746746

747747
assert cuberoot(27) == 3
748748

tests/integ/test_local_mode.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import stopit
2525

2626
import tests.integ.lock as lock
27+
from sagemaker.remote_function import remote
2728
from sagemaker.workflow.step_outputs import get_step
2829
from tests.integ.sagemaker.conftest import _build_container, DOCKERFILE_TEMPLATE
2930
from sagemaker.config import SESSION_DEFAULT_S3_BUCKET_PATH
@@ -58,6 +59,7 @@
5859
LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_local_mode_lock")
5960
DATA_PATH = os.path.join(DATA_DIR, "iris", "data")
6061
DEFAULT_REGION = "us-west-2"
62+
ROLE = "SageMakerRole"
6163

6264

6365
class LocalNoS3Session(LocalSession):
@@ -147,7 +149,7 @@ def _create_model(output_path):
147149

148150
mx = MXNet(
149151
entry_point=script_path,
150-
role="SageMakerRole",
152+
role=ROLE,
151153
instance_count=1,
152154
instance_type="local",
153155
output_path=output_path,
@@ -218,7 +220,7 @@ def test_mxnet_local_mode(
218220

219221
mx = MXNet(
220222
entry_point=script_path,
221-
role="SageMakerRole",
223+
role=ROLE,
222224
py_version=mxnet_training_latest_py_version,
223225
instance_count=1,
224226
instance_type="local",
@@ -254,7 +256,7 @@ def test_mxnet_distributed_local_mode(
254256

255257
mx = MXNet(
256258
entry_point=script_path,
257-
role="SageMakerRole",
259+
role=ROLE,
258260
py_version=mxnet_training_latest_py_version,
259261
instance_count=2,
260262
instance_type="local",
@@ -289,7 +291,7 @@ def test_mxnet_local_data_local_script(
289291

290292
mx = MXNet(
291293
entry_point=script_path,
292-
role="SageMakerRole",
294+
role=ROLE,
293295
instance_count=1,
294296
instance_type="local",
295297
framework_version=mxnet_training_latest_version,
@@ -324,7 +326,7 @@ def test_mxnet_local_training_env(mxnet_training_latest_version, mxnet_training_
324326

325327
mx = MXNet(
326328
entry_point=script_path,
327-
role="SageMakerRole",
329+
role=ROLE,
328330
instance_count=1,
329331
instance_type="local",
330332
framework_version=mxnet_training_latest_version,
@@ -347,7 +349,7 @@ def test_mxnet_training_failure(
347349

348350
mx = MXNet(
349351
entry_point=script_path,
350-
role="SageMakerRole",
352+
role=ROLE,
351353
framework_version=mxnet_training_latest_version,
352354
py_version=mxnet_training_latest_py_version,
353355
instance_count=1,
@@ -377,7 +379,7 @@ def test_local_transform_mxnet(
377379

378380
mx = MXNet(
379381
entry_point=script_path,
380-
role="SageMakerRole",
382+
role=ROLE,
381383
instance_count=1,
382384
instance_type="local",
383385
framework_version=mxnet_inference_latest_version,
@@ -426,7 +428,7 @@ def test_local_processing_sklearn(sagemaker_local_session_no_local_code, sklearn
426428

427429
sklearn_processor = SKLearnProcessor(
428430
framework_version=sklearn_latest_version,
429-
role="SageMakerRole",
431+
role=ROLE,
430432
instance_type="local",
431433
instance_count=1,
432434
command=["python3"],
@@ -457,7 +459,7 @@ def test_local_processing_script_processor(sagemaker_local_session, sklearn_imag
457459
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
458460

459461
script_processor = ScriptProcessor(
460-
role="SageMakerRole",
462+
role=ROLE,
461463
image_uri=sklearn_image_uri,
462464
command=["python3"],
463465
instance_count=1,
@@ -527,7 +529,7 @@ def test_local_pipeline_with_processing_step(sklearn_latest_version, local_pipel
527529
string_container_arg = ParameterString(name="ProcessingContainerArg", default_value="foo")
528530
sklearn_processor = SKLearnProcessor(
529531
framework_version=sklearn_latest_version,
530-
role="SageMakerRole",
532+
role=ROLE,
531533
instance_type="local",
532534
instance_count=1,
533535
command=["python3"],
@@ -549,7 +551,7 @@ def test_local_pipeline_with_processing_step(sklearn_latest_version, local_pipel
549551
sagemaker_session=local_pipeline_session,
550552
parameters=[string_container_arg],
551553
)
552-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
554+
pipeline.create(ROLE, "pipeline for sdk integ testing")
553555

554556
with lock.lock(LOCK_PATH):
555557
execution = pipeline.start()
@@ -586,7 +588,7 @@ def test_local_pipeline_with_training_and_transform_steps(
586588
# define Estimator
587589
mx = MXNet(
588590
entry_point=script_path,
589-
role="SageMakerRole",
591+
role=ROLE,
590592
instance_count=instance_count,
591593
instance_type="local",
592594
framework_version=mxnet_training_latest_version,
@@ -614,7 +616,7 @@ def test_local_pipeline_with_training_and_transform_steps(
614616
image_uri=inference_image_uri,
615617
model_data=training_step.properties.ModelArtifacts.S3ModelArtifacts,
616618
sagemaker_session=session,
617-
role="SageMakerRole",
619+
role=ROLE,
618620
)
619621

620622
# define create model step
@@ -647,7 +649,7 @@ def test_local_pipeline_with_training_and_transform_steps(
647649
sagemaker_session=session,
648650
)
649651

650-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
652+
pipeline.create(ROLE, "pipeline for sdk integ testing")
651653

652654
with lock.lock(LOCK_PATH):
653655
execution = pipeline.start(parameters={"InstanceCountParam": 1})
@@ -667,7 +669,7 @@ def test_local_pipeline_with_training_and_transform_steps(
667669
def test_local_pipeline_with_eval_cond_fail_steps(sklearn_image_uri, local_pipeline_session):
668670
processor = ScriptProcessor(
669671
image_uri=sklearn_image_uri,
670-
role="SageMakerRole",
672+
role=ROLE,
671673
instance_count=1,
672674
instance_type="local",
673675
sagemaker_session=local_pipeline_session,
@@ -729,7 +731,7 @@ def test_local_pipeline_with_eval_cond_fail_steps(sklearn_image_uri, local_pipel
729731
sagemaker_session=local_pipeline_session,
730732
)
731733

732-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
734+
pipeline.create(ROLE, "pipeline for sdk integ testing")
733735

734736
with lock.lock(LOCK_PATH):
735737
execution = pipeline.start()
@@ -763,7 +765,7 @@ def test_local_pipeline_with_step_decorator_and_step_dependency(
763765
local_pipeline_session, dummy_container
764766
):
765767
step_settings = dict(
766-
role="SageMakerRole",
768+
role=ROLE,
767769
instance_type="ml.m5.xlarge",
768770
image_uri=dummy_container,
769771
keep_alive_period_in_seconds=60,
@@ -787,7 +789,7 @@ def sum(a, b):
787789
sagemaker_session=local_pipeline_session,
788790
)
789791

790-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
792+
pipeline.create(ROLE, "pipeline for sdk integ testing")
791793

792794
with lock.lock(LOCK_PATH):
793795
execution = pipeline.start()
@@ -808,7 +810,7 @@ def test_local_pipeline_with_step_decorator_and_pre_exe_script(
808810
local_pipeline_session, dummy_container
809811
):
810812
step_settings = dict(
811-
role="SageMakerRole",
813+
role=ROLE,
812814
instance_type="local",
813815
image_uri=dummy_container,
814816
keep_alive_period_in_seconds=60,
@@ -833,7 +835,7 @@ def validate_file_exists(files_exists, files_does_not_exist):
833835
sagemaker_session=local_pipeline_session,
834836
)
835837

836-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
838+
pipeline.create(ROLE, "pipeline for sdk integ testing")
837839

838840
with lock.lock(LOCK_PATH):
839841
execution = pipeline.start()
@@ -851,7 +853,7 @@ def test_local_pipeline_with_step_decorator_and_condition_step(
851853
local_pipeline_session, dummy_container
852854
):
853855
step_settings = dict(
854-
role="SageMakerRole",
856+
role=ROLE,
855857
instance_type="local",
856858
image_uri=dummy_container,
857859
keep_alive_period_in_seconds=60,
@@ -888,7 +890,7 @@ def else_step():
888890
sagemaker_session=local_pipeline_session,
889891
)
890892

891-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
893+
pipeline.create(ROLE, "pipeline for sdk integ testing")
892894

893895
with lock.lock(LOCK_PATH):
894896
execution = pipeline.start()
@@ -916,7 +918,7 @@ def test_local_pipeline_with_step_decorator_data_referenced_by_other_steps(
916918
@step(
917919
name="step1",
918920
image_uri=dummy_container,
919-
role="SageMakerRole",
921+
role=ROLE,
920922
instance_type="ml.m5.xlarge",
921923
keep_alive_period_in_seconds=60,
922924
)
@@ -933,7 +935,7 @@ def func(var: int):
933935

934936
sklearn_processor = SKLearnProcessor(
935937
framework_version=sklearn_latest_version,
936-
role="SageMakerRole",
938+
role=ROLE,
937939
instance_type="local",
938940
instance_count=step_output[1],
939941
command=["python3"],
@@ -967,7 +969,7 @@ def func(var: int):
967969
sagemaker_session=local_pipeline_session,
968970
)
969971

970-
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
972+
pipeline.create(ROLE, "pipeline for sdk integ testing")
971973

972974
with lock.lock(LOCK_PATH):
973975
execution = pipeline.start()
@@ -983,3 +985,23 @@ def func(var: int):
983985
assert exe_step_result["StepStatus"] == "Succeeded"
984986
if exe_step_result["StepName"] == cond_step.name:
985987
assert exe_step_result["Metadata"]["Condition"]["Outcome"] is True
988+
989+
990+
def test_local_remote_function_with_additional_dependencies(
991+
local_pipeline_session, dummy_container
992+
):
993+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
994+
995+
@remote(
996+
role=ROLE,
997+
image_uri=dummy_container,
998+
dependencies=dependencies_path,
999+
instance_type="local",
1000+
sagemaker_session=local_pipeline_session,
1001+
)
1002+
def cuberoot(x):
1003+
from scipy.special import cbrt
1004+
1005+
return cbrt(x)
1006+
1007+
assert cuberoot(27) == 3

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def test_wait(session, mock_stored_function, mock_logs_for_job, *args):
11081108
job.wait(timeout=10)
11091109

11101110
mock_logs_for_job.assert_called_with(
1111-
boto_session=ANY, job_name=job.job_name, wait=True, timeout=10
1111+
sagemaker_session=ANY, job_name=job.job_name, wait=True, timeout=10
11121112
)
11131113

11141114

0 commit comments

Comments
 (0)