From 7b86786861cf3d620f2df9f8e99552ee14aba5e8 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 17 Aug 2023 20:45:34 +0000 Subject: [PATCH 1/4] feat: attach method for jumpstart estimator --- src/sagemaker/jumpstart/estimator.py | 70 ++++++++++++++++++- .../estimator/test_jumpstart_estimator.py | 8 +++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 3948bf5775..489903f6a0 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -22,11 +22,12 @@ from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig from sagemaker.debugger.profiler_config import ProfilerConfig -from sagemaker.estimator import Estimator +from sagemaker.estimator import _TrainingJob, Estimator from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -655,6 +656,73 @@ def fit( return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) + @classmethod + def attach( + cls, + training_job_name: str, + model_id: str, + model_version: str = "*", + sagemaker_session=None, + model_channel_name="model", + ): + """Attach to an existing training job. + + Create an Estimator bound to an existing training job. + After attaching, if the training job has a Complete status, + it can be ``deploy()`` ed to create a SageMaker Endpoint and return + a ``Predictor``. + + If the training job is in progress, attach will block until the training job + completes, but logs of the training job will not display. To see the logs + content, please call ``logs()`` + + Examples: + >>> my_estimator.fit(wait=False) + >>> training_job_name = my_estimator.latest_training_job.name + Later on: + >>> attached_estimator = JumpStartEstimator.attach(training_job_name, model_id) + >>> attached_estimator.logs() + >>> attached_estimator.deploy() + + Args: + training_job_name (str): The name of the training job to attach to. + model_id (str): The name of the JumpStart model id associated with the + training job. + model_version (str): Optional. The version of the JumpStart model id + associated with the training job. (Default: "*"). + sagemaker_session (sagemaker.session.Session): Optional. Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, the estimator creates one + using the default AWS configuration chain. + model_channel_name (str): Optional. Name of the channel where pre-trained + model data will be downloaded (default: 'model'). If no channel + with the same name exists in the training job, this option will + be ignored. + + Returns: + Instance of the calling ``Estimator`` Class with the attached + training job. + """ + sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + job_details = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_name + ) + init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) + tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=job_details["TrainingJobArn"] + )["Tags"] + init_params.update(tags=tags) + init_params.update(model_id=model_id, model_version=model_version) + + estimator = cls(sagemaker_session=sagemaker_session, **init_params) + estimator.latest_training_job = _TrainingJob( + sagemaker_session=sagemaker_session, job_name=training_job_name + ) + estimator._current_job_name = estimator.latest_training_job.name + estimator.latest_training_job.wait(logs="None") + return estimator + def deploy( self, initial_instance_count: Optional[int] = None, diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index f56d562812..d5a7d7ff00 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -64,6 +64,14 @@ def test_jumpstart_estimator(setup): } ) + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], From 5f1b7731f3a620bfc3db00bc9bd9883939d281aa Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 18 Aug 2023 13:32:32 +0000 Subject: [PATCH 2/4] fix: docstring, default args --- src/sagemaker/jumpstart/estimator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 489903f6a0..01a0d7936a 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -662,8 +662,8 @@ def attach( training_job_name: str, model_id: str, model_version: str = "*", - sagemaker_session=None, - model_channel_name="model", + sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_channel_name: str = "model", ): """Attach to an existing training job. @@ -694,16 +694,16 @@ def attach( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). model_channel_name (str): Optional. Name of the channel where pre-trained model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. Returns: - Instance of the calling ``Estimator`` Class with the attached + Instance of the calling ``JumpStartEstimator`` Class with the attached training job. """ - sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION job_details = sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=training_job_name From bc00b8f7df1b8976f7bd05ff71b7b6f4526bbf03 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 18 Aug 2023 16:17:14 +0000 Subject: [PATCH 3/4] chore: cleanup code --- src/sagemaker/estimator.py | 16 ++++++++++++++++ src/sagemaker/jumpstart/estimator.py | 27 ++++++++------------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 409ec57c79..d1997b6c9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1429,6 +1429,19 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m Instance of the calling ``Estimator`` Class with the attached training job. """ + return cls._attach( + training_job_name=training_job_name, + sagemaker_session=sagemaker_session, + model_channel_name=model_channel_name, + ) + + def _attach( + cls, + training_job_name: str, + sagemaker_session: Optional[str] = None, + model_channel_name: str = "model", + additional_kwargs: Optional[Dict[str, Any]] = None, + ) -> "EstimatorBase": sagemaker_session = sagemaker_session or Session() job_details = sagemaker_session.sagemaker_client.describe_training_job( @@ -1440,6 +1453,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m )["Tags"] init_params.update(tags=tags) + if additional_kwargs: + init_params.update(additional_kwargs) + estimator = cls(sagemaker_session=sagemaker_session, **init_params) estimator.latest_training_job = _TrainingJob( sagemaker_session=sagemaker_session, job_name=training_job_name diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 01a0d7936a..eae91f2a8d 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -22,7 +22,7 @@ from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig from sagemaker.debugger.profiler_config import ProfilerConfig -from sagemaker.estimator import _TrainingJob, Estimator +from sagemaker.estimator import Estimator from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.instance_group import InstanceGroup @@ -664,10 +664,10 @@ def attach( model_version: str = "*", sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", - ): + ) -> "JumpStartEstimator": """Attach to an existing training job. - Create an Estimator bound to an existing training job. + Create a JumpStartEstimator bound to an existing training job. After attaching, if the training job has a Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. @@ -705,23 +705,12 @@ def attach( training job. """ - job_details = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=training_job_name - ) - init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) - tags = sagemaker_session.sagemaker_client.list_tags( - ResourceArn=job_details["TrainingJobArn"] - )["Tags"] - init_params.update(tags=tags) - init_params.update(model_id=model_id, model_version=model_version) - - estimator = cls(sagemaker_session=sagemaker_session, **init_params) - estimator.latest_training_job = _TrainingJob( - sagemaker_session=sagemaker_session, job_name=training_job_name + return cls._attach( + training_job_name=training_job_name, + sagemaker_session=sagemaker_session, + model_channel_name=model_channel_name, + additional_kwargs={"model_id": model_id, "model_version": model_version}, ) - estimator._current_job_name = estimator.latest_training_job.name - estimator.latest_training_job.wait(logs="None") - return estimator def deploy( self, From 57290d2088cdad36795d68d6ff592e2c1db9c621 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 18 Aug 2023 17:19:24 +0000 Subject: [PATCH 4/4] fix: pylint --- src/sagemaker/estimator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index d1997b6c9f..8b09a17556 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1435,6 +1435,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m model_channel_name=model_channel_name, ) + @classmethod def _attach( cls, training_job_name: str, @@ -1442,6 +1443,10 @@ def _attach( model_channel_name: str = "model", additional_kwargs: Optional[Dict[str, Any]] = None, ) -> "EstimatorBase": + """Creates an Estimator bound to an existing training job. + + Additional kwargs are allowed for instantiating Estimator. + """ sagemaker_session = sagemaker_session or Session() job_details = sagemaker_session.sagemaker_client.describe_training_job(