-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: attach method for jumpstart estimator #4074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
7b86786
4c08347
5f1b773
bc00b8f
57290d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,12 @@ | |
from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig | ||
from sagemaker.debugger.profiler_config import ProfilerConfig | ||
|
||
from sagemaker.estimator import Estimator | ||
from sagemaker.estimator import _TrainingJob, Estimator | ||
from sagemaker.explainer.explainer_config import ExplainerConfig | ||
from sagemaker.inputs import FileSystemInput, TrainingInput | ||
from sagemaker.instance_group import InstanceGroup | ||
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor | ||
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG | ||
|
||
|
@@ -655,6 +656,73 @@ def fit( | |
|
||
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) | ||
|
||
@classmethod | ||
def attach( | ||
cls, | ||
training_job_name: str, | ||
model_id: str, | ||
model_version: str = "*", | ||
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
model_channel_name: str = "model", | ||
): | ||
"""Attach to an existing training job. | ||
|
||
Create an Estimator bound to an existing training job. | ||
After attaching, if the training job has a Complete status, | ||
it can be ``deploy()`` ed to create a SageMaker Endpoint and return | ||
a ``Predictor``. | ||
|
||
If the training job is in progress, attach will block until the training job | ||
completes, but logs of the training job will not display. To see the logs | ||
content, please call ``logs()`` | ||
|
||
Examples: | ||
>>> my_estimator.fit(wait=False) | ||
>>> training_job_name = my_estimator.latest_training_job.name | ||
Later on: | ||
>>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id) | ||
>>> attached_estimator.logs() | ||
>>> attached_estimator.deploy() | ||
|
||
Args: | ||
training_job_name (str): The name of the training job to attach to. | ||
model_id (str): The name of the JumpStart model id associated with the | ||
training job. | ||
model_version (str): Optional. The version of the JumpStart model id | ||
associated with the training job. (Default: "*"). | ||
sagemaker_session (sagemaker.session.Session): Optional. Session object which | ||
manages interactions with Amazon SageMaker APIs and any other | ||
AWS services needed. If not specified, the estimator creates one | ||
using the default AWS configuration chain. | ||
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). | ||
model_channel_name (str): Optional. Name of the channel where pre-trained | ||
model data will be downloaded (default: 'model'). If no channel | ||
with the same name exists in the training job, this option will | ||
be ignored. | ||
|
||
Returns: | ||
Instance of the calling ``JumpStartEstimator`` Class with the attached | ||
training job. | ||
""" | ||
|
||
job_details = sagemaker_session.sagemaker_client.describe_training_job( | ||
TrainingJobName=training_job_name | ||
) | ||
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) | ||
tags = sagemaker_session.sagemaker_client.list_tags( | ||
ResourceArn=job_details["TrainingJobArn"] | ||
)["Tags"] | ||
init_params.update(tags=tags) | ||
init_params.update(model_id=model_id, model_version=model_version) | ||
|
||
estimator = cls(sagemaker_session=sagemaker_session, **init_params) | ||
estimator.latest_training_job = _TrainingJob( | ||
sagemaker_session=sagemaker_session, job_name=training_job_name | ||
) | ||
estimator._current_job_name = estimator.latest_training_job.name | ||
estimator.latest_training_job.wait(logs="None") | ||
return estimator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are re-writing the whole function, do you think this will be maintainable? How about at least abstracting away the logic in the base class: def _attach(cls, ..., addl_kwargs: Optional[Dict[str, Any]] = None):
job_details = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
tags = sagemaker_session.sagemaker_client.list_tags(
ResourceArn=job_details["TrainingJobArn"]
)["Tags"]
init_params.update(tags=tags)
if addl_kwargs:
init_params.update(addl_kwargs)
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(
sagemaker_session=sagemaker_session, job_name=training_job_name
)
estimator._current_job_name = estimator.latest_training_job.name
estimator.latest_training_job.wait(logs="None")
return estimator Then in def attach(cls, ...):
return cls._attach() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea. I wanted to avoid modifying the base class, but this seems simple enough |
||
|
||
def deploy( | ||
self, | ||
initial_instance_count: Optional[int] = None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return typing please