Skip to content

Commit 8c4b6b6

Browse files
beniericpintaoz-aws
authored andcommitted
feature: support script mode with local train.sh (#1523)
* feature: support script mode with local train.sh * Stop tracking train.sh and add it to .gitignore * update message * make dir if not exist * fix docs * fix: docstyle * Address comments * fix hyperparams * Revert pydantic custom error * pylint
1 parent e6244db commit 8c4b6b6

File tree

8 files changed

+386
-44
lines changed

8 files changed

+386
-44
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ env/
3232
.python-version
3333
*.html
3434
**/_repack_script_launcher.sh
35+
src/sagemaker/modules/scripts/train.sh
3536
tests/data/**/_repack_model.py
3637
tests/data/experiment/sagemaker-dev-1.0.tar.gz
3738
src/sagemaker/serve/tmp_workspace

src/sagemaker/modules/configs.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Configuration classes."""
13+
"""This module provides the configuration classes used in `sagemaker.modules`.
14+
15+
Some of these classes are re-exported from `sagemaker-core.shapes`. For convinence,
16+
users can import these classes directly from `sagemaker.modules.configs`.
17+
18+
For more documentation on `sagemaker-core.shapes`, see:
19+
- https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes
20+
"""
21+
1422
from __future__ import absolute_import
1523

1624
from typing import Optional
@@ -20,8 +28,8 @@
2028
ResourceConfig,
2129
StoppingCondition,
2230
OutputDataConfig,
23-
AlgorithmSpecification,
2431
Channel,
32+
DataSource,
2533
S3DataSource,
2634
FileSystemDataSource,
2735
TrainingImageConfig,
@@ -33,8 +41,8 @@
3341
"ResourceConfig",
3442
"StoppingCondition",
3543
"OutputDataConfig",
36-
"AlgorithmSpecification",
3744
"Channel",
45+
"DataSource",
3846
"S3DataSource",
3947
"FileSystemDataSource",
4048
"TrainingImageConfig",
@@ -62,7 +70,7 @@ class SourceCodeConfig(BaseModel):
6270
job container. If not specified, command must be provided.
6371
"""
6472

65-
command: Optional[str]
66-
source_dir: Optional[str]
67-
requirements: Optional[str]
68-
entry_script: Optional[str]
73+
command: Optional[str] = None
74+
source_dir: Optional[str] = None
75+
requirements: Optional[str] = None
76+
entry_script: Optional[str] = None

src/sagemaker/modules/constants.py

+14
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,19 @@
1212
# language governing permissions and limitations under the License.
1313
"""Constants module."""
1414
from __future__ import absolute_import
15+
import os
1516

1617
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
18+
19+
SOURCE_CODE_CONTAINER_PATH = "/opt/ml/input/data/code"
20+
21+
SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/sm_code"
22+
SM_CODE_LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
23+
TRAIN_SCRIPT = "train.sh"
24+
25+
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]
26+
DEFAULT_CONTAINER_ARGUMENTS = [
27+
"-c",
28+
f"chmod +x {SM_CODE_CONTAINER_PATH}/{TRAIN_SCRIPT} "
29+
+ f"&& {SM_CODE_CONTAINER_PATH}/{TRAIN_SCRIPT}",
30+
]

src/sagemaker/modules/image_spec.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,13 @@ def __init__(
6969
self.config_name = config_name
7070
self.sagemaker_session = sagemaker_session
7171

72-
def get_image_uri(self):
72+
def get_image_uri(
73+
self, image_scope: Optional[str] = None, instance_type: Optional[str] = None
74+
) -> str:
7375
"""Get image URI for a specific framework version."""
76+
77+
self.image_scope = image_scope or self.image_scope
78+
self.instance_type = instance_type or self.instance_type
7479
return image_uris.retrieve(
7580
framework=self.framework_name,
7681
image_scope=self.image_scope,

src/sagemaker/modules/scripts/train.sh

-1
This file was deleted.

src/sagemaker/modules/templates.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Templates module."""
14+
from __future__ import absolute_import
15+
16+
TRAIN_SCRIPT_TEMPLATE = """
17+
#!/bin/bash
18+
echo "Starting training script"
19+
20+
echo "/opt/ml/input/config/resourceconfig.json:"
21+
cat /opt/ml/input/config/resourceconfig.json
22+
echo
23+
24+
echo "/opt/ml/input/config/inputdataconfig.json:"
25+
cat /opt/ml/input/config/inputdataconfig.json
26+
echo
27+
28+
echo "/opt/ml/input/config/hyperparameters.json:"
29+
cat /opt/ml/input/config/hyperparameters.json
30+
echo
31+
32+
python --version
33+
{working_dir}
34+
{install_requirements}
35+
CMD="{command}"
36+
echo "Running command: $CMD"
37+
eval $CMD
38+
EXIT_STATUS=$?
39+
40+
if [ $EXIT_STATUS -ne 0 ]; then
41+
echo "Command failed with exit status $EXIT_STATUS"
42+
if [ ! -s /opt/ml/output/failure ]; then
43+
echo "Command failed with exit code $EXIT_STATUS.
44+
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
45+
TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure
46+
fi
47+
exit $EXIT_STATUS
48+
else
49+
echo "Command succeeded"
50+
fi
51+
"""

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+25-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,32 @@
1919
"outputs": [],
2020
"source": [
2121
"from sagemaker.modules.train.model_trainer import ModelTrainer\n",
22+
"from sagemaker.modules.image_spec import ImageSpec\n",
2223
"\n",
23-
"model_trainer = ModelTrainer(training_image=\"python:3.10.15-slim\")\n"
24+
"pytorch_image = ImageSpec(\n",
25+
" framework_name=\"pytorch\",\n",
26+
" version=\"1.13.1\",\n",
27+
" py_version=\"py39\"\n",
28+
")\n",
29+
"\n",
30+
"python_ecr_image = \"public.ecr.aws/docker/library/python:3.10.15-slim\"\n",
31+
"python_docker_image = \"python:3.10.15-slim\"\n",
32+
"\n",
33+
"model_trainer = ModelTrainer(training_image=pytorch_image)\n"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"from sagemaker.modules.configs import SourceCodeConfig\n",
43+
"\n",
44+
"source_code_config = SourceCodeConfig(\n",
45+
" command=\"echo 'Hello World' && env\",\n",
46+
")\n",
47+
"model_trainer.train(source_code_config=source_code_config)"
2448
]
2549
},
2650
{

0 commit comments

Comments
 (0)