Skip to content

Commit 7b86786

Browse files
committed
feat: attach method for jumpstart estimator
1 parent cadd0a1 commit 7b86786

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

src/sagemaker/jumpstart/estimator.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig
2323
from sagemaker.debugger.profiler_config import ProfilerConfig
2424

25-
from sagemaker.estimator import Estimator
25+
from sagemaker.estimator import _TrainingJob, Estimator
2626
from sagemaker.explainer.explainer_config import ExplainerConfig
2727
from sagemaker.inputs import FileSystemInput, TrainingInput
2828
from sagemaker.instance_group import InstanceGroup
2929
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
30+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
3031
from sagemaker.jumpstart.enums import JumpStartScriptScope
3132
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3233

@@ -655,6 +656,73 @@ def fit(
655656

656657
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
657658

659+
@classmethod
660+
def attach(
661+
cls,
662+
training_job_name: str,
663+
model_id: str,
664+
model_version: str = "*",
665+
sagemaker_session=None,
666+
model_channel_name="model",
667+
):
668+
"""Attach to an existing training job.
669+
670+
Create an Estimator bound to an existing training job.
671+
After attaching, if the training job has a Complete status,
672+
it can be ``deploy()`` ed to create a SageMaker Endpoint and return
673+
a ``Predictor``.
674+
675+
If the training job is in progress, attach will block until the training job
676+
completes, but logs of the training job will not display. To see the logs
677+
content, please call ``logs()``
678+
679+
Examples:
680+
>>> my_estimator.fit(wait=False)
681+
>>> training_job_name = my_estimator.latest_training_job.name
682+
Later on:
683+
>>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
684+
>>> attached_estimator.logs()
685+
>>> attached_estimator.deploy()
686+
687+
Args:
688+
training_job_name (str): The name of the training job to attach to.
689+
model_id (str): The name of the JumpStart model id associated with the
690+
training job.
691+
model_version (str): Optional. The version of the JumpStart model id
692+
associated with the training job. (Default: "*").
693+
sagemaker_session (sagemaker.session.Session): Optional. Session object which
694+
manages interactions with Amazon SageMaker APIs and any other
695+
AWS services needed. If not specified, the estimator creates one
696+
using the default AWS configuration chain.
697+
model_channel_name (str): Optional. Name of the channel where pre-trained
698+
model data will be downloaded (default: 'model'). If no channel
699+
with the same name exists in the training job, this option will
700+
be ignored.
701+
702+
Returns:
703+
Instance of the calling ``Estimator`` Class with the attached
704+
training job.
705+
"""
706+
sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
707+
708+
job_details = sagemaker_session.sagemaker_client.describe_training_job(
709+
TrainingJobName=training_job_name
710+
)
711+
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
712+
tags = sagemaker_session.sagemaker_client.list_tags(
713+
ResourceArn=job_details["TrainingJobArn"]
714+
)["Tags"]
715+
init_params.update(tags=tags)
716+
init_params.update(model_id=model_id, model_version=model_version)
717+
718+
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
719+
estimator.latest_training_job = _TrainingJob(
720+
sagemaker_session=sagemaker_session, job_name=training_job_name
721+
)
722+
estimator._current_job_name = estimator.latest_training_job.name
723+
estimator.latest_training_job.wait(logs="None")
724+
return estimator
725+
658726
def deploy(
659727
self,
660728
initial_instance_count: Optional[int] = None,

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

+8
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def test_jumpstart_estimator(setup):
6464
}
6565
)
6666

67+
# test that we can create a JumpStartEstimator from existing job with `attach`
68+
estimator = JumpStartEstimator.attach(
69+
training_job_name=estimator.latest_training_job.name,
70+
model_id=model_id,
71+
model_version=model_version,
72+
sagemaker_session=get_sm_session(),
73+
)
74+
6775
# uses ml.p3.2xlarge instance
6876
predictor = estimator.deploy(
6977
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],

0 commit comments

Comments
 (0)