86
86
EXEUCTE_TORCHRUN_DRIVER ,
87
87
EXECUTE_BASIC_SCRIPT_DRIVER ,
88
88
)
89
+ from sagemaker .telemetry .telemetry_logging import _telemetry_emitter
90
+ from sagemaker .telemetry .constants import Feature
89
91
from sagemaker .modules import logger
90
92
from sagemaker .modules .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
91
93
@@ -117,7 +119,7 @@ class ModelTrainer(BaseModel):
117
119
```
118
120
119
121
Attributes:
120
- session (Optiona(Session)):
122
+ sagemaker_session (Optiona(Session)):
121
123
The SageMaker session.
122
124
If not specified, a new session will be created.
123
125
role (Optional(str)):
@@ -181,7 +183,7 @@ class ModelTrainer(BaseModel):
181
183
model_config = ConfigDict (arbitrary_types_allowed = True , extra = "forbid" )
182
184
183
185
training_mode : Mode = Mode .SAGEMAKER_TRAINING_JOB
184
- session : Optional [Session ] = None
186
+ sagemaker_session : Optional [Session ] = None
185
187
role : Optional [str ] = None
186
188
base_job_name : Optional [str ] = None
187
189
source_code : Optional [SourceCode ] = None
@@ -299,12 +301,12 @@ def model_post_init(self, __context: Any):
299
301
self ._validate_distributed_runner (self .source_code , self .distributed_runner )
300
302
301
303
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB :
302
- if self .session is None :
303
- self .session = Session ()
304
- logger .warning ("Session not provided. Using default Session." )
304
+ if self .sagemaker_session is None :
305
+ self .sagemaker_session = Session ()
306
+ logger .warning ("SageMaker session not provided. Using default Session." )
305
307
306
308
if self .role is None :
307
- self .role = get_execution_role (sagemaker_session = self .session )
309
+ self .role = get_execution_role (sagemaker_session = self .sagemaker_session )
308
310
logger .warning (f"Role not provided. Using default role:\n { self .role } " )
309
311
310
312
if self .base_job_name is None :
@@ -333,7 +335,7 @@ def model_post_init(self, __context: Any):
333
335
)
334
336
335
337
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB and self .output_data_config is None :
336
- session = self .session
338
+ session = self .sagemaker_session
337
339
base_job_name = self .base_job_name
338
340
self .output_data_config = OutputDataConfig (
339
341
s3_output_path = f"s3://{ session .default_bucket ()} /{ base_job_name } " ,
@@ -348,6 +350,7 @@ def model_post_init(self, __context: Any):
348
350
if self .training_image :
349
351
logger .info (f"Training image URI: { self .training_image } " )
350
352
353
+ @_telemetry_emitter (feature = Feature .MODEL_TRAINER , func_name = "model_trainer.train" )
351
354
@validate_call
352
355
def train (
353
356
self ,
@@ -451,7 +454,7 @@ def train(
451
454
resource_config = resource_config ,
452
455
vpc_config = vpc_config ,
453
456
# Public Instance Attributes
454
- session = self .session .boto_session ,
457
+ session = self .sagemaker_session .boto_session ,
455
458
role_arn = self .role ,
456
459
tags = self .tags ,
457
460
stopping_condition = self .stopping_condition ,
@@ -494,7 +497,7 @@ def train(
494
497
instance_count = resource_config .instance_count ,
495
498
image = algorithm_specification .training_image ,
496
499
container_root = self .local_container_root ,
497
- sagemaker_session = self .session ,
500
+ sagemaker_session = self .sagemaker_session ,
498
501
container_entrypoint = algorithm_specification .container_entrypoint ,
499
502
container_arguments = algorithm_specification .container_arguments ,
500
503
input_data_config = input_data_config ,
@@ -539,9 +542,9 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
539
542
input_mode = "File" ,
540
543
)
541
544
else :
542
- s3_uri = self .session .upload_data (
545
+ s3_uri = self .sagemaker_session .upload_data (
543
546
path = data_source ,
544
- bucket = self .session .default_bucket (),
547
+ bucket = self .sagemaker_session .default_bucket (),
545
548
key_prefix = f"{ self .base_job_name } /input/{ channel_name } " ,
546
549
)
547
550
channel = Channel (
@@ -821,7 +824,7 @@ def from_recipe(
821
824
training_input_mode : Optional [str ] = "File" ,
822
825
environment : Optional [Dict [str , str ]] = None ,
823
826
tags : Optional [List [Tag ]] = None ,
824
- session : Optional [Session ] = None ,
827
+ sagemaker_session : Optional [Session ] = None ,
825
828
role : Optional [str ] = None ,
826
829
base_job_name : Optional [str ] = None ,
827
830
) -> "ModelTrainer" :
@@ -863,7 +866,7 @@ def from_recipe(
863
866
tags (Optional[List[Tag]]):
864
867
An array of key-value pairs. You can use tags to categorize your AWS resources
865
868
in different ways, for example, by purpose, owner, or environment.
866
- session (Optional[Session]):
869
+ sagemaker_session (Optional[Session]):
867
870
The SageMaker session.
868
871
If not specified, a new session will be created.
869
872
role (Optional[str]):
@@ -885,9 +888,9 @@ def from_recipe(
885
888
+ "Please provide a GPU or Tranium instance type."
886
889
)
887
890
888
- if session is None :
889
- session = Session ()
890
- logger .warning ("Session not provided. Using default Session." )
891
+ if sagemaker_session is None :
892
+ sagemaker_session = Session ()
893
+ logger .warning ("SageMaker session not provided. Using default Session." )
891
894
if role is None :
892
895
role = get_execution_role (sagemaker_session = session )
893
896
logger .warning (f"Role not provided. Using default role:\n { role } " )
@@ -903,13 +906,13 @@ def from_recipe(
903
906
recipe_overrides = recipe_overrides ,
904
907
requirements = requirements ,
905
908
compute = compute ,
906
- region_name = session .boto_region_name ,
909
+ region_name = sagemaker_session .boto_region_name ,
907
910
)
908
911
if training_image is not None :
909
912
model_trainer_args ["training_image" ] = training_image
910
913
911
914
model_trainer = cls (
912
- session = session ,
915
+ sagemaker_session = sagemaker_session ,
913
916
role = role ,
914
917
base_job_name = base_job_name ,
915
918
training_image_config = training_image_config ,
0 commit comments