Skip to content

Commit 6bb2d41

Browse files
authored
change: Add validation for sagemaker version on remote job (aws#4393)
1 parent e4de46c commit 6bb2d41

File tree

8 files changed

+195
-3
lines changed

8 files changed

+195
-3
lines changed

src/sagemaker/remote_function/core/serialization.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
141141
return cloudpickle.loads(bytes_to_deserialize)
142142
except Exception as e:
143143
raise DeserializationError(
144-
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
144+
"Error when deserializing bytes downloaded from {}: {}. "
145+
"NOTE: this may be caused by inconsistent sagemaker python sdk versions "
146+
"where remote function runs versus the one used on client side. "
147+
"If the sagemaker versions do not match, a warning message would "
148+
"be logged starting with 'Inconsistent sagemaker versions found'. "
149+
"Please check it to validate.".format(s3_uri, repr(e))
145150
) from e
146151

147152

src/sagemaker/remote_function/job.py

+6
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,12 @@ def compile(
786786
container_args.extend(
787787
["--client_python_version", RuntimeEnvironmentManager()._current_python_version()]
788788
)
789+
container_args.extend(
790+
[
791+
"--client_sagemaker_pysdk_version",
792+
RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(),
793+
]
794+
)
789795
container_args.extend(
790796
[
791797
"--dependency_settings",

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def main(sys_args=None):
5656
try:
5757
args = _parse_args(sys_args)
5858
client_python_version = args.client_python_version
59+
client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
5960
job_conda_env = args.job_conda_env
6061
pipeline_execution_id = args.pipeline_execution_id
6162
dependency_settings = _DependencySettings.from_string(args.dependency_settings)
@@ -64,6 +65,9 @@ def main(sys_args=None):
6465
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6566

6667
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68+
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69+
client_sagemaker_pysdk_version
70+
)
6771

6872
user = getpass.getuser()
6973
if user != "root":
@@ -274,6 +278,7 @@ def _parse_args(sys_args):
274278
parser = argparse.ArgumentParser()
275279
parser.add_argument("--job_conda_env", type=str)
276280
parser.add_argument("--client_python_version", type=str)
281+
parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None)
277282
parser.add_argument("--pipeline_execution_id", type=str)
278283
parser.add_argument("--dependency_settings", type=str)
279284
parser.add_argument("--func_step_s3_dir", type=str)

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

+30
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import dataclasses
2525
import json
2626

27+
import sagemaker
28+
2729

2830
class _UTCFormatter(logging.Formatter):
2931
"""Class that overrides the default local time provider in log formatter."""
@@ -326,6 +328,11 @@ def _current_python_version(self):
326328

327329
return f"{sys.version_info.major}.{sys.version_info.minor}".strip()
328330

331+
def _current_sagemaker_pysdk_version(self):
332+
"""Returns the current sagemaker python sdk version where program is running"""
333+
334+
return sagemaker.__version__
335+
329336
def _validate_python_version(self, client_python_version: str, conda_env: str = None):
330337
"""Validate the python version
331338
@@ -344,6 +351,29 @@ def _validate_python_version(self, client_python_version: str, conda_env: str =
344351
f"is same as the local python version."
345352
)
346353

354+
def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
355+
"""Validate the sagemaker python sdk version
356+
357+
Validates if the sagemaker python sdk version where remote function runs
358+
matches the one used on client side.
359+
Otherwise, log a warning to call out that unexpected behaviors
360+
may occur in this case.
361+
"""
362+
job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version()
363+
if (
364+
client_sagemaker_pysdk_version
365+
and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version
366+
):
367+
logger.warning(
368+
"Inconsistent sagemaker versions found: "
369+
"sagemaker pysdk version found in the container is "
370+
"'%s' which does not match the '%s' on the local client. "
371+
"Please make sure that the python version used in the training container "
372+
"is the same as the local python version in case of unexpected behaviors.",
373+
job_sagemaker_pysdk_version,
374+
client_sagemaker_pysdk_version,
375+
)
376+
347377

348378
def _run_and_get_output_shell_cmd(cmd: str) -> str:
349379
"""Run and return the output of the given shell command"""

tests/unit/sagemaker/remote_function/core/test_serialization.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def square(x):
198198
with pytest.raises(
199199
DeserializationError,
200200
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
201-
+ r"RuntimeError\('some failure when loads'\)",
201+
+ r"RuntimeError\('some failure when loads'\). "
202+
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
202203
):
203204
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
204205

@@ -397,7 +398,8 @@ def __init__(self, x):
397398
with pytest.raises(
398399
DeserializationError,
399400
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
400-
+ r"RuntimeError\('some failure when loads'\)",
401+
+ r"RuntimeError\('some failure when loads'\). "
402+
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
401403
):
402404
deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
403405

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

+102
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
CURR_WORKING_DIR = "/user/set/workdir"
2828
TEST_DEPENDENCIES_PATH = "/user/set/workdir/sagemaker_remote_function_workspace"
2929
TEST_PYTHON_VERSION = "3.10"
30+
TEST_SAGEMAKER_PYSDK_VERSION = "2.205.0"
3031
TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws"
3132
TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip"
3233
TEST_EXECUTION_ID = "test_execution_id"
@@ -44,6 +45,8 @@ def args_for_remote():
4445
TEST_JOB_CONDA_ENV,
4546
"--client_python_version",
4647
TEST_PYTHON_VERSION,
48+
"--client_sagemaker_pysdk_version",
49+
TEST_SAGEMAKER_PYSDK_VERSION,
4750
"--dependency_settings",
4851
_DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(),
4952
]
@@ -55,6 +58,8 @@ def args_for_step():
5558
TEST_JOB_CONDA_ENV,
5659
"--client_python_version",
5760
TEST_PYTHON_VERSION,
61+
"--client_sagemaker_pysdk_version",
62+
TEST_SAGEMAKER_PYSDK_VERSION,
5863
"--pipeline_execution_id",
5964
TEST_EXECUTION_ID,
6065
"--func_step_s3_dir",
@@ -63,6 +68,10 @@ def args_for_step():
6368

6469

6570
@patch("sys.exit")
71+
@patch(
72+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
73+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
74+
)
6675
@patch(
6776
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
6877
"RuntimeEnvironmentManager._validate_python_version"
@@ -90,12 +99,75 @@ def test_main_success_remote_job_with_root_user(
9099
run_pre_exec_script,
91100
bootstrap_runtime,
92101
validate_python,
102+
validate_sagemaker,
93103
_exit_process,
94104
):
95105
bootstrap.main(args_for_remote())
96106

97107
change_dir_permission.assert_not_called()
98108
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
109+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
110+
bootstrap_remote.assert_called_once_with(
111+
TEST_PYTHON_VERSION,
112+
TEST_JOB_CONDA_ENV,
113+
_DependencySettings(TEST_DEPENDENCY_FILE_NAME),
114+
)
115+
run_pre_exec_script.assert_not_called()
116+
bootstrap_runtime.assert_not_called()
117+
_exit_process.assert_called_with(0)
118+
119+
120+
@patch("sys.exit")
121+
@patch(
122+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
123+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
124+
)
125+
@patch(
126+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
127+
"RuntimeEnvironmentManager._validate_python_version"
128+
)
129+
@patch(
130+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
131+
"RuntimeEnvironmentManager.bootstrap"
132+
)
133+
@patch(
134+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
135+
"RuntimeEnvironmentManager.run_pre_exec_script"
136+
)
137+
@patch(
138+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment."
139+
"_bootstrap_runtime_env_for_remote_function"
140+
)
141+
@patch("getpass.getuser", MagicMock(return_value="root"))
142+
@patch(
143+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
144+
"RuntimeEnvironmentManager.change_dir_permission"
145+
)
146+
def test_main_success_with_obsoleted_args_that_missing_sagemaker_version(
147+
change_dir_permission,
148+
bootstrap_remote,
149+
run_pre_exec_script,
150+
bootstrap_runtime,
151+
validate_python,
152+
validate_sagemaker,
153+
_exit_process,
154+
):
155+
# This test is to test the backward compatibility
156+
# In old version of SDK, the client side sagemaker_pysdk_version is not passed to job
157+
# thus it would be None and would not lead to the warning
158+
obsoleted_args = [
159+
"--job_conda_env",
160+
TEST_JOB_CONDA_ENV,
161+
"--client_python_version",
162+
TEST_PYTHON_VERSION,
163+
"--dependency_settings",
164+
_DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(),
165+
]
166+
bootstrap.main(obsoleted_args)
167+
168+
change_dir_permission.assert_not_called()
169+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
170+
validate_sagemaker.assert_called_once_with(None)
99171
bootstrap_remote.assert_called_once_with(
100172
TEST_PYTHON_VERSION,
101173
TEST_JOB_CONDA_ENV,
@@ -107,6 +179,10 @@ def test_main_success_remote_job_with_root_user(
107179

108180

109181
@patch("sys.exit")
182+
@patch(
183+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
184+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
185+
)
110186
@patch(
111187
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
112188
"RuntimeEnvironmentManager._validate_python_version"
@@ -134,11 +210,13 @@ def test_main_success_pipeline_step_with_root_user(
134210
run_pre_exec_script,
135211
bootstrap_runtime,
136212
validate_python,
213+
validate_sagemaker,
137214
_exit_process,
138215
):
139216
bootstrap.main(args_for_step())
140217
change_dir_permission.assert_not_called()
141218
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
219+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
142220
bootstrap_step.assert_called_once_with(
143221
TEST_PYTHON_VERSION,
144222
FUNC_STEP_WORKSPACE,
@@ -150,6 +228,10 @@ def test_main_success_pipeline_step_with_root_user(
150228
_exit_process.assert_called_with(0)
151229

152230

231+
@patch(
232+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
233+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
234+
)
153235
@patch(
154236
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
155237
"RuntimeEnvironmentManager._validate_python_version"
@@ -178,6 +260,7 @@ def test_main_failure_remote_job_with_root_user(
178260
write_failure,
179261
_exit_process,
180262
validate_python,
263+
validate_sagemaker,
181264
):
182265
runtime_err = RuntimeEnvironmentError("some failure reason")
183266
bootstrap_runtime.side_effect = runtime_err
@@ -186,12 +269,17 @@ def test_main_failure_remote_job_with_root_user(
186269

187270
change_dir_permission.assert_not_called()
188271
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
272+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
189273
run_pre_exec_script.assert_not_called()
190274
bootstrap_runtime.assert_called()
191275
write_failure.assert_called_with(str(runtime_err))
192276
_exit_process.assert_called_with(1)
193277

194278

279+
@patch(
280+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
281+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
282+
)
195283
@patch(
196284
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
197285
"RuntimeEnvironmentManager._validate_python_version"
@@ -220,6 +308,7 @@ def test_main_failure_pipeline_step_with_root_user(
220308
write_failure,
221309
_exit_process,
222310
validate_python,
311+
validate_sagemaker,
223312
):
224313
runtime_err = RuntimeEnvironmentError("some failure reason")
225314
bootstrap_runtime.side_effect = runtime_err
@@ -228,13 +317,18 @@ def test_main_failure_pipeline_step_with_root_user(
228317

229318
change_dir_permission.assert_not_called()
230319
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
320+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
231321
run_pre_exec_script.assert_not_called()
232322
bootstrap_runtime.assert_called()
233323
write_failure.assert_called_with(str(runtime_err))
234324
_exit_process.assert_called_with(1)
235325

236326

237327
@patch("sys.exit")
328+
@patch(
329+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
330+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
331+
)
238332
@patch(
239333
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
240334
"RuntimeEnvironmentManager._validate_python_version"
@@ -262,6 +356,7 @@ def test_main_remote_job_with_non_root_user(
262356
run_pre_exec_script,
263357
bootstrap_runtime,
264358
validate_python,
359+
validate_sagemaker,
265360
_exit_process,
266361
):
267362
bootstrap.main(args_for_remote())
@@ -270,6 +365,7 @@ def test_main_remote_job_with_non_root_user(
270365
dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777"
271366
)
272367
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
368+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
273369
bootstrap_remote.assert_called_once_with(
274370
TEST_PYTHON_VERSION,
275371
TEST_JOB_CONDA_ENV,
@@ -281,6 +377,10 @@ def test_main_remote_job_with_non_root_user(
281377

282378

283379
@patch("sys.exit")
380+
@patch(
381+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
382+
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
383+
)
284384
@patch(
285385
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
286386
"RuntimeEnvironmentManager._validate_python_version"
@@ -308,6 +408,7 @@ def test_main_pipeline_step_with_non_root_user(
308408
run_pre_exec_script,
309409
bootstrap_runtime,
310410
validate_python,
411+
validate_sagemaker,
311412
_exit_process,
312413
):
313414
bootstrap.main(args_for_step())
@@ -316,6 +417,7 @@ def test_main_pipeline_step_with_non_root_user(
316417
dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777"
317418
)
318419
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
420+
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
319421
bootstrap_step.assert_called_once_with(
320422
TEST_PYTHON_VERSION,
321423
FUNC_STEP_WORKSPACE,

0 commit comments

Comments
 (0)