Skip to content

Fix for Estimator Training details on register #4788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from sagemaker.interactive_apps import SupportedInteractiveAppTypes
from sagemaker.interactive_apps.tensorboard import TensorBoardApp
from sagemaker.instance_group import InstanceGroup
from sagemaker.model_card.model_card import ModelCard, TrainingDetails
from sagemaker.utils import instance_supports_kms
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import (
Expand Down Expand Up @@ -1797,8 +1798,17 @@ def register(
else:
if "model_kms_key" not in kwargs:
kwargs["model_kms_key"] = self.output_kms_key
model = self.create_model(image_uri=image_uri, **kwargs)
model = self.create_model(image_uri=image_uri, name=model_name, **kwargs)
model.name = model_name
if self.model_data is not None and model_card is None:
training_details = TrainingDetails.from_model_s3_artifacts(
model_artifacts=[self.model_data], sagemaker_session=self.sagemaker_session
)
model_card = ModelCard(
name="estimator_card",
training_details=training_details,
sagemaker_session=self.sagemaker_session,
)
return model.register(
content_types,
response_types,
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,10 @@ def register(
model_package_group_name = utils.base_name_from_image(
self.image_uri, default_base_name=ModelPackage.__name__
)
if model_package_group_name is not None:
if (
model_package_group_name is not None
and model_type is not JumpStartModelType.PROPRIETARY
):
container_def = self.prepare_container_def(accept_eula=accept_eula)
container_def = update_container_with_inference_params(
framework=framework,
Expand Down
60 changes: 60 additions & 0 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io
import json
import os

import numpy as np

import pytest
import sagemaker.amazon.common as smac


import sagemaker
from sagemaker import image_uris
from sagemaker.estimator import Estimator
from sagemaker.s3 import S3Uploader
from sagemaker.serializers import SimpleBaseSerializer
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets
Expand Down Expand Up @@ -102,6 +108,60 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
assert prediction["score"] is not None


@pytest.mark.release
def test_estimator_register_publish_training_details(sagemaker_session, region):

bucket = sagemaker_session.default_bucket()
prefix = "model-card-sample-notebook"

raw_data = (
(0.5, 0),
(0.75, 0),
(1.0, 0),
(1.25, 0),
(1.50, 0),
(1.75, 0),
(2.0, 0),
(2.25, 1),
(2.5, 0),
(2.75, 1),
(3.0, 0),
(3.25, 1),
(3.5, 0),
(4.0, 1),
(4.25, 1),
(4.5, 1),
(4.75, 1),
(5.0, 1),
(5.5, 1),
)
training_data = np.array(raw_data).astype("float32")
labels = training_data[:, 1]

# upload data to S3 bucket
buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(buf, training_data, labels)
buf.seek(0)
s3_train_data = f"s3://{bucket}/{prefix}/train"
S3Uploader.upload_bytes(b=buf, s3_uri=s3_train_data, sagemaker_session=sagemaker_session)
output_location = f"s3://{bucket}/{prefix}/output"
container = image_uris.retrieve("linear-learner", region)
estimator = Estimator(
container,
role="SageMakerRole",
instance_count=1,
instance_type="ml.m4.xlarge",
output_path=output_location,
sagemaker_session=sagemaker_session,
)
estimator.set_hyperparameters(
feature_dim=2, mini_batch_size=10, predictor_type="binary_classifier"
)
estimator.fit({"train": s3_train_data})
print(f"Training job name: {estimator.latest_training_job.name}")
estimator.register()


def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, training_set):
image_uri = image_uris.retrieve("factorization-machines", region)
endpoint_name = unique_name_from_base("byo")
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4402,7 +4402,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
framework = "TENSORFLOW"
framework_version = "2.9"
nearest_model_name = "resnet50"

model_card = {"ModelCardStatus": ModelCardStatusEnum.DRAFT, "ModelCardContent": "{}"}
estimator.register(
content_types=content_types,
response_types=response_types,
Expand All @@ -4425,6 +4425,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
"marketplace_cert": False,
"sample_payload_url": sample_payload_url,
"task": task,
"model_card": model_card,
}
sagemaker_session.create_model_package_from_containers.assert_called_with(
**expected_create_model_package_request
Expand Down Expand Up @@ -4454,6 +4455,7 @@ def test_register_inference_image(sagemaker_session):
framework = "TENSORFLOW"
framework_version = "2.9"
nearest_model_name = "resnet50"
model_card = {"ModelCardStatus": ModelCardStatusEnum.DRAFT, "ModelCardContent": "{}"}

estimator.register(
content_types=content_types,
Expand All @@ -4480,6 +4482,7 @@ def test_register_inference_image(sagemaker_session):
"marketplace_cert": False,
"sample_payload_url": sample_payload_url,
"task": task,
"model_card": model_card,
}
sagemaker_session.create_model_package_from_containers.assert_called_with(
**expected_create_model_package_request
Expand Down