Skip to content

Commit 02cc559

Browse files
beniericpintaoz-aws
authored andcommitted
Add interface units for ModelTrainer (#1631)
1 parent 9c75b2b commit 02cc559

File tree

3 files changed

+317
-6
lines changed

3 files changed

+317
-6
lines changed

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
]
1111
},
1212
{
13-
"metadata": {},
1413
"cell_type": "markdown",
14+
"metadata": {},
1515
"source": [
1616
"# ModelTrainer\n",
1717
"The ModelTrainer is a new interface for training designed to tackle many of the challenges that exist in todays Estimator class. Some key features include:\n",

src/sagemaker/modules/train/model_trainer.py

+14
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,8 @@ def from_recipe(
839839
training_recipe: str,
840840
compute: Compute,
841841
recipe_overrides: Optional[Dict[str, Any]] = None,
842+
networking: Optional[Networking] = None,
843+
stopping_condition: Optional[StoppingCondition] = None,
842844
requirements: Optional[str] = None,
843845
training_image: Optional[str] = None,
844846
training_image_config: Optional[TrainingImageConfig] = None,
@@ -863,6 +865,13 @@ def from_recipe(
863865
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
864866
recipe_overrides (Optional[Dict[str, Any]]):
865867
The recipe overrides. This is used to override the default recipe parameters.
868+
networking (Optional[Networking]):
869+
The networking configuration. This is used to specify the networking settings
870+
for the training job.
871+
stopping_condition (Optional[StoppingCondition]):
872+
The stopping condition. This is used to specify the different stopping
873+
conditions for the training job.
874+
If not specified, will default to 1 hour max run time.
866875
requirements (Optional[str]):
867876
The path to a requirements file to install in the training job container.
868877
training_image (Optional[str]):
@@ -912,6 +921,9 @@ def from_recipe(
912921
+ "Please provide a GPU or Tranium instance type."
913922
)
914923

924+
if training_image_config and training_image is None:
925+
raise ValueError("training_image must be provided when using training_image_config.")
926+
915927
if sagemaker_session is None:
916928
sagemaker_session = Session()
917929
logger.warning("SageMaker session not provided. Using default Session.")
@@ -939,6 +951,8 @@ def from_recipe(
939951
sagemaker_session=sagemaker_session,
940952
role=role,
941953
base_job_name=base_job_name,
954+
networking=networking,
955+
stopping_condition=stopping_condition,
942956
training_image_config=training_image_config,
943957
output_data_config=output_data_config,
944958
input_data_config=input_data_config,

0 commit comments

Comments
 (0)