Skip to content

Commit e6244db

Browse files
beniericpintaoz-aws
authored andcommitted
Base model trainer (#1521)
* Base model trainer * flake8 * add testing notebook * add param validation & set defaults * Implement simple train method
1 parent 6333914 commit e6244db

File tree

10 files changed

+658
-0
lines changed

10 files changed

+658
-0
lines changed

src/sagemaker/modules/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker modules directory."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker_core.main.utils import logger as sagemaker_core_logger
17+
18+
logger = sagemaker_core_logger

src/sagemaker/modules/configs.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration classes."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional
17+
from pydantic import BaseModel
18+
19+
from sagemaker_core.shapes import (
20+
ResourceConfig,
21+
StoppingCondition,
22+
OutputDataConfig,
23+
AlgorithmSpecification,
24+
Channel,
25+
S3DataSource,
26+
FileSystemDataSource,
27+
TrainingImageConfig,
28+
VpcConfig,
29+
)
30+
31+
__all__ = [
32+
"SourceCodeConfig",
33+
"ResourceConfig",
34+
"StoppingCondition",
35+
"OutputDataConfig",
36+
"AlgorithmSpecification",
37+
"Channel",
38+
"S3DataSource",
39+
"FileSystemDataSource",
40+
"TrainingImageConfig",
41+
"VpcConfig",
42+
]
43+
44+
45+
class SourceCodeConfig(BaseModel):
46+
"""SourceCodeConfig.
47+
48+
This config allows the user to specify the source code location, dependencies,
49+
entry script, or commands to be executed in the training job container.
50+
51+
Attributes:
52+
command (Optional[str]):
53+
The command(s) to execute in the training job container. Example: "python my_script.py".
54+
If not specified, entry_script must be provided
55+
source_dir (Optional[str]):
56+
The local directory containing the source code to be used in the training job container.
57+
requirements (Optional[str]):
58+
The path within `source_dir` to a `requirements.txt` file. If specified, the listed
59+
requirements will be installed in the training job container.
60+
entry_script (Optional[str]):
61+
The path within `source_dir` to the entry script that will be executed in the training
62+
job container. If not specified, command must be provided.
63+
"""
64+
65+
command: Optional[str]
66+
source_dir: Optional[str]
67+
requirements: Optional[str]
68+
entry_script: Optional[str]

src/sagemaker/modules/constants.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Constants module."""
14+
from __future__ import absolute_import
15+
16+
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"

src/sagemaker/modules/image_spec.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""ImageSpec class module."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional
17+
18+
from sagemaker import image_uris, Session
19+
from sagemaker.serverless import ServerlessInferenceConfig
20+
from sagemaker.training_compiler.config import TrainingCompilerConfig
21+
22+
23+
class ImageSpec:
24+
"""ImageSpec class to get image URI for a specific framework version."""
25+
26+
def __init__(
27+
self,
28+
framework_name: str,
29+
version: str,
30+
image_scope: Optional[str] = None,
31+
instance_type: Optional[str] = None,
32+
py_version: Optional[str] = None,
33+
region: Optional[str] = "us-west-2",
34+
accelerator_type: Optional[str] = None,
35+
container_version: Optional[str] = None,
36+
distribution: Optional[dict] = None,
37+
base_framework_version: Optional[str] = None,
38+
training_compiler_config: Optional[TrainingCompilerConfig] = None,
39+
model_id: Optional[str] = None,
40+
model_version: Optional[str] = None,
41+
hub_arn: Optional[str] = None,
42+
tolerate_vulnerable_model: Optional[bool] = False,
43+
tolerate_deprecated_model: Optional[bool] = False,
44+
sdk_version: Optional[str] = None,
45+
inference_tool: Optional[str] = None,
46+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
47+
config_name: Optional[str] = None,
48+
sagemaker_session: Optional[Session] = None,
49+
):
50+
self.framework_name = framework_name
51+
self.version = version
52+
self.image_scope = image_scope
53+
self.instance_type = instance_type
54+
self.py_version = py_version
55+
self.region = region
56+
self.accelerator_type = accelerator_type
57+
self.container_version = container_version
58+
self.distribution = distribution
59+
self.base_framework_version = base_framework_version
60+
self.training_compiler_config = training_compiler_config
61+
self.model_id = model_id
62+
self.model_version = model_version
63+
self.hub_arn = hub_arn
64+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
65+
self.tolerate_deprecated_model = tolerate_deprecated_model
66+
self.sdk_version = sdk_version
67+
self.inference_tool = inference_tool
68+
self.serverless_inference_config = serverless_inference_config
69+
self.config_name = config_name
70+
self.sagemaker_session = sagemaker_session
71+
72+
def get_image_uri(self):
73+
"""Get image URI for a specific framework version."""
74+
return image_uris.retrieve(
75+
framework=self.framework_name,
76+
image_scope=self.image_scope,
77+
instance_type=self.instance_type,
78+
py_version=self.py_version,
79+
region=self.region,
80+
version=self.version,
81+
accelerator_type=self.accelerator_type,
82+
container_version=self.container_version,
83+
distribution=self.distribution,
84+
base_framework_version=self.base_framework_version,
85+
training_compiler_config=self.training_compiler_config,
86+
model_id=self.model_id,
87+
model_version=self.model_version,
88+
hub_arn=self.hub_arn,
89+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
90+
tolerate_deprecated_model=self.tolerate_deprecated_model,
91+
sdk_version=self.sdk_version,
92+
inference_tool=self.inference_tool,
93+
serverless_inference_config=self.serverless_inference_config,
94+
config_name=self.config_name,
95+
sagemaker_session=self.sagemaker_session,
96+
)
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# place holder for training script
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import sys, os\n",
10+
"# Get the absolute path of the root directory\n",
11+
"root_dir = os.path.abspath(os.path.join(os.getcwd(), '../../..'))\n",
12+
"sys.path.insert(0, root_dir)"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"from sagemaker.modules.train.model_trainer import ModelTrainer\n",
22+
"\n",
23+
"model_trainer = ModelTrainer(training_image=\"python:3.10.15-slim\")\n"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": []
32+
}
33+
],
34+
"metadata": {
35+
"kernelspec": {
36+
"display_name": "Python 3",
37+
"language": "python",
38+
"name": "python3"
39+
},
40+
"language_info": {
41+
"codemirror_mode": {
42+
"name": "ipython",
43+
"version": 3
44+
},
45+
"file_extension": ".py",
46+
"mimetype": "text/x-python",
47+
"name": "python",
48+
"nbconvert_exporter": "python",
49+
"pygments_lexer": "ipython3",
50+
"version": "3.10.14"
51+
}
52+
},
53+
"nbformat": 4,
54+
"nbformat_minor": 2
55+
}
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules train directory."""
14+
from __future__ import absolute_import

0 commit comments

Comments
 (0)