|
22 | 22 | from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig
|
23 | 23 | from sagemaker.debugger.profiler_config import ProfilerConfig
|
24 | 24 |
|
25 |
| -from sagemaker.estimator import Estimator |
| 25 | +from sagemaker.estimator import _TrainingJob, Estimator |
26 | 26 | from sagemaker.explainer.explainer_config import ExplainerConfig
|
27 | 27 | from sagemaker.inputs import FileSystemInput, TrainingInput
|
28 | 28 | from sagemaker.instance_group import InstanceGroup
|
29 | 29 | from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
|
| 30 | +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
30 | 31 | from sagemaker.jumpstart.enums import JumpStartScriptScope
|
31 | 32 | from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
|
32 | 33 |
|
@@ -655,6 +656,73 @@ def fit(
|
655 | 656 |
|
656 | 657 | return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
|
657 | 658 |
|
| 659 | + @classmethod |
| 660 | + def attach( |
| 661 | + cls, |
| 662 | + training_job_name: str, |
| 663 | + model_id: str, |
| 664 | + model_version: str = "*", |
| 665 | + sagemaker_session=None, |
| 666 | + model_channel_name="model", |
| 667 | + ): |
| 668 | + """Attach to an existing training job. |
| 669 | +
|
| 670 | + Create an Estimator bound to an existing training job. |
| 671 | + After attaching, if the training job has a Complete status, |
| 672 | + it can be ``deploy()`` ed to create a SageMaker Endpoint and return |
| 673 | + a ``Predictor``. |
| 674 | +
|
| 675 | + If the training job is in progress, attach will block until the training job |
| 676 | + completes, but logs of the training job will not display. To see the logs |
| 677 | + content, please call ``logs()`` |
| 678 | +
|
| 679 | + Examples: |
| 680 | + >>> my_estimator.fit(wait=False) |
| 681 | + >>> training_job_name = my_estimator.latest_training_job.name |
| 682 | + Later on: |
| 683 | + >>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id) |
| 684 | + >>> attached_estimator.logs() |
| 685 | + >>> attached_estimator.deploy() |
| 686 | +
|
| 687 | + Args: |
| 688 | + training_job_name (str): The name of the training job to attach to. |
| 689 | + model_id (str): The name of the JumpStart model id associated with the |
| 690 | + training job. |
| 691 | + model_version (str): Optional. The version of the JumpStart model id |
| 692 | + associated with the training job. (Default: "*"). |
| 693 | + sagemaker_session (sagemaker.session.Session): Optional. Session object which |
| 694 | + manages interactions with Amazon SageMaker APIs and any other |
| 695 | + AWS services needed. If not specified, the estimator creates one |
| 696 | + using the default AWS configuration chain. |
| 697 | + model_channel_name (str): Optional. Name of the channel where pre-trained |
| 698 | + model data will be downloaded (default: 'model'). If no channel |
| 699 | + with the same name exists in the training job, this option will |
| 700 | + be ignored. |
| 701 | +
|
| 702 | + Returns: |
| 703 | + Instance of the calling ``Estimator`` Class with the attached |
| 704 | + training job. |
| 705 | + """ |
| 706 | + sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 707 | + |
| 708 | + job_details = sagemaker_session.sagemaker_client.describe_training_job( |
| 709 | + TrainingJobName=training_job_name |
| 710 | + ) |
| 711 | + init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) |
| 712 | + tags = sagemaker_session.sagemaker_client.list_tags( |
| 713 | + ResourceArn=job_details["TrainingJobArn"] |
| 714 | + )["Tags"] |
| 715 | + init_params.update(tags=tags) |
| 716 | + init_params.update(model_id=model_id, model_version=model_version) |
| 717 | + |
| 718 | + estimator = cls(sagemaker_session=sagemaker_session, **init_params) |
| 719 | + estimator.latest_training_job = _TrainingJob( |
| 720 | + sagemaker_session=sagemaker_session, job_name=training_job_name |
| 721 | + ) |
| 722 | + estimator._current_job_name = estimator.latest_training_job.name |
| 723 | + estimator.latest_training_job.wait(logs="None") |
| 724 | + return estimator |
| 725 | + |
658 | 726 | def deploy(
|
659 | 727 | self,
|
660 | 728 | initial_instance_count: Optional[int] = None,
|
|
0 commit comments