Skip to content

Commit 6b90f89

Browse files
beniericpintaoz-aws
authored andcommitted
Add recipes examples (#1582)
1 parent 70ae24f commit 6b90f89

File tree

3 files changed

+251
-17
lines changed

3 files changed

+251
-17
lines changed

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+140-15
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,9 @@
7979
" command=\"python custom_script.py\",\n",
8080
")\n",
8181
"\n",
82-
"hyperparameters = {\n",
83-
" \"secret_token\": \"123456\",\n",
84-
"}\n",
85-
"\n",
86-
"env_vars = {\n",
87-
" \"PASSWORD\": \"123456\"\n",
88-
"}\n",
89-
"\n",
9082
"model_trainer = ModelTrainer(\n",
9183
" training_image=pytorch_image,\n",
9284
" source_code=source_code,\n",
93-
" hyperparameters=hyperparameters,\n",
94-
" environment=env_vars,\n",
9585
")\n",
9686
"\n",
9787
"model_trainer.train(wait=False)"
@@ -386,10 +376,6 @@
386376
"from sagemaker.modules.configs import (\n",
387377
" Compute, SourceCode, InputData\n",
388378
")\n",
389-
"from sagemaker.modules.distributed import (\n",
390-
" Torchrun,\n",
391-
" MPI\n",
392-
")\n",
393379
"\n",
394380
"compute = Compute(\n",
395381
" instance_count=2,\n",
@@ -420,6 +406,12 @@
420406
"metadata": {},
421407
"outputs": [],
422408
"source": [
409+
"from sagemaker.modules.distributed import (\n",
410+
" Torchrun,\n",
411+
" MPI,\n",
412+
" SMP\n",
413+
")\n",
414+
"\n",
423415
"source_code = SourceCode(\n",
424416
" source_dir=\"distributed-training/scripts\",\n",
425417
" requirements=\"requirements.txt\",\n",
@@ -429,6 +421,14 @@
429421
"# Run using Torchrun\n",
430422
"torchrun = Torchrun()\n",
431423
"\n",
424+
"# Run using Torchrun with SMP\n",
425+
"torchrun_smp = Torchrun(\n",
426+
" smp=SMP(\n",
427+
" sm_activation_offloading=True,\n",
428+
" activation_loading_horizon=2,\n",
429+
" )\n",
430+
")\n",
431+
"\n",
432432
"# Run using MPI\n",
433433
"mpi = MPI(\n",
434434
" mpi_additional_options=[\n",
@@ -482,7 +482,7 @@
482482
"outputs": [],
483483
"source": [
484484
"from sagemaker.modules.train import ModelTrainer\n",
485-
"from sagemaker.modules.configs import Compute, InputData\n",
485+
"from sagemaker.modules.configs import Compute\n",
486486
"\n",
487487
"recipe_overrides = {\n",
488488
" \"run\": {\n",
@@ -536,6 +536,131 @@
536536
"source": [
537537
"Successful Run - https://tiny.amazon.com/14jxjrndx/IsenLink"
538538
]
539+
},
540+
{
541+
"cell_type": "markdown",
542+
"metadata": {},
543+
"source": [
544+
"### Custom Recipe"
545+
]
546+
},
547+
{
548+
"cell_type": "code",
549+
"execution_count": null,
550+
"metadata": {},
551+
"outputs": [],
552+
"source": [
553+
"from sagemaker.modules.train import ModelTrainer\n",
554+
"from sagemaker.modules.configs import Compute\n",
555+
"\n",
556+
"training_image = \"059094755717.dkr.ecr.us-west-2.amazonaws.com/sagemaker-recipes-gpu\"\n",
557+
"\n",
558+
"model_trainer = ModelTrainer.from_recipe(\n",
559+
" training_recipe=\"recipes/custom-recipe.yaml\",\n",
560+
" training_image=training_image,\n",
561+
" compute=Compute(instance_type=\"ml.p4d.24xlarge\")\n",
562+
")"
563+
]
564+
},
565+
{
566+
"cell_type": "code",
567+
"execution_count": null,
568+
"metadata": {},
569+
"outputs": [],
570+
"source": [
571+
"model_trainer.train()"
572+
]
573+
},
574+
{
575+
"cell_type": "markdown",
576+
"metadata": {},
577+
"source": [
578+
"Successful Run - https://tiny.amazon.com/dimbimx1/IsenLink"
579+
]
580+
},
581+
{
582+
"cell_type": "markdown",
583+
"metadata": {},
584+
"source": [
585+
"### Trainium Recipe"
586+
]
587+
},
588+
{
589+
"cell_type": "code",
590+
"execution_count": null,
591+
"metadata": {},
592+
"outputs": [],
593+
"source": [
594+
"from sagemaker import session\n",
595+
"\n",
596+
"session = session.Session()\n",
597+
"base_job_name = \"trn-llama\"\n",
598+
"compiler_cache_bucket = f\"s3://{session.default_bucket()}/{base_job_name}/compiler-cache\"\n",
599+
"print(f\"Compiler cache: {compiler_cache_bucket}\")"
600+
]
601+
},
602+
{
603+
"cell_type": "code",
604+
"execution_count": null,
605+
"metadata": {},
606+
"outputs": [],
607+
"source": [
608+
"from sagemaker.modules.train import ModelTrainer\n",
609+
"from sagemaker.modules.configs import Compute, InputData, StoppingCondition\n",
610+
"\n",
611+
"recipe_overrides = {\n",
612+
" \"data\": {\n",
613+
" \"train_dir\": \"/opt/ml/input/data/train\",\n",
614+
" },\n",
615+
" \"model\": {\n",
616+
" \"model_config\": \"/opt/ml/input/data/train/config.json\",\n",
617+
" },\n",
618+
" \"trainer\": {\n",
619+
" \"max_epochs\": 1,\n",
620+
" },\n",
621+
" \"compiler_cache_url\": compiler_cache_bucket,\n",
622+
"}\n",
623+
"env = {\n",
624+
" \"FI_EFA_FORK_SAFE\": \"1\"\n",
625+
"}\n",
626+
"\n",
627+
"training_image = \"059094755717.dkr.ecr.us-west-2.amazonaws.com/sagemaker-recipes-neuron\"\n",
628+
"\n",
629+
"model_trainer = ModelTrainer.from_recipe(\n",
630+
" training_recipe=\"https://raw.githubusercontent.com/aws-neuron/neuronx-distributed-training/refs/heads/main/examples/conf/hf_llama3_8B_config.yaml\",\n",
631+
" recipe_overrides=recipe_overrides,\n",
632+
" training_image=training_image,\n",
633+
" compute=Compute(\n",
634+
" instance_type=\"ml.trn1.32xlarge\",\n",
635+
" instance_count=2,\n",
636+
" ),\n",
637+
" stopping_condition=StoppingCondition(\n",
638+
" max_runtime_in_seconds=86400\n",
639+
" ),\n",
640+
" environment=env\n",
641+
")"
642+
]
643+
},
644+
{
645+
"cell_type": "code",
646+
"execution_count": null,
647+
"metadata": {},
648+
"outputs": [],
649+
"source": [
650+
"train = InputData(\n",
651+
" channel_name=\"train\",\n",
652+
" data_source=\"s3://sagemaker-recipes-059094755717-data/data_llama3/\",\n",
653+
")\n",
654+
"\n",
655+
"model_trainer.train(input_data_config=[train], wait=False)"
656+
]
657+
},
658+
{
659+
"cell_type": "markdown",
660+
"metadata": {},
661+
"source": [
662+
"Successful Run - https://tiny.amazon.com/125zldym8/IsenLink"
663+
]
539664
}
540665
],
541666
"metadata": {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
run:
2+
name: llama-8b
3+
results_dir: /opt/ml/model
4+
time_limit: 6-00:00:00
5+
model_type: hf
6+
trainer:
7+
devices: 8
8+
num_nodes: 1
9+
accelerator: gpu
10+
precision: bf16
11+
max_steps: 50
12+
log_every_n_steps: 1
13+
val_check_interval: 1
14+
limit_val_batches: 0
15+
exp_manager:
16+
exp_dir: ''
17+
name: experiment
18+
create_tensorboard_logger: true
19+
create_checkpoint_callback: true
20+
checkpoint_callback_params:
21+
save_top_k: 0
22+
every_n_train_steps: 10
23+
monitor: step
24+
mode: max
25+
save_last: true
26+
checkpoint_dir: /opt/ml/checkpoints
27+
resume_from_checkpoint: null
28+
auto_checkpoint:
29+
enabled: false
30+
export_full_model:
31+
every_n_train_steps: 0
32+
save_last: false
33+
explicit_log_dir: /opt/ml/output/tensorboard
34+
use_smp_model: false
35+
distributed_backend: nccl
36+
model:
37+
model_type: llama_v3
38+
train_batch_size: 1
39+
seed: 12345
40+
grad_clip: 1.0
41+
log_reduced_training_loss: true
42+
context_parallel_degree: 1
43+
moe: false
44+
activation_checkpointing: true
45+
activation_loading_horizon: 2
46+
delayed_param: false
47+
offload_activations: false
48+
fsdp: true
49+
sharding_strategy: hybrid_shard
50+
forward_prefetch: true
51+
shard_degree: 8
52+
backward_fetch_policy: backward_pre
53+
auto_wrap_policy: transformer_auto_wrap_policy
54+
limit_all_gathers: false
55+
use_orig_param: false
56+
fp8: false
57+
max_context_width: 8192
58+
max_position_embeddings: 8192
59+
num_hidden_layers: 32
60+
hidden_size: 4096
61+
num_attention_heads: 32
62+
intermediate_size: 14336
63+
initializer_range: 0.02
64+
layernorm_epsilon: 1.0e-05
65+
vocab_size: 128256
66+
num_key_value_heads: null
67+
use_flash_attention: true
68+
rope_theta: 500000.0
69+
rope_scaling:
70+
rope_type: llama3
71+
factor: 8.0
72+
high_freq_factor: 4.0
73+
low_freq_factor: 1.0
74+
original_max_position_embeddings: 8192
75+
do_finetune: true
76+
hf_model_name_or_path: meta-llama/Llama-3.1-8B
77+
hf_access_token: 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
78+
peft:
79+
peft_type: lora
80+
rank: 32
81+
alpha: 16
82+
dropout: 0.1
83+
precision: bf16
84+
lr_decay_iters: 50
85+
optim:
86+
name: adamw
87+
lr: 0.0001
88+
weight_decay: 0.01
89+
betas:
90+
- 0.9
91+
- 0.95
92+
sched:
93+
name: CosineAnnealing
94+
warmup_steps: 0
95+
constant_steps: 0
96+
min_lr: 1.0e-06
97+
data:
98+
train_dir: /opt/ml/input/data/train
99+
val_dir: /opt/ml/input/data/val
100+
dataset_type: hf
101+
use_synthetic_data: true
102+
nsys_profile:
103+
enabled: false
104+
start_step: 10
105+
end_step: 10
106+
ranks:
107+
- 0
108+
gen_shape: false
109+
viztracer:
110+
enabled: false

src/sagemaker/modules/train/model_trainer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class ModelTrainer(BaseModel):
191191
tags: Optional[List[Tag]] = None
192192

193193
# Created Artifacts
194-
_latest_training_job: Optional[resources.TrainingJob] = None
194+
_latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
195195

196196
# Metrics settings
197197
_enable_sage_maker_metrics_time_series: Optional[bool] = PrivateAttr(default=False)
@@ -200,7 +200,6 @@ class ModelTrainer(BaseModel):
200200
# Debugger settings
201201
_debug_hook_config: Optional[DebugHookConfig] = PrivateAttr(default=None)
202202
_debug_rule_configurations: Optional[List[DebugRuleConfiguration]] = PrivateAttr(default=None)
203-
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
204203
_profiler_config: Optional[ProfilerConfig] = PrivateAttr(default=None)
205204
_profiler_rule_configurations: Optional[List[ProfilerRuleConfiguration]] = PrivateAttr(
206205
default=None

0 commit comments

Comments
 (0)