Skip to content

Commit f3dab1e

Browse files
beniericroot
authored and
root
committed
fix: forbid extras in Configs (aws#5042)
* fix: make configs safer * fix: safer destructor in ModelTrainer * format * Update error message * pylint * Create BaseConfig
1 parent 945db32 commit f3dab1e

File tree

4 files changed

+64
-13
lines changed

4 files changed

+64
-13
lines changed

src/sagemaker/modules/configs.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from __future__ import absolute_import
2323

2424
from typing import Optional, Union
25-
from pydantic import BaseModel, model_validator
25+
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker_core.shapes as shapes
2828

@@ -74,7 +74,13 @@
7474
]
7575

7676

77-
class SourceCode(BaseModel):
77+
class BaseConfig(BaseModel):
78+
"""BaseConfig"""
79+
80+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
81+
82+
83+
class SourceCode(BaseConfig):
7884
"""SourceCode.
7985
8086
The SourceCode class allows the user to specify the source code location, dependencies,
@@ -194,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig:
194200
return shapes.VpcConfig(**filtered_dict)
195201

196202

197-
class InputData(BaseModel):
203+
class InputData(BaseConfig):
198204
"""InputData.
199205
200206
This config allows the user to specify an input data source for the training job.

src/sagemaker/modules/distributed.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional, Dict, Any, List
17-
from pydantic import BaseModel, PrivateAttr
17+
from pydantic import PrivateAttr
1818
from sagemaker.modules.utils import safe_serialize
19+
from sagemaker.modules.configs import BaseConfig
1920

2021

21-
class SMP(BaseModel):
22+
class SMP(BaseConfig):
2223
"""SMP.
2324
2425
This class is used for configuring the SageMaker Model Parallelism v2 parameters.
@@ -72,7 +73,7 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7273
return hyperparameters
7374

7475

75-
class DistributedConfig(BaseModel):
76+
class DistributedConfig(BaseConfig):
7677
"""Base class for distributed training configurations."""
7778

7879
_type: str = PrivateAttr()

src/sagemaker/modules/train/model_trainer.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
205205
"LOCAL_CONTAINER" mode.
206206
"""
207207

208-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
208+
model_config = ConfigDict(
209+
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
210+
)
209211

210212
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
211213
sagemaker_session: Optional[Session] = None
@@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
363365

364366
def __del__(self):
365367
"""Destructor method to clean up the temporary directory."""
366-
# Clean up the temporary directory if it exists
367-
if self._temp_recipe_train_dir is not None:
368-
self._temp_recipe_train_dir.cleanup()
368+
# Clean up the temporary directory if it exists and class was initialized
369+
if hasattr(self, "__pydantic_fields_set__"):
370+
if self._temp_recipe_train_dir is not None:
371+
self._temp_recipe_train_dir.cleanup()
369372

370373
def _validate_training_image_and_algorithm_name(
371374
self, training_image: Optional[str], algorithm_name: Optional[str]
@@ -792,14 +795,14 @@ def _prepare_train_script(
792795
"""Prepare the training script to be executed in the training job container.
793796
794797
Args:
795-
source_code (SourceCodeConfig): The source code configuration.
798+
source_code (SourceCode): The source code configuration.
796799
"""
797800

798801
base_command = ""
799802
if source_code.command:
800803
if source_code.entry_script:
801804
logger.warning(
802-
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
805+
"Both 'command' and 'entry_script' are provided in the SourceCode. "
803806
+ "Defaulting to 'command'."
804807
)
805808
base_command = source_code.command.split()
@@ -831,6 +834,13 @@ def _prepare_train_script(
831834
+ "Only .py and .sh scripts are supported."
832835
)
833836
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
837+
else:
838+
# This should never be reached, as the source_code should have been validated.
839+
raise ValueError(
840+
f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}."
841+
+ "Please provide a valid configuration with atleast one of 'command'"
842+
+ " or entry_script'."
843+
)
834844

835845
train_script = TRAIN_SCRIPT_TEMPLATE.format(
836846
working_dir=working_dir,

tests/unit/sagemaker/modules/train/test_model_trainer.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import pytest
21+
from pydantic import ValidationError
2122
from unittest.mock import patch, MagicMock, ANY
2223

2324
from sagemaker import image_uris
@@ -438,7 +439,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
438439
{
439440
"source_code": DEFAULT_SOURCE_CODE,
440441
"distributed": MPI(
441-
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
442+
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
442443
),
443444
"expected_template": EXECUTE_MPI_DRIVER,
444445
"expected_hyperparameters": {},
@@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix):
10591060
hyper_parameters=hyperparameters,
10601061
environment=environment,
10611062
)
1063+
1064+
1065+
def test_safe_configs():
1066+
# Test extra fails
1067+
with pytest.raises(ValueError):
1068+
SourceCode(entry_point="train.py")
1069+
# Test invalid type fails
1070+
with pytest.raises(ValueError):
1071+
SourceCode(entry_script=1)
1072+
1073+
1074+
@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
1075+
def test_destructor_cleanup(mock_tmp_dir, modules_session):
1076+
1077+
with pytest.raises(ValidationError):
1078+
model_trainer = ModelTrainer(
1079+
training_image=DEFAULT_IMAGE,
1080+
role=DEFAULT_ROLE,
1081+
sagemaker_session=modules_session,
1082+
compute="test",
1083+
)
1084+
mock_tmp_dir.cleanup.assert_not_called()
1085+
1086+
model_trainer = ModelTrainer(
1087+
training_image=DEFAULT_IMAGE,
1088+
role=DEFAULT_ROLE,
1089+
sagemaker_session=modules_session,
1090+
compute=DEFAULT_COMPUTE_CONFIG,
1091+
)
1092+
model_trainer._temp_recipe_train_dir = mock_tmp_dir
1093+
mock_tmp_dir.assert_not_called()
1094+
del model_trainer
1095+
mock_tmp_dir.cleanup.assert_called_once()

0 commit comments

Comments
 (0)