Skip to content

feat: attach method for jumpstart estimator #4074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
Instance of the calling ``Estimator`` Class with the attached
training job.
"""
return cls._attach(
training_job_name=training_job_name,
sagemaker_session=sagemaker_session,
model_channel_name=model_channel_name,
)

@classmethod
def _attach(
cls,
training_job_name: str,
sagemaker_session: Optional[str] = None,
model_channel_name: str = "model",
additional_kwargs: Optional[Dict[str, Any]] = None,
) -> "EstimatorBase":
"""Creates an Estimator bound to an existing training job.

Additional kwargs are allowed for instantiating Estimator.
"""
sagemaker_session = sagemaker_session or Session()

job_details = sagemaker_session.sagemaker_client.describe_training_job(
Expand All @@ -1440,6 +1458,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
)["Tags"]
init_params.update(tags=tags)

if additional_kwargs:
init_params.update(additional_kwargs)

estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(
sagemaker_session=sagemaker_session, job_name=training_job_name
Expand Down
57 changes: 57 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.instance_group import InstanceGroup
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG

Expand Down Expand Up @@ -655,6 +656,62 @@ def fit(

return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())

@classmethod
def attach(
cls,
training_job_name: str,
model_id: str,
model_version: str = "*",
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
) -> "JumpStartEstimator":
"""Attach to an existing training job.

Create a JumpStartEstimator bound to an existing training job.
After attaching, if the training job has a Complete status,
it can be ``deploy()`` ed to create a SageMaker Endpoint and return
a ``Predictor``.

If the training job is in progress, attach will block until the training job
completes, but logs of the training job will not display. To see the logs
content, please call ``logs()``

Examples:
>>> my_estimator.fit(wait=False)
>>> training_job_name = my_estimator.latest_training_job.name
Later on:
>>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
>>> attached_estimator.logs()
>>> attached_estimator.deploy()

Args:
training_job_name (str): The name of the training job to attach to.
model_id (str): The name of the JumpStart model id associated with the
training job.
model_version (str): Optional. The version of the JumpStart model id
associated with the training job. (Default: "*").
sagemaker_session (sagemaker.session.Session): Optional. Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
model_channel_name (str): Optional. Name of the channel where pre-trained
model data will be downloaded (default: 'model'). If no channel
with the same name exists in the training job, this option will
be ignored.

Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
training job.
"""

return cls._attach(
training_job_name=training_job_name,
sagemaker_session=sagemaker_session,
model_channel_name=model_channel_name,
additional_kwargs={"model_id": model_id, "model_version": model_version},
)

def deploy(
self,
initial_instance_count: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def test_jumpstart_estimator(setup):
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
Expand Down