Skip to content

Commit 15e26c4

Browse files
grenmesterJacky Lee
and
Jacky Lee
authored
fix: make telemetry logger persist certain information (aws#1500)
* refactor telemetry logger * refactor * refactor * pylint + UT * add tag * add remove tags * handle tags again * pylint --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent f1bc99e commit 15e26c4

File tree

7 files changed

+136
-66
lines changed

7 files changed

+136
-66
lines changed

src/sagemaker/model.py

+9
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
Tags,
7575
_resolve_routing_config,
7676
_validate_new_tags,
77+
remove_tag_with_key,
7778
)
7879
from sagemaker.async_inference import AsyncInferenceConfig
7980
from sagemaker.predictor_async import AsyncPredictor
@@ -426,6 +427,14 @@ def add_tags(self, tags: Tags) -> None:
426427
"""
427428
self._tags = _validate_new_tags(tags, self._tags)
428429

430+
def remove_tag_with_key(self, key: str) -> None:
431+
"""Remove a tag with the given key from the list of tags.
432+
433+
Args:
434+
key (str): The key of the tag to remove.
435+
"""
436+
self._tags = remove_tag_with_key(key, self._tags)
437+
429438
@classmethod
430439
def attach(
431440
cls,

src/sagemaker/serve/builder/jumpstart_builder.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def __init__(self):
116116
self.model_metadata = None
117117
self.role_arn = None
118118
self.is_fine_tuned = None
119-
self.is_gated = None
119+
self.is_compiled = False
120+
self.is_quantized = False
121+
self.speculative_decoding_draft_model_source = None
120122

121123
@abstractmethod
122124
def _prepare_for_mode(self):
@@ -503,6 +505,18 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
503505

504506
self.pysdk_model.set_deployment_config(config_name, instance_type)
505507

508+
self.instance_type = instance_type
509+
510+
# JS-benchmarked models only include SageMaker-provided SD models
511+
if self.pysdk_model.additional_model_data_sources:
512+
self.speculative_decoding_draft_model_source = "sagemaker"
513+
self.pysdk_model.add_tags(
514+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
515+
)
516+
self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME)
517+
self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH)
518+
self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME)
519+
506520
def get_deployment_config(self) -> Optional[Dict[str, Any]]:
507521
"""Gets the deployment config to apply to the model.
508522
@@ -775,10 +789,8 @@ def _is_gated_model(self, model=None) -> bool:
775789
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
776790

777791
if s3_uri is None:
778-
self.is_gated = False
779-
else:
780-
self.is_gated = "private" in s3_uri
781-
return self.is_gated
792+
return False
793+
return "private" in s3_uri
782794

783795
def _set_additional_model_source(
784796
self,

src/sagemaker/serve/builder/model_builder.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from pathlib import Path
2525

26+
from sagemaker.enums import Tag
2627
from sagemaker.s3 import S3Downloader
2728

2829
from sagemaker import Session
@@ -69,6 +70,7 @@
6970
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
7071
from sagemaker.serve.utils.optimize_utils import (
7172
_generate_optimized_model,
73+
_extract_speculative_draft_model_provider,
7274
)
7375
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
7476
from sagemaker.serve.utils.hardware_detector import (
@@ -647,11 +649,6 @@ def _handle_mlflow_input(self):
647649
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
648650
artifact_path = self._get_artifact_path(mlflow_model_path)
649651
if not self._mlflow_metadata_exists(artifact_path):
650-
logger.info(
651-
"MLflow model metadata not detected in %s. ModelBuilder is not "
652-
"handling MLflow model input",
653-
mlflow_model_path,
654-
)
655652
return
656653

657654
self._initialize_for_mlflow(artifact_path)
@@ -1144,6 +1141,12 @@ def _model_builder_optimize_wrapper(
11441141
Returns:
11451142
Model: A deployable ``Model`` object.
11461143
"""
1144+
self.is_compiled = compilation_config is not None
1145+
self.is_quantized = quantization_config is not None
1146+
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
1147+
speculative_decoding_config
1148+
)
1149+
11471150
if quantization_config and compilation_config:
11481151
raise ValueError("Quantization config and compilation config are mutually exclusive.")
11491152

@@ -1180,4 +1183,8 @@ def _model_builder_optimize_wrapper(
11801183
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
11811184
return _generate_optimized_model(self.pysdk_model, job_status)
11821185

1186+
self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME)
1187+
if not speculative_decoding_config:
1188+
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)
1189+
11831190
return self.pysdk_model

src/sagemaker/serve/utils/telemetry_logger.py

+45-49
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,38 @@ def wrapper(self, *args, **kwargs):
9494
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
9595
response = None
9696
caught_ex = None
97-
97+
status = "1"
98+
failure_reason = None
99+
failure_type = None
98100
extra = f"{func_name}"
99101

102+
start_timer = perf_counter()
103+
try:
104+
response = func(self, *args, **kwargs)
105+
except (
106+
ModelBuilderException,
107+
exceptions.CapacityError,
108+
exceptions.UnexpectedStatusException,
109+
exceptions.AsyncInferenceError,
110+
) as e:
111+
status = "0"
112+
caught_ex = e
113+
failure_reason = str(e)
114+
failure_type = e.__class__.__name__
115+
except Exception as e: # pylint: disable=W0703
116+
raise e
117+
118+
stop_timer = perf_counter()
119+
elapsed = stop_timer - start_timer
120+
100121
if self.model_server:
101122
extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}"
102123

103124
if self.image_uri:
104125
image_uri_tail = self.image_uri.split("/")[1]
105-
image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri)
126+
image_uri_option = _get_image_uri_option(
127+
self.image_uri, getattr(self, "_is_custom_image_uri", False)
128+
)
106129

107130
if self.image_uri:
108131
extra += f"&x-imageTag={image_uri_tail}"
@@ -128,63 +151,36 @@ def wrapper(self, *args, **kwargs):
128151

129152
if getattr(self, "is_fine_tuned", False):
130153
extra += "&x-fineTuned=1"
131-
if getattr(self, "is_gated", False):
132-
extra += "&x-gated=1"
133154

134-
if kwargs.get("compilation_config"):
155+
if getattr(self, "is_compiled", False):
135156
extra += "&x-compiled=1"
136-
if kwargs.get("quantization_config"):
157+
if getattr(self, "is_quantized", False):
137158
extra += "&x-quantized=1"
138-
if kwargs.get("speculative_decoding_config"):
139-
model_provider = kwargs["speculative_decoding_config"]["ModelProvider"]
159+
if getattr(self, "speculative_decoding_draft_model_source", False):
140160
model_provider_enum = (
141161
SpeculativeDecodingDraftModelSource.SAGEMAKER
142-
if model_provider.lower() == "sagemaker"
162+
if self.speculative_decoding_draft_model_source == "sagemaker"
143163
else SpeculativeDecodingDraftModelSource.CUSTOM
144164
)
145165
model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)]
146166
extra += f"&x-sdDraftModelSource={model_provider_value}"
147167

148-
start_timer = perf_counter()
149-
try:
150-
response = func(self, *args, **kwargs)
151-
stop_timer = perf_counter()
152-
elapsed = stop_timer - start_timer
153-
extra += f"&x-latency={round(elapsed, 2)}"
154-
if not self.serve_settings.telemetry_opt_out:
155-
_send_telemetry(
156-
"1",
157-
MODE_TO_CODE[str(self.mode)],
158-
self.sagemaker_session,
159-
None,
160-
None,
161-
extra,
162-
)
163-
except (
164-
ModelBuilderException,
165-
exceptions.CapacityError,
166-
exceptions.UnexpectedStatusException,
167-
exceptions.AsyncInferenceError,
168-
) as e:
169-
stop_timer = perf_counter()
170-
elapsed = stop_timer - start_timer
171-
extra += f"&x-latency={round(elapsed, 2)}"
172-
if not self.serve_settings.telemetry_opt_out:
173-
_send_telemetry(
174-
"0",
175-
MODE_TO_CODE[str(self.mode)],
176-
self.sagemaker_session,
177-
str(e),
178-
e.__class__.__name__,
179-
extra,
180-
)
181-
caught_ex = e
182-
except Exception as e: # pylint: disable=W0703
183-
caught_ex = e
184-
finally:
185-
if caught_ex:
186-
raise caught_ex
187-
return response # pylint: disable=W0150
168+
extra += f"&x-latency={round(elapsed, 2)}"
169+
170+
if not self.serve_settings.telemetry_opt_out:
171+
_send_telemetry(
172+
status,
173+
MODE_TO_CODE[str(self.mode)],
174+
self.sagemaker_session,
175+
failure_reason,
176+
failure_type,
177+
extra,
178+
)
179+
180+
if caught_ex:
181+
raise caught_ex
182+
183+
return response
188184

189185
return wrapper
190186

src/sagemaker/utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -1873,3 +1873,30 @@ def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> O
18731873
curr_tags.append(new_tag)
18741874

18751875
return curr_tags
1876+
1877+
1878+
def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
1879+
"""Remove a tag with the given key from the list of tags.
1880+
1881+
Args:
1882+
key (str): The key of the tag to remove.
1883+
tags (Optional[Tags]): The current list of tags.
1884+
1885+
Returns:
1886+
Optional[Tags]: The updated list of tags with the tag removed.
1887+
"""
1888+
if tags is None:
1889+
return tags
1890+
if isinstance(tags, dict):
1891+
tags = [tags]
1892+
1893+
updated_tags = []
1894+
for tag in tags:
1895+
if tag["Key"] != key:
1896+
updated_tags.append(tag)
1897+
1898+
if not updated_tags:
1899+
return None
1900+
if len(updated_tags) == 1:
1901+
return updated_tags[0]
1902+
return updated_tags

tests/unit/sagemaker/serve/utils/test_telemetry_logger.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,15 @@ def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_sen
314314
mock_model_builder.model_server = ModelServer.TORCHSERVE
315315
mock_model_builder.sagemaker_session.endpoint_arn = None
316316
mock_model_builder.is_fine_tuned = True
317-
mock_model_builder.is_gated = True
317+
mock_model_builder.is_compiled = True
318+
mock_model_builder.is_quantized = True
319+
mock_model_builder.speculative_decoding_draft_model_source = "sagemaker"
318320

319321
mock_speculative_decoding_config = MagicMock()
320322
mock_config = {"ModelProvider": "sagemaker"}
321323
mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__
322324

323-
mock_model_builder.mock_optimize(
324-
quantization_config=Mock(),
325-
compilation_config=Mock(),
326-
speculative_decoding_config=mock_speculative_decoding_config,
327-
)
325+
mock_model_builder.mock_optimize()
328326

329327
args = mock_send_telemetry.call_args.args
330328
latency = str(args[5]).split("latency=")[1]
@@ -333,7 +331,6 @@ def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_sen
333331
"&x-modelServer=1"
334332
f"&x-sdkVersion={SDK_VERSION}"
335333
f"&x-fineTuned=1"
336-
f"&x-gated=1"
337334
f"&x-compiled=1"
338335
f"&x-quantized=1"
339336
f"&x-sdDraftModelSource=1"

tests/unit/test_utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
_resolve_routing_config,
5858
tag_exists,
5959
_validate_new_tags,
60+
remove_tag_with_key,
6061
)
6162
from tests.unit.sagemaker.workflow.helpers import CustomStep
6263
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
@@ -2124,3 +2125,24 @@ def test_new_add_tags(self):
21242125
new_tag = {"Key": "project-2", "Value": "my-project-2"}
21252126

21262127
self.assertEqual(_validate_new_tags(new_tag, None), new_tag)
2128+
2129+
def test_remove_existing_tag(self):
2130+
original_tags = [
2131+
{"Key": "Tag1", "Value": "Value1"},
2132+
{"Key": "Tag2", "Value": "Value2"},
2133+
{"Key": "Tag3", "Value": "Value3"},
2134+
]
2135+
expected_output = [{"Key": "Tag1", "Value": "Value1"}, {"Key": "Tag3", "Value": "Value3"}]
2136+
self.assertEqual(remove_tag_with_key("Tag2", original_tags), expected_output)
2137+
2138+
def test_remove_non_existent_tag(self):
2139+
original_tags = [
2140+
{"Key": "Tag1", "Value": "Value1"},
2141+
{"Key": "Tag2", "Value": "Value2"},
2142+
{"Key": "Tag3", "Value": "Value3"},
2143+
]
2144+
self.assertEqual(remove_tag_with_key("NonExistentTag", original_tags), original_tags)
2145+
2146+
def test_remove_only_tag(self):
2147+
original_tags = [{"Key": "Tag1", "Value": "Value1"}]
2148+
self.assertIsNone(remove_tag_with_key("Tag1", original_tags))

0 commit comments

Comments
 (0)