diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index d5d879cb08..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": @@ -89,6 +86,10 @@ def main(sys_args=None): client_python_version, conda_env, dependency_settings ) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + exit_code = SUCCESS_EXIT_CODE except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while bootstrapping runtime environment: %s", e) diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 0dd5f0d219..13493c1d15 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,8 +24,6 @@ import dataclasses import json -import sagemaker - class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" @@ -330,6 +328,7 @@ def _current_python_version(self): def _current_sagemaker_pysdk_version(self): """Returns the current sagemaker python sdk version where program is running""" + import sagemaker return sagemaker.__version__ @@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): ): logger.warning( "Inconsistent sagemaker versions found: " - "sagemaker pysdk version found in the container is " + "sagemaker python sdk version found in the container is " "'%s' which does not match the '%s' on the local client. " - "Please make sure that the python version used in the training container " - "is the same as the local python version in case of unexpected behaviors.", + "Please make sure that the sagemaker version used in the training container " + "is the same as the local sagemaker version in case of unexpected behaviors.", job_sagemaker_pysdk_version, client_sagemaker_pysdk_version, ) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index b7d9e10047..ef35c965e9 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -269,7 +269,7 @@ def test_main_failure_remote_job_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) @@ -317,7 +317,7 @@ def test_main_failure_pipeline_step_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err))