Skip to content

Commit a99ae84

Browse files
nargokulbeniericpintaoz-awspravali96
committed
Trainer handshake (#1535)
* Base model trainer (#1521) * Base model trainer * flake8 * add testing notebook * add param validation & set defaults * Implement simple train method * feature: support script mode with local train.sh (#1523) * feature: support script mode with local train.sh * Stop tracking train.sh and add it to .gitignore * update message * make dir if not exist * fix docs * fix: docstyle * Address comments * fix hyperparams * Revert pydantic custom error * pylint * Image Spec refactoring and updates (#1525) * Image Spec refactoring and updates * Unit tests and update function for Image Spec * Fix hugging face test * Fix Tests * Add unit tests for ModelTrainer (#1527) * Add unit tests for ModelTrainer * Flake8 * format * Add example notebook (#1528) * Add testing notebook * format * use smaller data * remove large dataset * update * pylint * flake8 * ignore docstyle in directories with test * format * format * Add enviornment variable bootstrapping script (#1530) * Add enviornment variables scripts * format * fix comment * add docstrings * fix comment * feature: add utility function to capture local snapshot (#1524) * local snapshot * Update pip list command * Remove function calls * Address comments * Address comments * Change to make Model Trainer return a Model Object * Fix * Cleanup * Support intelligent parameters (#1540) * Support intelligent parameters * fix codestyle * Revert Image Spec (#1541) * Cleanup ModelTrainer (#1542) * General image builder (#1546) * General image builder * General image builder * Fix codestyle * Fix codestyle * Move location * Add warnings * Add integ tests * Fix integ test * Fix integ test * Fix region error * Add region * Latest Container Image (#1545) * Latest Container Image * Test Fixes * Parameterized tests and some logic updates * Test fixes * Move to Image URI * Fixes for unit test * Fixes for unit test * Fix codestyle error checks * Cleanup ModelTrainer code (#1552) * Updates * feat: add pre-processing and post-processing logic to inference_spec (#1560) * add pre-processing and post-processing logic to inference_spec * fix format * make accept_type and content_type optional * remove accept_type and content_type from pre/post processing * correct typo * Add Distributed Training Support Model Trainer (#1536) * Add path to set Additional Settings in ModelTrainer (#1555) * Updates * Mask Sensitive Env Logs in Container (#1568) * Cleanup PR * Codestyle fixes * Update logic to use model parameter instead of model_path * Fixes * Fixes * Tests * Codestyle Fixes * Codestyle Fixes * Codestyle Fixes * Codestyle Fixes --------- Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: pintaoz-aws <[email protected]> Co-authored-by: Pravali Uppugunduri <[email protected]>
1 parent 694f8e9 commit a99ae84

File tree

4 files changed

+195
-15
lines changed

4 files changed

+195
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"id": "initial_id",
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"source": [
10+
"from sagemaker_core.main.shapes import TrainingJob\n",
11+
"\n",
12+
"from sagemaker import Session, get_execution_role\n",
13+
"\n",
14+
"sagemaker_session = Session()\n",
15+
"role = get_execution_role()\n",
16+
"region = sagemaker_session.boto_region_name\n",
17+
"bucket = sagemaker_session.default_bucket()"
18+
],
19+
"outputs": [],
20+
"execution_count": null
21+
},
22+
{
23+
"metadata": {},
24+
"cell_type": "code",
25+
"source": [
26+
"\n",
27+
"from sagemaker.modules.configs import SourceCode\n",
28+
"from sagemaker.modules.train.model_trainer import ModelTrainer\n",
29+
"\n",
30+
"xgboost_image = \"433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest\"\n",
31+
"\n",
32+
"source_code = SourceCode(\n",
33+
" command=\"echo 'Hello World' && env\",\n",
34+
")\n",
35+
"model_trainer = ModelTrainer(\n",
36+
" training_image=xgboost_image,\n",
37+
" source_code=source_code,\n",
38+
")\n",
39+
"\n",
40+
"model_trainer.train()"
41+
],
42+
"id": "4b3a4f7d1713685f",
43+
"outputs": [],
44+
"execution_count": null
45+
},
46+
{
47+
"metadata": {},
48+
"cell_type": "code",
49+
"source": [
50+
"import numpy as np\n",
51+
"from sagemaker.serve.builder.schema_builder import SchemaBuilder\n",
52+
"import pandas as pd\n",
53+
"from xgboost import XGBClassifier\n",
54+
"from sagemaker.serve.spec.inference_spec import InferenceSpec\n",
55+
"from sagemaker.serve import ModelBuilder\n",
56+
"\n",
57+
"data = {\n",
58+
" 'Name': ['Alice', 'Bob', 'Charlie']\n",
59+
"}\n",
60+
"df = pd.DataFrame(data)\n",
61+
"schema_builder = SchemaBuilder(sample_input=df, sample_output=df)\n",
62+
"\n",
63+
"\n",
64+
"class XGBoostSpec(InferenceSpec):\n",
65+
" def load(self, model_dir: str):\n",
66+
" print(model_dir)\n",
67+
" model = XGBClassifier()\n",
68+
" model.load_model(model_dir + \"/xgboost-model\")\n",
69+
" return model\n",
70+
"\n",
71+
" def invoke(self, input_object: object, model: object):\n",
72+
" prediction_probabilities = model.predict_proba(input_object)\n",
73+
" predictions = np.argmax(prediction_probabilities, axis=1)\n",
74+
" return predictions\n",
75+
"\n",
76+
"model_builder = ModelBuilder(\n",
77+
" model=model_trainer, # ModelTrainer object passed onto ModelBuilder directly \n",
78+
" role_arn=role,\n",
79+
" image_uri=xgboost_image,\n",
80+
" inference_spec=XGBoostSpec(),\n",
81+
" schema_builder=schema_builder,\n",
82+
" instance_type=\"ml.c6i.xlarge\"\n",
83+
")\n",
84+
"model=model_builder.build()\n",
85+
"predictor=model_builder.deploy()\n",
86+
"\n",
87+
"predictor\n",
88+
"assert model.model_data == model_trainer._latest_training_job.model_artifacts.s3_model_artifacts\n",
89+
"\n",
90+
"print(model.model_data)"
91+
],
92+
"id": "295a16ef277257a0",
93+
"outputs": [],
94+
"execution_count": null
95+
},
96+
{
97+
"metadata": {},
98+
"cell_type": "code",
99+
"source": [
100+
"training_job: TrainingJob = model_trainer._latest_training_job\n",
101+
"\n",
102+
"model_builder = ModelBuilder(\n",
103+
" model=training_job, # Sagemaker core's TrainingJob object passed onto ModelBuilder directly \n",
104+
" role_arn=role,\n",
105+
" image_uri=xgboost_image,\n",
106+
" schema_builder=schema_builder,\n",
107+
" inference_spec=XGBoostSpec(),\n",
108+
" instance_type=\"ml.c6i.xlarge\"\n",
109+
")\n",
110+
"model=model_builder.build()\n",
111+
"\n",
112+
"assert model.model_data == training_job.model_artifacts.s3_model_artifacts\n",
113+
"\n",
114+
"print(model.model_data)"
115+
],
116+
"id": "935ea8486278d7b1",
117+
"outputs": [],
118+
"execution_count": null
119+
},
120+
{
121+
"metadata": {},
122+
"cell_type": "code",
123+
"source": "",
124+
"id": "757180da84407a1a",
125+
"outputs": [],
126+
"execution_count": null
127+
}
128+
],
129+
"metadata": {
130+
"kernelspec": {
131+
"display_name": "Python 3",
132+
"language": "python",
133+
"name": "python3"
134+
},
135+
"language_info": {
136+
"codemirror_mode": {
137+
"name": "ipython",
138+
"version": 2
139+
},
140+
"file_extension": ".py",
141+
"mimetype": "text/x-python",
142+
"name": "python",
143+
"nbconvert_exporter": "python",
144+
"pygments_lexer": "ipython2",
145+
"version": "2.7.6"
146+
}
147+
},
148+
"nbformat": 4,
149+
"nbformat_minor": 5
150+
}

src/sagemaker/modules/train/model_trainer.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from tempfile import TemporaryDirectory
2020

2121
from typing import Optional, List, Union, Dict, Any
22-
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
23-
22+
from sagemaker_core.main import resources
2423
from sagemaker_core.resources import TrainingJob
2524
from sagemaker_core.shapes import AlgorithmSpecification
2625

26+
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
27+
2728
from sagemaker import get_execution_role, Session
2829
from sagemaker.modules.configs import (
2930
Compute,
@@ -51,6 +52,7 @@
5152
CheckpointConfig,
5253
InputData,
5354
)
55+
5456
from sagemaker.modules.distributed import (
5557
DistributedRunner,
5658
TorchrunSMP,
@@ -187,13 +189,17 @@ class ModelTrainer(BaseModel):
187189
hyperparameters: Optional[Dict[str, Any]] = None
188190
tags: Optional[List[Tag]] = None
189191

192+
# Created Artifacts
193+
_latest_training_job: Optional[resources.TrainingJob] = None
194+
190195
# Metrics settings
191196
_enable_sage_maker_metrics_time_series: Optional[bool] = PrivateAttr(default=False)
192197
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
193198

194199
# Debugger settings
195200
_debug_hook_config: Optional[DebugHookConfig] = PrivateAttr(default=None)
196201
_debug_rule_configurations: Optional[List[DebugRuleConfiguration]] = PrivateAttr(default=None)
202+
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
197203
_profiler_config: Optional[ProfilerConfig] = PrivateAttr(default=None)
198204
_profiler_rule_configurations: Optional[List[ProfilerRuleConfiguration]] = PrivateAttr(
199205
default=None
@@ -448,11 +454,9 @@ def train(
448454
infra_check_config=self._infra_check_config,
449455
session_chaining_config=self._session_chaining_config,
450456
)
451-
457+
self._latest_training_job = training_job
452458
if wait:
453459
training_job.wait(logs=logs)
454-
if logs and not wait:
455-
logger.warning("Not displaing the training container logs as 'wait' is set to False.")
456460

457461
def create_input_data_channel(self, channel_name: str, data_source: DataSourceType) -> Channel:
458462
"""Create an input data channel for the training job.

src/sagemaker/serve/builder/model_builder.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
from pathlib import Path
2626

27+
from sagemaker_core.main.resources import TrainingJob
28+
29+
from sagemaker.estimator import Estimator
2730
from sagemaker.enums import Tag
2831
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
2932
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -176,8 +179,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
176179
The schema builder can be omitted for HuggingFace models with task types TextGeneration,
177180
TextClassification, and QuestionAnswering. Omitting SchemaBuilder is in
178181
beta for FillMask, and AutomaticSpeechRecognition use-cases.
179-
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
180-
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
182+
model (Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]]):
183+
Define object from which training artifacts can be extracted.
184+
Either ``model`` or ``inference_spec``
181185
is required for the model builder to build the artifact.
182186
inference_spec (InferenceSpec): The inference spec file with your customized
183187
``invoke`` and ``load`` functions.
@@ -268,14 +272,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
268272
schema_builder: Optional[SchemaBuilder] = field(
269273
default=None, metadata={"help": "Defines the i/o schema of the model"}
270274
)
271-
model: Optional[Union[object, str]] = field(
275+
model: Optional[Union[object, str, "ModelTrainer", TrainingJob, Estimator]] = field(
272276
default=None,
273-
metadata={
274-
"help": (
275-
'Model object with "predict" method to perform inference '
276-
"or HuggingFace/JumpStart Model ID"
277-
)
278-
},
277+
metadata={"help": "Define object from which training artifacts can be extracted"}
279278
)
280279
inference_spec: InferenceSpec = field(
281280
default=None,
@@ -852,13 +851,24 @@ def build( # pylint: disable=R0911
852851
Returns:
853852
Type[Model]: A deployable ``Model`` object.
854853
"""
854+
from sagemaker.modules.train.model_trainer import ModelTrainer
855855
self.modes = dict()
856856

857857
if mode:
858858
self.mode = mode
859859
if role_arn:
860860
self.role_arn = role_arn
861861

862+
if isinstance(self.model, TrainingJob):
863+
self.model_path = self.model.model_artifacts.s3_model_artifacts
864+
self.model = None
865+
elif isinstance(self.model, ModelTrainer):
866+
self.model_path = self.model._latest_training_job.model_artifacts.s3_model_artifacts
867+
self.model = None
868+
elif isinstance(self.model, Estimator):
869+
self.model_path = self.model.output_path
870+
self.model = None
871+
862872
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
863873

864874
self.sagemaker_session.settings._local_download_dir = self.model_path

tests/unit/sagemaker/modules/train/test_model_trainer.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121
from unittest.mock import patch, MagicMock
2222

23+
from sagemaker_core.main.resources import TrainingJob
24+
2325
from sagemaker.session import Session
2426
from sagemaker.modules.train.model_trainer import ModelTrainer
2527
from sagemaker.modules.constants import (
@@ -316,6 +318,7 @@ def test_debugger_settings(mock_training_job, modules_session):
316318

317319
assert model_trainer._debug_hook_config == debug_hook_config
318320
assert model_trainer._debug_rule_configurations == debug_rule_config
321+
319322
assert model_trainer._profiler_config == profiler_config
320323
assert model_trainer._profiler_rule_configurations == profiler_rule_config
321324
assert model_trainer._tensor_board_output_config == tensor_board_output_config
@@ -485,7 +488,12 @@ def test_train_with_distributed_runner(
485488
assert test_case["distributed_runner"].model_dump(exclude_none=True) == (
486489
json.loads(runner_json_content)
487490
)
488-
491+
assert os.path.exists(expected_source_code_json_path)
492+
with open(expected_source_code_json_path, "r") as f:
493+
source_code_json_content = f.read()
494+
assert test_case["source_code"].model_dump(exclude_none=True) == (
495+
json.loads(source_code_json_content)
496+
)
489497
assert os.path.exists(expected_source_code_json_path)
490498
with open(expected_source_code_json_path, "r") as f:
491499
source_code_json_content = f.read()
@@ -495,3 +503,11 @@ def test_train_with_distributed_runner(
495503
finally:
496504
shutil.rmtree(tmp_dir.name)
497505
assert not os.path.exists(tmp_dir.name)
506+
507+
508+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
509+
def test_train_stores_created_training_job(mock_training_job, model_trainer):
510+
mock_training_job.create.return_value = TrainingJob(training_job_name="Created-job")
511+
model_trainer.train(wait=False)
512+
assert model_trainer._latest_training_job is not None
513+
assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job")

0 commit comments

Comments
 (0)