Skip to content

Commit 5036fb8

Browse files
zhaoqizqwangpintaoz-aws
authored andcommitted
[Updated] Add telemetry to ModelTrainer, Estimator and ModelBuilder (#1608)
* [Updated] Add telemetry to ModelTrainer, Estimator and ModelBuilder * fix unit test name * Fix unit test * Add handshake telemetry * restyle code format * remove debug logs * remove debug code
1 parent a9dd628 commit 5036fb8

File tree

8 files changed

+77
-34
lines changed

8 files changed

+77
-34
lines changed

src/sagemaker/estimator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@
106106
from sagemaker.workflow.entities import PipelineVariable
107107
from sagemaker.workflow.parameters import ParameterString
108108
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
109-
109+
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
110+
from sagemaker.telemetry.constants import Feature
110111

111112
logger = logging.getLogger(__name__)
112113

@@ -1297,6 +1298,7 @@ def latest_job_profiler_artifacts_path(self):
12971298
)
12981299
return None
12991300

1301+
@_telemetry_emitter(feature=Feature.ESTIMATOR, func_name="estimator.fit")
13001302
@runnable_by_pipeline
13011303
def fit(
13021304
self,

src/sagemaker/modules/train/model_trainer.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@
8686
EXEUCTE_TORCHRUN_DRIVER,
8787
EXECUTE_BASIC_SCRIPT_DRIVER,
8888
)
89+
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
90+
from sagemaker.telemetry.constants import Feature
8991
from sagemaker.modules import logger
9092
from sagemaker.modules.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
9193

@@ -117,7 +119,7 @@ class ModelTrainer(BaseModel):
117119
```
118120
119121
Attributes:
120-
session (Optiona(Session)):
122+
sagemaker_session (Optiona(Session)):
121123
The SageMaker session.
122124
If not specified, a new session will be created.
123125
role (Optional(str)):
@@ -181,7 +183,7 @@ class ModelTrainer(BaseModel):
181183
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
182184

183185
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
184-
session: Optional[Session] = None
186+
sagemaker_session: Optional[Session] = None
185187
role: Optional[str] = None
186188
base_job_name: Optional[str] = None
187189
source_code: Optional[SourceCode] = None
@@ -299,12 +301,12 @@ def model_post_init(self, __context: Any):
299301
self._validate_distributed_runner(self.source_code, self.distributed_runner)
300302

301303
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
302-
if self.session is None:
303-
self.session = Session()
304-
logger.warning("Session not provided. Using default Session.")
304+
if self.sagemaker_session is None:
305+
self.sagemaker_session = Session()
306+
logger.warning("SageMaker session not provided. Using default Session.")
305307

306308
if self.role is None:
307-
self.role = get_execution_role(sagemaker_session=self.session)
309+
self.role = get_execution_role(sagemaker_session=self.sagemaker_session)
308310
logger.warning(f"Role not provided. Using default role:\n{self.role}")
309311

310312
if self.base_job_name is None:
@@ -333,7 +335,7 @@ def model_post_init(self, __context: Any):
333335
)
334336

335337
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
336-
session = self.session
338+
session = self.sagemaker_session
337339
base_job_name = self.base_job_name
338340
self.output_data_config = OutputDataConfig(
339341
s3_output_path=f"s3://{session.default_bucket()}/{base_job_name}",
@@ -348,6 +350,7 @@ def model_post_init(self, __context: Any):
348350
if self.training_image:
349351
logger.info(f"Training image URI: {self.training_image}")
350352

353+
@_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train")
351354
@validate_call
352355
def train(
353356
self,
@@ -451,7 +454,7 @@ def train(
451454
resource_config=resource_config,
452455
vpc_config=vpc_config,
453456
# Public Instance Attributes
454-
session=self.session.boto_session,
457+
session=self.sagemaker_session.boto_session,
455458
role_arn=self.role,
456459
tags=self.tags,
457460
stopping_condition=self.stopping_condition,
@@ -494,7 +497,7 @@ def train(
494497
instance_count=resource_config.instance_count,
495498
image=algorithm_specification.training_image,
496499
container_root=self.local_container_root,
497-
sagemaker_session=self.session,
500+
sagemaker_session=self.sagemaker_session,
498501
container_entrypoint=algorithm_specification.container_entrypoint,
499502
container_arguments=algorithm_specification.container_arguments,
500503
input_data_config=input_data_config,
@@ -539,9 +542,9 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
539542
input_mode="File",
540543
)
541544
else:
542-
s3_uri = self.session.upload_data(
545+
s3_uri = self.sagemaker_session.upload_data(
543546
path=data_source,
544-
bucket=self.session.default_bucket(),
547+
bucket=self.sagemaker_session.default_bucket(),
545548
key_prefix=f"{self.base_job_name}/input/{channel_name}",
546549
)
547550
channel = Channel(
@@ -821,7 +824,7 @@ def from_recipe(
821824
training_input_mode: Optional[str] = "File",
822825
environment: Optional[Dict[str, str]] = None,
823826
tags: Optional[List[Tag]] = None,
824-
session: Optional[Session] = None,
827+
sagemaker_session: Optional[Session] = None,
825828
role: Optional[str] = None,
826829
base_job_name: Optional[str] = None,
827830
) -> "ModelTrainer":
@@ -863,7 +866,7 @@ def from_recipe(
863866
tags (Optional[List[Tag]]):
864867
An array of key-value pairs. You can use tags to categorize your AWS resources
865868
in different ways, for example, by purpose, owner, or environment.
866-
session (Optional[Session]):
869+
sagemaker_session (Optional[Session]):
867870
The SageMaker session.
868871
If not specified, a new session will be created.
869872
role (Optional[str]):
@@ -885,9 +888,9 @@ def from_recipe(
885888
+ "Please provide a GPU or Tranium instance type."
886889
)
887890

888-
if session is None:
889-
session = Session()
890-
logger.warning("Session not provided. Using default Session.")
891+
if sagemaker_session is None:
892+
sagemaker_session = Session()
893+
logger.warning("SageMaker session not provided. Using default Session.")
891894
if role is None:
892895
role = get_execution_role(sagemaker_session=session)
893896
logger.warning(f"Role not provided. Using default role:\n{role}")
@@ -903,13 +906,13 @@ def from_recipe(
903906
recipe_overrides=recipe_overrides,
904907
requirements=requirements,
905908
compute=compute,
906-
region_name=session.boto_region_name,
909+
region_name=sagemaker_session.boto_region_name,
907910
)
908911
if training_image is not None:
909912
model_trainer_args["training_image"] = training_image
910913

911914
model_trainer = cls(
912-
session=session,
915+
sagemaker_session=sagemaker_session,
913916
role=role,
914917
base_job_name=base_job_name,
915918
training_image_config=training_image_config,

src/sagemaker/serve/builder/model_builder.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,24 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
836836
self.env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"})
837837
self.dependencies.update({"requirements": mlflow_model_dependency_path})
838838

839+
@_capture_telemetry("ModelBuilder.build_training_job")
840+
def _collect_training_job_model_telemetry(self):
841+
return
842+
843+
@_capture_telemetry("ModelBuilder.build_model_trainer")
844+
def _collect_model_trainer_model_telemetry(self):
845+
return
846+
847+
@_capture_telemetry("ModelBuilder.build_estimator")
848+
def _collect_estimator_model_telemetry(self):
849+
return
850+
839851
# Model Builder is a class to build the model for deployment.
840852
# It supports three modes of deployment
841853
# 1/ SageMaker Endpoint
842854
# 2/ Local launch with container
843855
# 3/ In process mode with Transformers server in beta release
856+
@_capture_telemetry("ModelBuilder.build")
844857
def build( # pylint: disable=R0911
845858
self,
846859
mode: Type[Mode] = None,
@@ -868,15 +881,20 @@ def build( # pylint: disable=R0911
868881
if role_arn:
869882
self.role_arn = role_arn
870883

884+
self.serve_settings = self._get_serve_setting()
885+
871886
if isinstance(self.model, TrainingJob):
872887
self.model_path = self.model.model_artifacts.s3_model_artifacts
873888
self.model = None
889+
self._collect_training_job_model_telemetry()
874890
elif isinstance(self.model, ModelTrainer):
875891
self.model_path = self.model._latest_training_job.model_artifacts.s3_model_artifacts
876892
self.model = None
893+
self._collect_model_trainer_model_telemetry()
877894
elif isinstance(self.model, Estimator):
878895
self.model_path = self.model.output_path
879896
self.model = None
897+
self._collect_estimator_model_telemetry()
880898

881899
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
882900

@@ -899,7 +917,6 @@ def build( # pylint: disable=R0911
899917
self.sagemaker_session.sagemaker_client._user_agent_creator.to_string
900918
)
901919

902-
self.serve_settings = self._get_serve_setting()
903920
self._is_custom_image_uri = self.image_uri is not None
904921

905922
self._handle_mlflow_input()
@@ -1017,6 +1034,7 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710
10171034
if self.model_server == ModelServer.MMS:
10181035
return self._build_for_transformers()
10191036

1037+
@_capture_telemetry("ModelBuilder.save")
10201038
def save(
10211039
self,
10221040
save_path: Optional[str] = None,
@@ -1566,6 +1584,7 @@ def _optimize_prepare_for_hf(self):
15661584
)
15671585
self.pysdk_model.env.update(env)
15681586

1587+
@_capture_telemetry("ModelBuilder.deploy")
15691588
def deploy(
15701589
self,
15711590
endpoint_name: str = None,

src/sagemaker/serve/utils/telemetry_logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def wrapper(self, *args, **kwargs):
171171

172172
extra += f"&x-latency={round(elapsed, 2)}"
173173

174-
if not self.serve_settings.telemetry_opt_out:
174+
if hasattr(self, "serve_settings") and not self.serve_settings.telemetry_opt_out:
175175
_send_telemetry(
176176
status,
177177
MODE_TO_CODE[str(self.mode)],

src/sagemaker/telemetry/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Feature(Enum):
2525
SDK_DEFAULTS = 1
2626
LOCAL_MODE = 2
2727
REMOTE_FUNCTION = 3
28+
MODEL_TRAINER = 4
29+
ESTIMATOR = 5
2830

2931
def __str__(self): # pylint: disable=E0307
3032
"""Return the feature name."""

src/sagemaker/telemetry/telemetry_logging.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
str(Feature.SDK_DEFAULTS): 1,
5353
str(Feature.LOCAL_MODE): 2,
5454
str(Feature.REMOTE_FUNCTION): 3,
55+
str(Feature.MODEL_TRAINER): 4,
56+
str(Feature.ESTIMATOR): 5,
5557
}
5658

5759
STATUS_TO_CODE = {
@@ -61,7 +63,13 @@
6163

6264

6365
def _telemetry_emitter(feature: str, func_name: str):
64-
"""Decorator to emit telemetry logs for SageMaker Python SDK functions"""
66+
"""
67+
Decorator to emit telemetry logs for SageMaker Python SDK functions. This class needs
68+
sagemaker_session object as a member. Default session object is a pysdk v2 Session object
69+
in this repo. When collecting telemetry for classes using sagemaker-core Session object,
70+
we should be aware of its differences, such as sagemaker_session.sagemaker_config does not
71+
exist in new Session class.
72+
"""
6573

6674
def decorator(func):
6775
@functools.wraps(func)
@@ -95,10 +103,18 @@ def wrapper(*args, **kwargs):
95103
# Construct the feature list to track feature combinations
96104
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]
97105

98-
if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
106+
if (
107+
hasattr(sagemaker_session, "sagemaker_config")
108+
and sagemaker_session.sagemaker_config
109+
and feature != Feature.SDK_DEFAULTS
110+
):
99111
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])
100112

101-
if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
113+
if (
114+
hasattr(sagemaker_session, "local_mode")
115+
and sagemaker_session.local_mode
116+
and feature != Feature.LOCAL_MODE
117+
):
102118
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])
103119

104120
# Construct the extra info to track platform and environment usage metadata
@@ -111,7 +127,7 @@ def wrapper(*args, **kwargs):
111127
)
112128

113129
# Add endpoint ARN to the extra info if available
114-
if sagemaker_session.endpoint_arn:
130+
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
115131
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"
116132

117133
start_timer = perf_counter()
@@ -171,8 +187,9 @@ def _send_telemetry_request(
171187
) -> None:
172188
"""Make GET request to an empty object in S3 bucket"""
173189
try:
174-
accountId = _get_accountId(session)
175-
region = _get_region_or_default(session)
190+
accountId = _get_accountId(session) if session else "NotAvailable"
191+
# telemetry will be sent to us-west-2 if no session availale
192+
region = _get_region_or_default(session) if session else DEFAULT_AWS_REGION
176193
url = _construct_url(
177194
accountId,
178195
region,

src/sagemaker/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict =
11601160
Returns:
11611161
object: The corresponding default value in the configuration file.
11621162
"""
1163-
if sagemaker_session:
1163+
if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
11641164
config_to_check = sagemaker_session.sagemaker_config
11651165
else:
11661166
config_to_check = sagemaker_config

tests/unit/sagemaker/modules/train/test_model_trainer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ def model_trainer():
146146
def test_model_trainer_param_validation(test_case, modules_session):
147147
if test_case["should_throw"]:
148148
with pytest.raises(ValueError):
149-
ModelTrainer(**test_case["init_params"], session=modules_session)
149+
ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session)
150150
else:
151-
trainer = ModelTrainer(**test_case["init_params"], session=modules_session)
151+
trainer = ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session)
152152
assert trainer is not None
153153
assert trainer.training_image == DEFAULT_IMAGE
154154
assert trainer.compute == DEFAULT_COMPUTE_CONFIG
@@ -261,7 +261,7 @@ def test_metric_settings(mock_training_job, modules_session):
261261

262262
model_trainer = ModelTrainer(
263263
training_image=image_uri,
264-
session=modules_session,
264+
sagemaker_session=modules_session,
265265
role=role,
266266
).with_metric_settings(
267267
enable_sage_maker_metrics_time_series=True, metric_definitions=[metric_definition]
@@ -306,7 +306,7 @@ def test_debugger_settings(mock_training_job, modules_session):
306306

307307
model_trainer = ModelTrainer(
308308
training_image=image_uri,
309-
session=modules_session,
309+
sagemaker_session=modules_session,
310310
role=role,
311311
).with_debugger_settings(
312312
debug_hook_config=debug_hook_config,
@@ -367,7 +367,7 @@ def test_additional_settings(mock_training_job, modules_session):
367367
)
368368
model_trainer = ModelTrainer(
369369
training_image=image_uri,
370-
session=modules_session,
370+
sagemaker_session=modules_session,
371371
role=role,
372372
).with_additional_settings(
373373
retry_strategy=retry_strategy,
@@ -467,7 +467,7 @@ def test_train_with_distributed_runner(
467467

468468
try:
469469
model_trainer = ModelTrainer(
470-
session=modules_session,
470+
sagemaker_session=modules_session,
471471
training_image=DEFAULT_IMAGE,
472472
source_code=test_case["source_code"],
473473
distributed_runner=test_case["distributed_runner"],

0 commit comments

Comments
 (0)