|
27 | 27 | from sagemaker.inputs import FileSystemInput, TrainingInput
|
28 | 28 | from sagemaker.instance_group import InstanceGroup
|
29 | 29 | from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
|
| 30 | +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
30 | 31 | from sagemaker.jumpstart.enums import JumpStartScriptScope
|
31 | 32 | from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
|
32 | 33 |
|
@@ -655,6 +656,62 @@ def fit(
|
655 | 656 |
|
656 | 657 | return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
|
657 | 658 |
|
| 659 | + @classmethod |
| 660 | + def attach( |
| 661 | + cls, |
| 662 | + training_job_name: str, |
| 663 | + model_id: str, |
| 664 | + model_version: str = "*", |
| 665 | + sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
| 666 | + model_channel_name: str = "model", |
| 667 | + ) -> "JumpStartEstimator": |
| 668 | + """Attach to an existing training job. |
| 669 | +
|
| 670 | + Create a JumpStartEstimator bound to an existing training job. |
| 671 | + After attaching, if the training job has a Complete status, |
| 672 | + it can be ``deploy()`` ed to create a SageMaker Endpoint and return |
| 673 | + a ``Predictor``. |
| 674 | +
|
| 675 | + If the training job is in progress, attach will block until the training job |
| 676 | + completes, but logs of the training job will not display. To see the logs |
| 677 | + content, please call ``logs()`` |
| 678 | +
|
| 679 | + Examples: |
| 680 | + >>> my_estimator.fit(wait=False) |
| 681 | + >>> training_job_name = my_estimator.latest_training_job.name |
| 682 | + Later on: |
| 683 | + >>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id) |
| 684 | + >>> attached_estimator.logs() |
| 685 | + >>> attached_estimator.deploy() |
| 686 | +
|
| 687 | + Args: |
| 688 | + training_job_name (str): The name of the training job to attach to. |
| 689 | + model_id (str): The name of the JumpStart model id associated with the |
| 690 | + training job. |
| 691 | + model_version (str): Optional. The version of the JumpStart model id |
| 692 | + associated with the training job. (Default: "*"). |
| 693 | + sagemaker_session (sagemaker.session.Session): Optional. Session object which |
| 694 | + manages interactions with Amazon SageMaker APIs and any other |
| 695 | + AWS services needed. If not specified, the estimator creates one |
| 696 | + using the default AWS configuration chain. |
| 697 | + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). |
| 698 | + model_channel_name (str): Optional. Name of the channel where pre-trained |
| 699 | + model data will be downloaded (default: 'model'). If no channel |
| 700 | + with the same name exists in the training job, this option will |
| 701 | + be ignored. |
| 702 | +
|
| 703 | + Returns: |
| 704 | + Instance of the calling ``JumpStartEstimator`` Class with the attached |
| 705 | + training job. |
| 706 | + """ |
| 707 | + |
| 708 | + return cls._attach( |
| 709 | + training_job_name=training_job_name, |
| 710 | + sagemaker_session=sagemaker_session, |
| 711 | + model_channel_name=model_channel_name, |
| 712 | + additional_kwargs={"model_id": model_id, "model_version": model_version}, |
| 713 | + ) |
| 714 | + |
658 | 715 | def deploy(
|
659 | 716 | self,
|
660 | 717 | initial_instance_count: Optional[int] = None,
|
|
0 commit comments