Skip to content

Commit ee2c345

Browse files
author
Dan
authored
fix: fix tags in deploy call for generic estimators (#1146)
1 parent dba0d61 commit ee2c345

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def deploy(
606606
model_name=None,
607607
kms_key=None,
608608
data_capture_config=None,
609+
tags=None,
609610
**kwargs
610611
):
611612
"""Deploy the trained model to an Amazon SageMaker endpoint and return a
@@ -639,18 +640,18 @@ def deploy(
639640
model completes (default: True).
640641
model_name (str): Name to use for creating an Amazon SageMaker
641642
model. If not specified, the name of the training job is used.
642-
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
643-
endpoint. Example:
644-
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
645-
For more information about tags, see
646-
https://boto3.amazonaws.com/v1/documentation\
647-
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
648643
kms_key (str): The ARN of the KMS key that is used to encrypt the
649644
data on the storage volume attached to the instance hosting the
650645
endpoint.
651646
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
652647
configuration related to Endpoint data capture for use with
653648
Amazon SageMaker Model Monitoring. Default: None.
649+
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
650+
endpoint. Example:
651+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
652+
For more information about tags, see
653+
https://boto3.amazonaws.com/v1/documentation\
654+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
654655
**kwargs: Passed to invocation of ``create_model()``.
655656
Implementations may customize ``create_model()`` to accept
656657
``**kwargs`` to customize model creation during deploy.
@@ -685,7 +686,7 @@ def deploy(
685686
accelerator_type=accelerator_type,
686687
endpoint_name=endpoint_name,
687688
update_endpoint=update_endpoint,
688-
tags=self.tags,
689+
tags=tags or self.tags,
689690
wait=wait,
690691
kms_key=kms_key,
691692
data_capture_config=data_capture_config,

tests/unit/test_estimator.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,7 @@ def test_unsupported_type_in_dict():
17021702
)
17031703

17041704

1705-
def test_fit_deploy_keep_tags(sagemaker_session):
1705+
def test_fit_deploy_tags_in_estimator(sagemaker_session):
17061706
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]
17071707
estimator = Estimator(
17081708
IMAGE_NAME,
@@ -1747,6 +1747,46 @@ def test_fit_deploy_keep_tags(sagemaker_session):
17471747
)
17481748

17491749

1750+
def test_fit_deploy_tags(sagemaker_session):
1751+
estimator = Estimator(
1752+
IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
1753+
)
1754+
1755+
estimator.fit()
1756+
1757+
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]
1758+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, tags=tags)
1759+
1760+
variant = [
1761+
{
1762+
"InstanceType": "c4.4xlarge",
1763+
"VariantName": "AllTraffic",
1764+
"ModelName": ANY,
1765+
"InitialVariantWeight": 1,
1766+
"InitialInstanceCount": 1,
1767+
}
1768+
]
1769+
1770+
job_name = estimator._current_job_name
1771+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
1772+
name=job_name,
1773+
production_variants=variant,
1774+
tags=tags,
1775+
kms_key=None,
1776+
wait=True,
1777+
data_capture_config_dict=None,
1778+
)
1779+
1780+
sagemaker_session.create_model.assert_called_with(
1781+
ANY,
1782+
"DummyRole",
1783+
{"ModelDataUrl": "s3://bucket/model.tar.gz", "Environment": {}, "Image": "fakeimage"},
1784+
enable_network_isolation=False,
1785+
vpc_config=None,
1786+
tags=tags,
1787+
)
1788+
1789+
17501790
def test_generic_to_fit_no_input(sagemaker_session):
17511791
e = Estimator(
17521792
IMAGE_NAME,

0 commit comments

Comments
 (0)