diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 1492b6bafa..69e06419f2 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -664,6 +664,10 @@ def __enter__(self): if self._inside_load_context: raise RuntimeError(nested_with_err_msg_template.format("load_run")) self._inside_load_context = True + if not self._inside_init_context: + # Add to run context only if the load_run is called separately + # without under a Run init context + _RunContext.add_run_object(self) else: if _RunContext.get_current_run(): raise RuntimeError(nested_with_err_msg_template.format("Run")) @@ -692,6 +696,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback): if self._in_load: self._inside_load_context = False self._in_load = False + if not self._inside_init_context: + _RunContext.drop_current_run() else: self._inside_init_context = False _RunContext.drop_current_run() diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 713a6a3792..96aad30dc0 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -170,11 +170,11 @@ def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_na def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar): # Notes: - # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 1. The 1st Run created locally and its exp config was auto passed to the job # 2. In training job, the same exp and run names are given in the Run constructor - # which will load the 1st Run TC in training job and log parameters + # which will load the 1st Run in training job and log parameters # and metrics there - # 3. In a different training job, load the same Run TC and log more parameters there. + # 3. In a different training job, load the same Run and log more parameters there. exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) estimator = _generate_estimator( sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name @@ -253,12 +253,12 @@ def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, d def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar): # Notes: - # 1. The 1st Run TC created locally and its exp config was auto passed to the job - # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given - # in the Run constructor which will create a Run TC according to the run_name + # 1. The 1st Run created locally and its exp config was auto passed to the job + # 2. In training job, different exp and run names (i.e. 2nd Run) are given + # in the Run constructor which will create a Run according to the run_name # passed in there and ignore the exp config in the job - # 3. Both metrics and parameters are logged in the Run TC created in job - # 4. In a different training job, load the 2nd Run TC and log more parameters there. + # 3. Both metrics and parameters are logged in the Run created in job + # 4. In a different training job, load the 2nd Run and log more parameters there. exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) estimator = _generate_estimator( @@ -328,11 +328,11 @@ def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, d def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar): # Notes: - # 1. No Run TC created locally or specified in experiment config + # 1. No Run created locally or specified in experiment config # 2. In training job, Run is initialized - # which will create a Run TC according to the run_name passed in there - # 3. Both metrics and parameters are logged in the Run TC created in job - # 4. In a different training job, load the same Run TC and log more parameters there. + # which will create a Run according to the run_name passed in there + # 3. Both metrics and parameters are logged in the Run created in job + # 4. In a different training job, load the same Run and log more parameters there. exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) estimator = _generate_estimator( sdk_tar=dev_sdk_tar, @@ -370,13 +370,13 @@ def test_run_from_processing_job_and_override_default_exp_config( sagemaker_session, dev_sdk_tar, run_obj ): # Notes: - # 1. The 1st Run TC (run) created locally - # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run - # but override the default experiment config in context of 2nd Run TC - # with the experiment config of the 1st Run TC - # 3. In the processing job script, load the 1st Run TC via the experiment config + # 1. The 1st Run (run) created locally + # 2. Within the 2nd Run (run_obj)'s context, invoke processor.run + # but override the default experiment config in context of 2nd Run + # with the experiment config of the 1st Run + # 3. In the processing job script, load the 1st Run via the experiment config # fetched from the job env - # 4. All data are logged in the Run TC either locally or in the processing job + # 4. All data are logged in the Run either locally or in the processing job exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) processor = FrameworkProcessor( estimator_cls=PyTorch, @@ -441,14 +441,15 @@ def test_run_from_processing_job_and_override_default_exp_config( # dev_sdk_tar is required to trigger generating the dev SDK tar -def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version): +def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_version): # Notes: - # 1. The 1st Run TC (run) created locally - # 2. In the inference script running in a transform job, load the 1st Run TC - # via explicitly passing the experiment_name and run_name of the 1st Run TC + # 1. The 1st Run (run) created locally + # 2. In the inference script running in a transform job, load the 1st Run + # via explicitly passing the experiment_name and run_name of the 1st Run # TODO: once we're able to retrieve exp config from the transform job env, # we should expand this test and add the load_run() without explicitly supplying the names - # 3. All data are logged in the Run TC either locally or in the transform job + # 3. All data are logged in the Run either locally or in the transform job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) xgb_model_data_s3 = sagemaker_session.upload_data( path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"), key_prefix="integ-test-data/xgboost/model", @@ -461,8 +462,8 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost source_dir=_EXP_DIR, framework_version=xgboost_latest_version, env={ - "EXPERIMENT_NAME": run_obj.experiment_name, - "RUN_NAME": run_obj.run_name, + "EXPERIMENT_NAME": exp_name, + "RUN_NAME": _RUN_NAME_IN_SCRIPT, }, ) transformer = xgboost_model.transformer( @@ -481,25 +482,83 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session ) - with run_obj: - _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) - transformer.transform( - data=input_data, - content_type="text/libsvm", - split_type="Line", - wait=True, - job_name=f"transform-job-{name()}", + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + transformer.transform( + data=input_data, + content_type="text/libsvm", + split_type="Line", + wait=True, + job_name=f"transform-job-{name()}", + ) + + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + - _check_run_from_local_end_result( - tc=run_obj._trial_component, +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_load_run_auto_pass_in_exp_config_to_job(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. In local side, load the Run created previously and invoke a job under the load context + # 2. In the job script, load the 1st Run via exp config auto-passed to the job env + # 3. All data are logged in the Run either locally or in the transform job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + processor = FrameworkProcessor( + estimator_cls=PyTorch, + framework_version="1.10", + py_version="py38", + instance_count=1, + instance_type="ml.m5.xlarge", + role=EXECUTION_ROLE, sagemaker_session=sagemaker_session, - is_complete_log=False, ) - tc_name = Run._generate_trial_component_name( - experiment_name=run_obj.experiment_name, run_name=run_obj.run_name - ) - _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + + with load_run( + experiment_name=run.experiment_name, + run_name=run.run_name, + sagemaker_session=sagemaker_session, + ): + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + ) + + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) def test_list(run_obj, sagemaker_session): diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py index 7e068136a1..e63a1256a5 100644 --- a/tests/unit/sagemaker/experiments/test_run_context.py +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -16,11 +16,15 @@ import pytest +from sagemaker import Processor from sagemaker.estimator import Estimator, _TrainingJob from sagemaker.experiments.experiment import _Experiment from sagemaker.experiments.run import _RunContext from sagemaker.experiments import load_run, Run from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.processing import ProcessingJob +from sagemaker.transformer import _TransformJob, Transformer from tests.unit.sagemaker.experiments.helpers import ( TEST_EXP_NAME, mock_trial_load_or_create_func, @@ -56,6 +60,97 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker assert not _RunContext.get_current_run() +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._Experiment._load_or_create") +@patch("sagemaker.experiments.run._Trial._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +@patch.object(_TrainingJob, "start_new") +def test_auto_pass_in_exp_config_under_load_run( + mock_start_job, mock_load_tc, mock_load_trial, mock_load_exp, run_obj, sagemaker_session +): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + mock_load_tc.return_value = run_obj._trial_component, True + mock_load_trial.return_value = run_obj._trial + mock_load_exp.return_value = run_obj._experiment + with load_run( + run_name=run_obj.run_name, + experiment_name=run_obj.experiment_name, + sagemaker_session=sagemaker_session, + ): + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + inputs=_train_input_path, + wait=False, + ) + + loaded_run = _RunContext.get_current_run() + assert loaded_run.run_name == run_obj.run_name + assert loaded_run.experiment_config == run_obj.experiment_config + + expected_exp_config = run_obj.experiment_config + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(ProcessingJob, "start_new") +def test_auto_pass_in_exp_config_to_process_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = ProcessingJob(sagemaker_session, "my-job", [], [], "") + with run_obj: + processor = Processor( + role="arn:my-role", + image_uri="my-image", + instance_count=1, + instance_type="ml.m5.large", + sagemaker_session=sagemaker_session, + ) + processor.run(wait=False, logs=False) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + assert mock_start_job.call_args[1]["experiment_config"] == expected_exp_config + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(_TransformJob, "start_new") +def test_auto_pass_in_exp_config_to_transform_job(mock_start_job, run_obj, sagemaker_session): + bucket_name = "my-bucket" + job_name = "my-job" + mock_start_job.return_value = _TransformJob(sagemaker_session, job_name) + with run_obj: + transformer = Transformer( + model_name="my-model", + instance_count=1, + instance_type="ml.m5.large", + output_path=f"s3://{bucket_name}/output", + sagemaker_session=sagemaker_session, + ) + transformer.transform( + data=f"s3://{bucket_name}/data", wait=False, logs=False, job_name=job_name + ) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + assert mock_start_job.call_args[0][9] == expected_exp_config + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +# TODO: add unit test for test_auto_pass_in_exp_config_to_tuning_job once ready + + @patch.object(_TrainingJob, "start_new") def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") @@ -94,6 +189,7 @@ def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session): def train(): with load_run(sagemaker_session=sagemaker_session) as run_load: assert run_load == run_obj + assert _RunContext.get_current_run() == run_obj assert run_obj._inside_init_context assert run_obj._inside_load_context assert run_obj._in_load @@ -105,14 +201,14 @@ def train(): assert run_obj._inside_init_context assert not run_obj._inside_load_context assert not run_obj._in_load - assert _RunContext.get_current_run() + assert _RunContext.get_current_run() == run_obj train() assert run_obj._inside_init_context assert not run_obj._inside_load_context assert not run_obj._in_load - assert _RunContext.get_current_run() + assert _RunContext.get_current_run() == run_obj run_obj.log_parameters({"a": "b", "c": 2}) @@ -132,7 +228,7 @@ def train(): assert run_obj._inside_init_context assert not run_obj._inside_load_context assert not run_obj._in_load - assert _RunContext.get_current_run() + assert _RunContext.get_current_run() == run_obj assert not run_obj._inside_init_context assert not run_obj._inside_load_context @@ -176,6 +272,14 @@ def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_sess pass assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + with pytest.raises(RuntimeError) as err: + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + assert _RunContext.get_current_run() + + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + def test_nested_run_load_context(run_obj, sagemaker_session): assert not _RunContext.get_current_run() @@ -189,3 +293,37 @@ def test_nested_run_load_context(run_obj, sagemaker_session): with run_load: pass assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with load_run(): + with load_run(): + pass + assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) + + +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._Experiment._load_or_create") +@patch("sagemaker.experiments.run._Trial._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +def test_run_init_under_run_load_context( + mock_load_tc, mock_load_trial, mock_load_exp, run_obj, sagemaker_session +): + mock_load_tc.return_value = run_obj._trial_component, True + mock_load_trial.return_value = run_obj._trial + mock_load_exp.return_value = run_obj._experiment + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with load_run( + run_name=run_obj.run_name, + experiment_name=run_obj.experiment_name, + sagemaker_session=sagemaker_session, + ): + assert _RunContext.get_current_run() + + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err)