diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 0c074a9881..db47786d39 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -433,3 +433,44 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( job_description = mx.latest_training_job.describe() assert job_description["ProfilerConfig"]["S3OutputPath"] == mx.output_path + + +def test_mxnet_profiling_with_disable_debugger_hook( + sagemaker_session, + mxnet_training_latest_version, + mxnet_training_latest_py_version, + cpu_instance_type, +): + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_training_latest_version, + py_version=mxnet_training_latest_py_version, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + debugger_hook_config=False, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + training_job_name = unique_name_from_base("test-profiler-mxnet-training") + mx.fit( + inputs={"train": train_input, "test": test_input}, + job_name=training_job_name, + wait=False, + ) + + job_description = mx.latest_training_job.describe() + # setting debugger_hook_config to false would not disable profiling + # https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-turn-off.html + assert job_description.get("ProfilingStatus") == "Enabled"