Skip to content

Commit a3f5874

Browse files
authored
feat: attach method for jumpstart estimator (#4074)
1 parent 34d3961 commit a3f5874

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

src/sagemaker/estimator.py

+21
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14291429
Instance of the calling ``Estimator`` Class with the attached
14301430
training job.
14311431
"""
1432+
return cls._attach(
1433+
training_job_name=training_job_name,
1434+
sagemaker_session=sagemaker_session,
1435+
model_channel_name=model_channel_name,
1436+
)
1437+
1438+
@classmethod
1439+
def _attach(
1440+
cls,
1441+
training_job_name: str,
1442+
sagemaker_session: Optional[str] = None,
1443+
model_channel_name: str = "model",
1444+
additional_kwargs: Optional[Dict[str, Any]] = None,
1445+
) -> "EstimatorBase":
1446+
"""Creates an Estimator bound to an existing training job.
1447+
1448+
Additional kwargs are allowed for instantiating Estimator.
1449+
"""
14321450
sagemaker_session = sagemaker_session or Session()
14331451

14341452
job_details = sagemaker_session.sagemaker_client.describe_training_job(
@@ -1440,6 +1458,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14401458
)["Tags"]
14411459
init_params.update(tags=tags)
14421460

1461+
if additional_kwargs:
1462+
init_params.update(additional_kwargs)
1463+
14431464
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
14441465
estimator.latest_training_job = _TrainingJob(
14451466
sagemaker_session=sagemaker_session, job_name=training_job_name

src/sagemaker/jumpstart/estimator.py

+57
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.inputs import FileSystemInput, TrainingInput
2828
from sagemaker.instance_group import InstanceGroup
2929
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
30+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
3031
from sagemaker.jumpstart.enums import JumpStartScriptScope
3132
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3233

@@ -655,6 +656,62 @@ def fit(
655656

656657
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
657658

659+
@classmethod
660+
def attach(
661+
cls,
662+
training_job_name: str,
663+
model_id: str,
664+
model_version: str = "*",
665+
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
666+
model_channel_name: str = "model",
667+
) -> "JumpStartEstimator":
668+
"""Attach to an existing training job.
669+
670+
Create a JumpStartEstimator bound to an existing training job.
671+
After attaching, if the training job has a Complete status,
672+
it can be ``deploy()`` ed to create a SageMaker Endpoint and return
673+
a ``Predictor``.
674+
675+
If the training job is in progress, attach will block until the training job
676+
completes, but logs of the training job will not display. To see the logs
677+
content, please call ``logs()``
678+
679+
Examples:
680+
>>> my_estimator.fit(wait=False)
681+
>>> training_job_name = my_estimator.latest_training_job.name
682+
Later on:
683+
>>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
684+
>>> attached_estimator.logs()
685+
>>> attached_estimator.deploy()
686+
687+
Args:
688+
training_job_name (str): The name of the training job to attach to.
689+
model_id (str): The name of the JumpStart model id associated with the
690+
training job.
691+
model_version (str): Optional. The version of the JumpStart model id
692+
associated with the training job. (Default: "*").
693+
sagemaker_session (sagemaker.session.Session): Optional. Session object which
694+
manages interactions with Amazon SageMaker APIs and any other
695+
AWS services needed. If not specified, the estimator creates one
696+
using the default AWS configuration chain.
697+
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
698+
model_channel_name (str): Optional. Name of the channel where pre-trained
699+
model data will be downloaded (default: 'model'). If no channel
700+
with the same name exists in the training job, this option will
701+
be ignored.
702+
703+
Returns:
704+
Instance of the calling ``JumpStartEstimator`` Class with the attached
705+
training job.
706+
"""
707+
708+
return cls._attach(
709+
training_job_name=training_job_name,
710+
sagemaker_session=sagemaker_session,
711+
model_channel_name=model_channel_name,
712+
additional_kwargs={"model_id": model_id, "model_version": model_version},
713+
)
714+
658715
def deploy(
659716
self,
660717
initial_instance_count: Optional[int] = None,

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

+8
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def test_jumpstart_estimator(setup):
6464
}
6565
)
6666

67+
# test that we can create a JumpStartEstimator from existing job with `attach`
68+
estimator = JumpStartEstimator.attach(
69+
training_job_name=estimator.latest_training_job.name,
70+
model_id=model_id,
71+
model_version=model_version,
72+
sagemaker_session=get_sm_session(),
73+
)
74+
6775
# uses ml.p3.2xlarge instance
6876
predictor = estimator.deploy(
6977
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],

0 commit comments

Comments
 (0)