Skip to content

Commit 6a0224f

Browse files
beniericpintaoz-aws
authored andcommitted
Add unit tests for ModelTrainer (#1527)
* Add unit tests for ModelTrainer * Flake8 * format
1 parent fba3285 commit 6a0224f

File tree

5 files changed

+484
-19
lines changed

5 files changed

+484
-19
lines changed

src/sagemaker/modules/train/model_trainer.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ModelTrainer(BaseModel):
108108
The VPC configuration.
109109
"""
110110

111-
model_config = ConfigDict(arbitrary_types_allowed=True)
111+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
112112

113113
session: Optional[Session] = None
114114
role: Optional[str] = None
@@ -274,19 +274,20 @@ def train(
274274
if input_data_channels:
275275
self.input_data_channels = input_data_channels
276276
if source_code_config:
277+
self._validate_source_code_config(source_code_config)
277278
self.source_code_config = source_code_config
278279
if hyper_parameters:
279280
self.hyper_parameters = hyper_parameters
280281
if environment:
281282
self.environment = environment
282283

283-
input_data_config = self._get_input_data_config(self.input_data_channels)
284+
input_data_config = []
285+
if self.input_data_channels:
286+
input_data_config = self._get_input_data_config(self.input_data_channels)
284287

285288
container_entrypoint = None
286289
container_arguments = None
287290
if self.source_code_config:
288-
if not input_data_config:
289-
input_data_config = []
290291

291292
# If source code is provided, create a channel for the source code
292293
# The source code will be mounted at /opt/ml/input/data/code in the container
@@ -397,9 +398,14 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
397398
else:
398399
raise ValueError(f"Not a valid S3 URI or local file path: {data_source}.")
399400
elif isinstance(data_source, S3DataSource):
400-
channel = Channel(channel_name=channel_name, data_source=data_source)
401+
channel = Channel(
402+
channel_name=channel_name, data_source=DataSource(s3_data_source=data_source)
403+
)
401404
elif isinstance(data_source, FileSystemDataSource):
402-
channel = Channel(channel_name=channel_name, data_source=data_source)
405+
channel = Channel(
406+
channel_name=channel_name,
407+
data_source=DataSource(file_system_data_source=data_source),
408+
)
403409
return channel
404410

405411
def _get_input_data_config(

src/sagemaker/modules/utils.py

-13
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,6 @@
1919
from typing import Literal
2020

2121

22-
def _is_valid_ecr_image(image: str) -> bool:
23-
"""Check if the image is a valid ECR image URI.
24-
25-
Args:
26-
image (str): The image URI sting to validate
27-
28-
Returns:
29-
bool: True if the image is a valid ECR image URI, False otherwise
30-
"""
31-
pattern = r"^\d{12}\.dkr\.ecr\.\w+-\d{1,2}\.amazonaws\.com\/[\w-]+(:[\w.-]+)?$"
32-
return bool(re.match(pattern, image))
33-
34-
3522
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
3623
"""Check if the path is a valid S3 URI.
3724
+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utils Tests."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
18+
from tests.unit import DATA_DIR
19+
from sagemaker.modules.utils import (
20+
_is_valid_s3_uri,
21+
_is_valid_path,
22+
_get_unique_name,
23+
_get_repo_name_from_image,
24+
)
25+
26+
27+
@pytest.mark.parametrize(
28+
"test_case",
29+
[
30+
{
31+
"path": "s3://bucket/key",
32+
"path_type": "Any",
33+
"expected": True,
34+
},
35+
{
36+
"path": "s3://bucket/key",
37+
"path_type": "File",
38+
"expected": True,
39+
},
40+
{
41+
"path": "s3://bucket/key/",
42+
"path_type": "Directory",
43+
"expected": True,
44+
},
45+
{
46+
"path": "s3://bucket/key/",
47+
"path_type": "File",
48+
"expected": False,
49+
},
50+
{
51+
"path": "s3://bucket/key",
52+
"path_type": "Directory",
53+
"expected": False,
54+
},
55+
{
56+
"path": "/bucket/key",
57+
"path_type": "Any",
58+
"expected": False,
59+
},
60+
],
61+
)
62+
def test_is_valid_s3_uri(test_case):
63+
assert _is_valid_s3_uri(test_case["path"], test_case["path_type"]) == test_case["expected"]
64+
65+
66+
@pytest.mark.parametrize(
67+
"test_case",
68+
[
69+
{
70+
"path": DATA_DIR,
71+
"path_type": "Any",
72+
"expected": True,
73+
},
74+
{
75+
"path": DATA_DIR,
76+
"path_type": "Directory",
77+
"expected": True,
78+
},
79+
{
80+
"path": f"{DATA_DIR}/dummy_input.txt",
81+
"path_type": "File",
82+
"expected": True,
83+
},
84+
{
85+
"path": f"{DATA_DIR}/dummy_input.txt",
86+
"path_type": "Directory",
87+
"expected": False,
88+
},
89+
{
90+
"path": f"{DATA_DIR}/non_existent",
91+
"path_type": "Any",
92+
"expected": False,
93+
},
94+
],
95+
)
96+
def test_is_valid_path(test_case):
97+
assert _is_valid_path(test_case["path"], test_case["path_type"]) == test_case["expected"]
98+
99+
100+
@pytest.mark.parametrize(
101+
"test_case",
102+
[
103+
{
104+
"base": "test",
105+
"max_length": 5,
106+
},
107+
{
108+
"base": "1111111111" * 7,
109+
"max_length": None,
110+
},
111+
],
112+
)
113+
def test_get_unique_name(test_case):
114+
assert (
115+
len(_get_unique_name(test_case["base"], test_case.get("max_length")))
116+
<= test_case["max_length"]
117+
if test_case.get("max_length")
118+
else 63
119+
)
120+
121+
122+
@pytest.mark.parametrize(
123+
"test_case",
124+
[
125+
{
126+
"image": "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:latest",
127+
"expected": "my-custom-image",
128+
},
129+
{
130+
"image": "my-custom-image:latest",
131+
"expected": "my-custom-image",
132+
},
133+
{
134+
"image": "public.ecr.aws/docker/library/my-custom-image:latest",
135+
"expected": "my-custom-image",
136+
},
137+
],
138+
)
139+
def test_get_repo_name_from_image(test_case):
140+
assert _get_repo_name_from_image(test_case["image"]) == test_case["expected"]

tests/unit/sagemaker/modules/train/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)