Skip to content

change: Support local mode for remote function #4306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def wait(self, timeout: int = None):
"""

self._last_describe_response = _logs_for_job(
boto_session=self.sagemaker_session.boto_session,
sagemaker_session=self.sagemaker_session,
job_name=self.job_name,
wait=True,
timeout=timeout,
Expand Down
13 changes: 6 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5447,7 +5447,7 @@ def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=No
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If waiting and the training job fails.
"""
_logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout)
_logs_for_job(self, job_name, wait, poll, log_type, timeout)

def logs_for_processing_job(self, job_name, wait=False, poll=10):
"""Display logs for a given processing job, optionally tailing them until the is complete.
Expand Down Expand Up @@ -7330,17 +7330,16 @@ def _rule_statuses_changed(current_statuses, last_statuses):


def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
):
"""Display logs for a given training job, optionally tailing them until job is complete.

If the output is a tty or a Jupyter cell, it will be color-coded
based on which instance the log entry is from.

Args:
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
calls are delegated to (default: None). If not provided, one is created with
default AWS configuration chain.
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions.
job_name (str): Name of the training job to display the logs for.
wait (bool): Whether to keep looking for new log entries until the job completes
(default: False).
Expand All @@ -7357,13 +7356,13 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If waiting and the training job fails.
"""
sagemaker_client = boto_session.client("sagemaker")
sagemaker_client = sagemaker_session.sagemaker_client
request_end_time = time.time() + timeout if timeout else None
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
print(secondary_training_status_message(description, None), end="")

instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
boto_session, description, job="Training"
sagemaker_session.boto_session, description, job="Training"
)

state = _get_initial_job_state(description, "TrainingJobStatus", wait)
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_with_additional_dependencies(
def cuberoot(x):
from scipy.special import cbrt

return cbrt(27)
return cbrt(x)

assert cuberoot(27) == 3

Expand Down Expand Up @@ -742,7 +742,7 @@ def test_with_user_and_workdir_set_in_the_image(
def cuberoot(x):
from scipy.special import cbrt

return cbrt(27)
return cbrt(x)

assert cuberoot(27) == 3

Expand Down
72 changes: 47 additions & 25 deletions tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import stopit

import tests.integ.lock as lock
from sagemaker.remote_function import remote
from sagemaker.workflow.step_outputs import get_step
from tests.integ.sagemaker.conftest import _build_container, DOCKERFILE_TEMPLATE
from sagemaker.config import SESSION_DEFAULT_S3_BUCKET_PATH
Expand Down Expand Up @@ -58,6 +59,7 @@
LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_local_mode_lock")
DATA_PATH = os.path.join(DATA_DIR, "iris", "data")
DEFAULT_REGION = "us-west-2"
ROLE = "SageMakerRole"


class LocalNoS3Session(LocalSession):
Expand Down Expand Up @@ -147,7 +149,7 @@ def _create_model(output_path):

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
instance_count=1,
instance_type="local",
output_path=output_path,
Expand Down Expand Up @@ -218,7 +220,7 @@ def test_mxnet_local_mode(

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
py_version=mxnet_training_latest_py_version,
instance_count=1,
instance_type="local",
Expand Down Expand Up @@ -254,7 +256,7 @@ def test_mxnet_distributed_local_mode(

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
py_version=mxnet_training_latest_py_version,
instance_count=2,
instance_type="local",
Expand Down Expand Up @@ -289,7 +291,7 @@ def test_mxnet_local_data_local_script(

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
instance_count=1,
instance_type="local",
framework_version=mxnet_training_latest_version,
Expand Down Expand Up @@ -324,7 +326,7 @@ def test_mxnet_local_training_env(mxnet_training_latest_version, mxnet_training_

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
instance_count=1,
instance_type="local",
framework_version=mxnet_training_latest_version,
Expand All @@ -347,7 +349,7 @@ def test_mxnet_training_failure(

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
instance_count=1,
Expand Down Expand Up @@ -377,7 +379,7 @@ def test_local_transform_mxnet(

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
instance_count=1,
instance_type="local",
framework_version=mxnet_inference_latest_version,
Expand Down Expand Up @@ -426,7 +428,7 @@ def test_local_processing_sklearn(sagemaker_local_session_no_local_code, sklearn

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_latest_version,
role="SageMakerRole",
role=ROLE,
instance_type="local",
instance_count=1,
command=["python3"],
Expand Down Expand Up @@ -457,7 +459,7 @@ def test_local_processing_script_processor(sagemaker_local_session, sklearn_imag
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")

script_processor = ScriptProcessor(
role="SageMakerRole",
role=ROLE,
image_uri=sklearn_image_uri,
command=["python3"],
instance_count=1,
Expand Down Expand Up @@ -527,7 +529,7 @@ def test_local_pipeline_with_processing_step(sklearn_latest_version, local_pipel
string_container_arg = ParameterString(name="ProcessingContainerArg", default_value="foo")
sklearn_processor = SKLearnProcessor(
framework_version=sklearn_latest_version,
role="SageMakerRole",
role=ROLE,
instance_type="local",
instance_count=1,
command=["python3"],
Expand All @@ -549,7 +551,7 @@ def test_local_pipeline_with_processing_step(sklearn_latest_version, local_pipel
sagemaker_session=local_pipeline_session,
parameters=[string_container_arg],
)
pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand Down Expand Up @@ -586,7 +588,7 @@ def test_local_pipeline_with_training_and_transform_steps(
# define Estimator
mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
role=ROLE,
instance_count=instance_count,
instance_type="local",
framework_version=mxnet_training_latest_version,
Expand Down Expand Up @@ -614,7 +616,7 @@ def test_local_pipeline_with_training_and_transform_steps(
image_uri=inference_image_uri,
model_data=training_step.properties.ModelArtifacts.S3ModelArtifacts,
sagemaker_session=session,
role="SageMakerRole",
role=ROLE,
)

# define create model step
Expand Down Expand Up @@ -647,7 +649,7 @@ def test_local_pipeline_with_training_and_transform_steps(
sagemaker_session=session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start(parameters={"InstanceCountParam": 1})
Expand All @@ -667,7 +669,7 @@ def test_local_pipeline_with_training_and_transform_steps(
def test_local_pipeline_with_eval_cond_fail_steps(sklearn_image_uri, local_pipeline_session):
processor = ScriptProcessor(
image_uri=sklearn_image_uri,
role="SageMakerRole",
role=ROLE,
instance_count=1,
instance_type="local",
sagemaker_session=local_pipeline_session,
Expand Down Expand Up @@ -729,7 +731,7 @@ def test_local_pipeline_with_eval_cond_fail_steps(sklearn_image_uri, local_pipel
sagemaker_session=local_pipeline_session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand Down Expand Up @@ -763,7 +765,7 @@ def test_local_pipeline_with_step_decorator_and_step_dependency(
local_pipeline_session, dummy_container
):
step_settings = dict(
role="SageMakerRole",
role=ROLE,
instance_type="ml.m5.xlarge",
image_uri=dummy_container,
keep_alive_period_in_seconds=60,
Expand All @@ -787,7 +789,7 @@ def sum(a, b):
sagemaker_session=local_pipeline_session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand All @@ -808,7 +810,7 @@ def test_local_pipeline_with_step_decorator_and_pre_exe_script(
local_pipeline_session, dummy_container
):
step_settings = dict(
role="SageMakerRole",
role=ROLE,
instance_type="local",
image_uri=dummy_container,
keep_alive_period_in_seconds=60,
Expand All @@ -833,7 +835,7 @@ def validate_file_exists(files_exists, files_does_not_exist):
sagemaker_session=local_pipeline_session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand All @@ -851,7 +853,7 @@ def test_local_pipeline_with_step_decorator_and_condition_step(
local_pipeline_session, dummy_container
):
step_settings = dict(
role="SageMakerRole",
role=ROLE,
instance_type="local",
image_uri=dummy_container,
keep_alive_period_in_seconds=60,
Expand Down Expand Up @@ -888,7 +890,7 @@ def else_step():
sagemaker_session=local_pipeline_session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand Down Expand Up @@ -916,7 +918,7 @@ def test_local_pipeline_with_step_decorator_data_referenced_by_other_steps(
@step(
name="step1",
image_uri=dummy_container,
role="SageMakerRole",
role=ROLE,
instance_type="ml.m5.xlarge",
keep_alive_period_in_seconds=60,
)
Expand All @@ -933,7 +935,7 @@ def func(var: int):

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_latest_version,
role="SageMakerRole",
role=ROLE,
instance_type="local",
instance_count=step_output[1],
command=["python3"],
Expand Down Expand Up @@ -967,7 +969,7 @@ def func(var: int):
sagemaker_session=local_pipeline_session,
)

pipeline.create("SageMakerRole", "pipeline for sdk integ testing")
pipeline.create(ROLE, "pipeline for sdk integ testing")

with lock.lock(LOCK_PATH):
execution = pipeline.start()
Expand All @@ -983,3 +985,23 @@ def func(var: int):
assert exe_step_result["StepStatus"] == "Succeeded"
if exe_step_result["StepName"] == cond_step.name:
assert exe_step_result["Metadata"]["Condition"]["Outcome"] is True


def test_local_remote_function_with_additional_dependencies(
local_pipeline_session, dummy_container
):
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")

@remote(
role=ROLE,
image_uri=dummy_container,
dependencies=dependencies_path,
instance_type="local",
sagemaker_session=local_pipeline_session,
)
def cuberoot(x):
from scipy.special import cbrt

return cbrt(x)

assert cuberoot(27) == 3
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def test_wait(session, mock_stored_function, mock_logs_for_job, *args):
job.wait(timeout=10)

mock_logs_for_job.assert_called_with(
boto_session=ANY, job_name=job.job_name, wait=True, timeout=10
sagemaker_session=ANY, job_name=job.job_name, wait=True, timeout=10
)


Expand Down
Loading