Skip to content

Commit 4a7a2f4

Browse files
qidewenwhenroot
authored and
root
committed
fix: Move sagemaker pysdk version check after bootstrap in remote job (aws#4487)
1 parent 69f9fa7 commit 4a7a2f4

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def main(sys_args=None):
6565
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6666

6767
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68-
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69-
client_sagemaker_pysdk_version
70-
)
7168

7269
user = getpass.getuser()
7370
if user != "root":
@@ -89,6 +86,10 @@ def main(sys_args=None):
8986
client_python_version, conda_env, dependency_settings
9087
)
9188

89+
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
90+
client_sagemaker_pysdk_version
91+
)
92+
9293
exit_code = SUCCESS_EXIT_CODE
9394
except Exception as e: # pylint: disable=broad-except
9495
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import dataclasses
2525
import json
2626

27-
import sagemaker
28-
2927

3028
class _UTCFormatter(logging.Formatter):
3129
"""Class that overrides the default local time provider in log formatter."""
@@ -330,6 +328,7 @@ def _current_python_version(self):
330328

331329
def _current_sagemaker_pysdk_version(self):
332330
"""Returns the current sagemaker python sdk version where program is running"""
331+
import sagemaker
333332

334333
return sagemaker.__version__
335334

@@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
366365
):
367366
logger.warning(
368367
"Inconsistent sagemaker versions found: "
369-
"sagemaker pysdk version found in the container is "
368+
"sagemaker python sdk version found in the container is "
370369
"'%s' which does not match the '%s' on the local client. "
371-
"Please make sure that the python version used in the training container "
372-
"is the same as the local python version in case of unexpected behaviors.",
370+
"Please make sure that the sagemaker version used in the training container "
371+
"is the same as the local sagemaker version in case of unexpected behaviors.",
373372
job_sagemaker_pysdk_version,
374373
client_sagemaker_pysdk_version,
375374
)

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_main_failure_remote_job_with_root_user(
269269

270270
change_dir_permission.assert_not_called()
271271
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
272-
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
272+
validate_sagemaker.assert_not_called()
273273
run_pre_exec_script.assert_not_called()
274274
bootstrap_runtime.assert_called()
275275
write_failure.assert_called_with(str(runtime_err))
@@ -317,7 +317,7 @@ def test_main_failure_pipeline_step_with_root_user(
317317

318318
change_dir_permission.assert_not_called()
319319
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
320-
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
320+
validate_sagemaker.assert_not_called()
321321
run_pre_exec_script.assert_not_called()
322322
bootstrap_runtime.assert_called()
323323
write_failure.assert_called_with(str(runtime_err))

0 commit comments

Comments
 (0)