Skip to content

Commit cf4cb16

Browse files
committed
fix tests
1 parent 4643d81 commit cf4cb16

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,11 @@ def model_post_init(self, __context: Any):
482482
logger.debug("Hyperparameters loaded as JSON")
483483
except json.JSONDecodeError:
484484
try:
485+
logger.info(f"contents: {contents}")
485486
self.hyperparameters = yaml.safe_load(contents)
487+
if not isinstance(self.hyperparameters, dict):
488+
raise ValueError(f"YAML contents must be a valid mapping")
489+
logger.info(f"hyperparameters: {self.hyperparameters}")
486490
logger.debug("Hyperparameters loaded as YAML")
487491
except (yaml.YAMLError, ValueError):
488492
raise ValueError(

tests/integ/sagemaker/modules/train/test_model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
entry_script="train.py",
4343
)
4444

45-
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
45+
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
4646

4747

4848
def test_hp_contract_basic_py_script(modules_sagemaker_session):

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,8 +1148,21 @@ def test_hyperparameters_not_exist(modules_session):
11481148
@patch("os.path.exists")
11491149
def test_hyperparameters_invalid(mock_exists, modules_session):
11501150
mock_exists.return_value = True
1151-
# Must be valid YAML or JSON
1152-
mock_file_open = mock_open(read_data="invalid")
1151+
1152+
# YAML contents must be a valid mapping
1153+
mock_file_open = mock_open(read_data="- item1\n- item2")
1154+
with patch("builtins.open", mock_file_open):
1155+
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
1156+
ModelTrainer(
1157+
training_image=DEFAULT_IMAGE,
1158+
role=DEFAULT_ROLE,
1159+
sagemaker_session=modules_session,
1160+
compute=DEFAULT_COMPUTE_CONFIG,
1161+
hyperparameters="hyperparameters.yaml",
1162+
)
1163+
1164+
# Must be valid YAML
1165+
mock_file_open = mock_open(read_data="* invalid")
11531166
with patch("builtins.open", mock_file_open):
11541167
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
11551168
ModelTrainer(

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ passenv =
8383
commands =
8484
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
8585
pip install 'apache-airflow==2.9.3' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.3/constraints-3.8.txt"
86-
pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
87-
pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
86+
pip install 'torch==2.0.1' -f 'https://download.pytorch.org/whl/torch_stable.html'
87+
pip install 'torchvision==0.15.2' -f 'https://download.pytorch.org/whl/torch_stable.html'
8888
pip install 'dill>=0.3.8'
8989

9090
pytest {posargs}

0 commit comments

Comments
 (0)