Skip to content

Commit b97ebb5

Browse files
committed
fix: Move sagemaker pysdk version check after bootstrap in remote job
1 parent 615a8ad commit b97ebb5

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def main(sys_args=None):
6565
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6666

6767
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68-
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69-
client_sagemaker_pysdk_version
70-
)
7168

7269
user = getpass.getuser()
7370
if user != "root":
@@ -89,6 +86,10 @@ def main(sys_args=None):
8986
client_python_version, conda_env, dependency_settings
9087
)
9188

89+
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
90+
client_sagemaker_pysdk_version
91+
)
92+
9293
exit_code = SUCCESS_EXIT_CODE
9394
except Exception as e: # pylint: disable=broad-except
9495
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import dataclasses
2525
import json
2626

27-
import sagemaker
28-
2927

3028
class _UTCFormatter(logging.Formatter):
3129
"""Class that overrides the default local time provider in log formatter."""
@@ -330,6 +328,7 @@ def _current_python_version(self):
330328

331329
def _current_sagemaker_pysdk_version(self):
332330
"""Returns the current sagemaker python sdk version where program is running"""
331+
import sagemaker
333332

334333
return sagemaker.__version__
335334

@@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
366365
):
367366
logger.warning(
368367
"Inconsistent sagemaker versions found: "
369-
"sagemaker pysdk version found in the container is "
368+
"sagemaker python sdk version found in the container is "
370369
"'%s' which does not match the '%s' on the local client. "
371-
"Please make sure that the python version used in the training container "
372-
"is the same as the local python version in case of unexpected behaviors.",
370+
"Please make sure that the sagemaker version used in the training container "
371+
"is the same as the local sagemaker version in case of unexpected behaviors.",
373372
job_sagemaker_pysdk_version,
374373
client_sagemaker_pysdk_version,
375374
)

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

-12
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,6 @@ def test_main_success_pipeline_step_with_root_user(
228228
_exit_process.assert_called_with(0)
229229

230230

231-
@patch(
232-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
233-
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
234-
)
235231
@patch(
236232
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
237233
"RuntimeEnvironmentManager._validate_python_version"
@@ -260,7 +256,6 @@ def test_main_failure_remote_job_with_root_user(
260256
write_failure,
261257
_exit_process,
262258
validate_python,
263-
validate_sagemaker,
264259
):
265260
runtime_err = RuntimeEnvironmentError("some failure reason")
266261
bootstrap_runtime.side_effect = runtime_err
@@ -269,17 +264,12 @@ def test_main_failure_remote_job_with_root_user(
269264

270265
change_dir_permission.assert_not_called()
271266
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
272-
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
273267
run_pre_exec_script.assert_not_called()
274268
bootstrap_runtime.assert_called()
275269
write_failure.assert_called_with(str(runtime_err))
276270
_exit_process.assert_called_with(1)
277271

278272

279-
@patch(
280-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
281-
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
282-
)
283273
@patch(
284274
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
285275
"RuntimeEnvironmentManager._validate_python_version"
@@ -308,7 +298,6 @@ def test_main_failure_pipeline_step_with_root_user(
308298
write_failure,
309299
_exit_process,
310300
validate_python,
311-
validate_sagemaker,
312301
):
313302
runtime_err = RuntimeEnvironmentError("some failure reason")
314303
bootstrap_runtime.side_effect = runtime_err
@@ -317,7 +306,6 @@ def test_main_failure_pipeline_step_with_root_user(
317306

318307
change_dir_permission.assert_not_called()
319308
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
320-
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
321309
run_pre_exec_script.assert_not_called()
322310
bootstrap_runtime.assert_called()
323311
write_failure.assert_called_with(str(runtime_err))

0 commit comments

Comments
 (0)