|
55 | 55 |
|
56 | 56 | from sagemaker.modules.distributed import (
|
57 | 57 | DistributedRunner,
|
58 |
| - TorchrunSMP, |
| 58 | + Torchrun, |
59 | 59 | )
|
60 | 60 | from sagemaker.modules.utils import (
|
61 | 61 | _get_repo_name_from_image,
|
|
85 | 85 | EXECUTE_BASIC_SCRIPT_DRIVER,
|
86 | 86 | )
|
87 | 87 | from sagemaker.modules import logger
|
| 88 | +from sagemaker.modules.train.sm_recipes.utils import get_args_from_recipe, _determine_device_type |
88 | 89 |
|
89 | 90 |
|
90 | 91 | class ModelTrainer(BaseModel):
|
@@ -213,6 +214,14 @@ class ModelTrainer(BaseModel):
|
213 | 214 | _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
|
214 | 215 | _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
|
215 | 216 |
|
| 217 | + _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) |
| 218 | + |
| 219 | + def __del__(self): |
| 220 | + """Destructor method to clean up the temporary directory.""" |
| 221 | + # Clean up the temporary directory if it exists |
| 222 | + if self._temp_recipe_train_dir is not None: |
| 223 | + self._temp_recipe_train_dir.cleanup() |
| 224 | + |
216 | 225 | def _validate_training_image_and_algorithm_name(
|
217 | 226 | self, training_image: Optional[str], algorithm_name: Optional[str]
|
218 | 227 | ):
|
@@ -383,9 +392,9 @@ def train(
|
383 | 392 | distributed_runner=self.distributed_runner,
|
384 | 393 | )
|
385 | 394 |
|
386 |
| - if isinstance(self.distributed_runner, TorchrunSMP): |
387 |
| - mp_parameters = self.distributed_runner._to_mp_parameters_dict() |
388 |
| - string_hyper_parameters["mp_parameters"] = safe_serialize(mp_parameters) |
| 395 | + if isinstance(self.distributed_runner, Torchrun) and self.distributed_runner.smp: |
| 396 | + mp_parameters = self.distributed_runner.smp._to_mp_hyperparameters() |
| 397 | + string_hyper_parameters.update(mp_parameters) |
389 | 398 |
|
390 | 399 | self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code)
|
391 | 400 | self._write_distributed_runner_json(
|
@@ -455,6 +464,11 @@ def train(
|
455 | 464 | session_chaining_config=self._session_chaining_config,
|
456 | 465 | )
|
457 | 466 | self._latest_training_job = training_job
|
| 467 | + |
| 468 | + # Clean up the temporary directory if it exists |
| 469 | + if self._temp_recipe_train_dir is not None: |
| 470 | + self._temp_recipe_train_dir.cleanup() |
| 471 | + |
458 | 472 | if wait:
|
459 | 473 | training_job.wait(logs=logs)
|
460 | 474 |
|
@@ -748,3 +762,77 @@ def with_additional_settings(
|
748 | 762 | self._session_chaining_config = session_chaining_config
|
749 | 763 | self._remote_debug_config = remote_debug_config
|
750 | 764 | return self
|
| 765 | + |
| 766 | + @classmethod |
| 767 | + def from_recipe( |
| 768 | + cls, |
| 769 | + training_recipe: str, |
| 770 | + compute: Compute, |
| 771 | + recipe_overrides: Optional[Dict[str, Any]] = None, |
| 772 | + training_image: Optional[str] = None, |
| 773 | + session: Optional[Session] = None, |
| 774 | + role: Optional[str] = None, |
| 775 | + base_job_name: Optional[str] = None, |
| 776 | + **kwargs, |
| 777 | + ) -> "ModelTrainer": |
| 778 | + """Create a ModelTrainer from a training recipe. |
| 779 | +
|
| 780 | + Args: |
| 781 | + training_recipe (str): |
| 782 | + The training recipe to use for training the model. This must be the name of |
| 783 | + a sagemaker training recipe or a path to a local training recipe .yaml file. |
| 784 | + compute (Compute): |
| 785 | + The compute configuration. This is used to specify the compute resources for |
| 786 | + the training job. If not specified, will default to 1 instance of ml.m5.xlarge. |
| 787 | + recipe_overrides (Optional[Dict[str, Any]]): |
| 788 | + The recipe overrides. This is used to override the default recipe parameters. |
| 789 | + training_image (Optional[str]): |
| 790 | + The training image URI to use for the training job container. If not specified, |
| 791 | + the training image will be determined from the recipe. |
| 792 | + session (Optional[Session]): |
| 793 | + The SageMaker session. |
| 794 | + If not specified, a new session will be created. |
| 795 | + role (Optional[str]): |
| 796 | + The IAM role ARN for the training job. |
| 797 | + If not specified, the default SageMaker execution role will be used. |
| 798 | + base_job_name (Optional[str]): |
| 799 | + The base name for the training job. |
| 800 | + If not specified, a default name will be generated using the algorithm name |
| 801 | + or training image. |
| 802 | + kwargs: |
| 803 | + Additional keyword arguments to pass to the ModelTrainer constructor. |
| 804 | +
|
| 805 | + """ |
| 806 | + if compute.instance_type is None: |
| 807 | + raise ValueError( |
| 808 | + "Must set `instance_type` in compute_config when using training recipes." |
| 809 | + ) |
| 810 | + device_type = _determine_device_type(compute.instance_type) |
| 811 | + if device_type == "cpu": |
| 812 | + raise ValueError( |
| 813 | + "Training recipes are not supported for CPU instances. " |
| 814 | + + "Please provide a GPU or Tranium instance type." |
| 815 | + ) |
| 816 | + |
| 817 | + if session is None: |
| 818 | + session = Session() |
| 819 | + logger.warning("Session not provided. Using default Session.") |
| 820 | + if role is None: |
| 821 | + role = get_execution_role() |
| 822 | + logger.warning(f"Role not provided. Using default role:\n{role}") |
| 823 | + |
| 824 | + model_trainer_args, recipe_train_dir = get_args_from_recipe( |
| 825 | + training_recipe=training_recipe, |
| 826 | + recipe_overrides=recipe_overrides, |
| 827 | + compute=compute, |
| 828 | + session=session, |
| 829 | + ) |
| 830 | + if training_image is not None: |
| 831 | + model_trainer_args["training_image"] = training_image |
| 832 | + |
| 833 | + model_trainer = cls( |
| 834 | + session=session, role=role, base_job_name=base_job_name, **model_trainer_args, **kwargs |
| 835 | + ) |
| 836 | + |
| 837 | + model_trainer._temp_recipe_train_dir = recipe_train_dir |
| 838 | + return model_trainer |
0 commit comments