-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: attach method for jumpstart estimator #4074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: attach method for jumpstart estimator #4074
Conversation
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Codecov Report
@@ Coverage Diff @@
## master #4074 +/- ##
==========================================
- Coverage 90.19% 89.45% -0.74%
==========================================
Files 1296 306 -990
Lines 115184 28480 -86704
==========================================
- Hits 103892 25478 -78414
+ Misses 11292 3002 -8290
|
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
src/sagemaker/jumpstart/estimator.py
Outdated
model_version: str = "*", | ||
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
model_channel_name: str = "model", | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return typing please
src/sagemaker/jumpstart/estimator.py
Outdated
) | ||
estimator._current_job_name = estimator.latest_training_job.name | ||
estimator.latest_training_job.wait(logs="None") | ||
return estimator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are re-writing the whole function, do you think this will be maintainable?
How about at least abstracting away the logic in the base class:
def _attach(cls, ..., addl_kwargs: Optional[Dict[str, Any]] = None):
job_details = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
tags = sagemaker_session.sagemaker_client.list_tags(
ResourceArn=job_details["TrainingJobArn"]
)["Tags"]
init_params.update(tags=tags)
if addl_kwargs:
init_params.update(addl_kwargs)
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(
sagemaker_session=sagemaker_session, job_name=training_job_name
)
estimator._current_job_name = estimator.latest_training_job.name
estimator.latest_training_job.wait(logs="None")
return estimator
Then in Estimator
:
def attach(cls, ...):
return cls._attach()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea. I wanted to avoid modifying the base class, but this seems simple enough
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Bypassing slow test as the failure is unrelated, and a previous iteration of slow test passed. |
Issue #, if available:
Description of changes:
Implement
attach()
method forJumpStartEstimator
. Now, the following workflow is possible:Testing done:
Integ test added
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_base
to create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.