Skip to content

feat: attach method for jumpstart estimator #4074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig
from sagemaker.debugger.profiler_config import ProfilerConfig

from sagemaker.estimator import Estimator
from sagemaker.estimator import _TrainingJob, Estimator
from sagemaker.explainer.explainer_config import ExplainerConfig
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.instance_group import InstanceGroup
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG

Expand Down Expand Up @@ -655,6 +656,73 @@ def fit(

return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())

@classmethod
def attach(
cls,
training_job_name: str,
model_id: str,
model_version: str = "*",
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return typing please

"""Attach to an existing training job.

Create an Estimator bound to an existing training job.
After attaching, if the training job has a Complete status,
it can be ``deploy()`` ed to create a SageMaker Endpoint and return
a ``Predictor``.

If the training job is in progress, attach will block until the training job
completes, but logs of the training job will not display. To see the logs
content, please call ``logs()``

Examples:
>>> my_estimator.fit(wait=False)
>>> training_job_name = my_estimator.latest_training_job.name
Later on:
>>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
>>> attached_estimator.logs()
>>> attached_estimator.deploy()

Args:
training_job_name (str): The name of the training job to attach to.
model_id (str): The name of the JumpStart model id associated with the
training job.
model_version (str): Optional. The version of the JumpStart model id
associated with the training job. (Default: "*").
sagemaker_session (sagemaker.session.Session): Optional. Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
model_channel_name (str): Optional. Name of the channel where pre-trained
model data will be downloaded (default: 'model'). If no channel
with the same name exists in the training job, this option will
be ignored.

Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
training job.
"""

job_details = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
tags = sagemaker_session.sagemaker_client.list_tags(
ResourceArn=job_details["TrainingJobArn"]
)["Tags"]
init_params.update(tags=tags)
init_params.update(model_id=model_id, model_version=model_version)

estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(
sagemaker_session=sagemaker_session, job_name=training_job_name
)
estimator._current_job_name = estimator.latest_training_job.name
estimator.latest_training_job.wait(logs="None")
return estimator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are re-writing the whole function, do you think this will be maintainable?

How about at least abstracting away the logic in the base class:

def _attach(cls, ..., addl_kwargs: Optional[Dict[str, Any]] = None):

        job_details = sagemaker_session.sagemaker_client.describe_training_job(
            TrainingJobName=training_job_name
        )
        init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
        tags = sagemaker_session.sagemaker_client.list_tags(
            ResourceArn=job_details["TrainingJobArn"]
        )["Tags"]
        init_params.update(tags=tags)
        if addl_kwargs:
          init_params.update(addl_kwargs)

        estimator = cls(sagemaker_session=sagemaker_session, **init_params)
        estimator.latest_training_job = _TrainingJob(
            sagemaker_session=sagemaker_session, job_name=training_job_name
        )
        estimator._current_job_name = estimator.latest_training_job.name
        estimator.latest_training_job.wait(logs="None")
        return estimator

Then in Estimator:

def attach(cls, ...):
  return cls._attach()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. I wanted to avoid modifying the base class, but this seems simple enough


def deploy(
self,
initial_instance_count: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def test_jumpstart_estimator(setup):
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
Expand Down