@@ -135,6 +135,8 @@ class ModelTrainer(BaseModel):
135
135
The SageMakerCore session. For convinience, can be imported like:
136
136
`from sagemaker.modules import Session`.
137
137
If not specified, a new session will be created.
138
+ If the default bucket for the artifacts needs to be updated, it can be done by
139
+ passing it in the Session object.
138
140
role (Optional(str)):
139
141
The IAM role ARN for the training job.
140
142
If not specified, the default SageMaker execution role will be used.
@@ -173,7 +175,8 @@ class ModelTrainer(BaseModel):
173
175
output_data_config (Optional[OutputDataConfig]):
174
176
The output data configuration. This is used to specify the output data location
175
177
for the training job.
176
- If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
178
+ If not specified in the session, will default to
179
+ `s3://<default_bucket>/<default_prefix>/<base_job_name>/`.
177
180
input_data_config (Optional[List[Union[Channel, InputData]]]):
178
181
The input data config for the training job.
179
182
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -348,7 +351,7 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
348
351
configurable_attribute
349
352
)(
350
353
** default_config # pylint: disable=E1134
351
- ) # noqa
354
+ )
352
355
setattr (self , configurable_attribute , default_config )
353
356
354
357
def __del__ (self ):
@@ -461,7 +464,8 @@ def model_post_init(self, __context: Any):
461
464
session = self .sagemaker_session
462
465
base_job_name = self .base_job_name
463
466
self .output_data_config = OutputDataConfig (
464
- s3_output_path = f"s3://{ session .default_bucket ()} /{ base_job_name } " ,
467
+ s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
468
+ f"/{ base_job_name } " ,
465
469
compression_type = "GZIP" ,
466
470
kms_key_id = None ,
467
471
)
@@ -473,6 +477,12 @@ def model_post_init(self, __context: Any):
473
477
if self .training_image :
474
478
logger .info (f"Training image URI: { self .training_image } " )
475
479
480
+ def _fetch_bucket_name_and_prefix (self , session : Session ) -> str :
481
+ """Helper function to get the bucket name with the corresponding prefix if applicable"""
482
+ if session .default_bucket_prefix is not None :
483
+ return f"{ session .default_bucket ()} /{ session .default_bucket_prefix } "
484
+ return session .default_bucket ()
485
+
476
486
@_telemetry_emitter (feature = Feature .MODEL_TRAINER , func_name = "model_trainer.train" )
477
487
@validate_call
478
488
def train (
@@ -497,12 +507,16 @@ def train(
497
507
Defaults to True.
498
508
"""
499
509
self ._populate_intelligent_defaults ()
510
+ current_training_job_name = _get_unique_name (self .base_job_name )
511
+ input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
500
512
if input_data_config :
501
513
self .input_data_config = input_data_config
502
514
503
515
input_data_config = []
504
516
if self .input_data_config :
505
- input_data_config = self ._get_input_data_config (self .input_data_config )
517
+ input_data_config = self ._get_input_data_config (
518
+ self .input_data_config , input_data_key_prefix
519
+ )
506
520
507
521
string_hyper_parameters = {}
508
522
if self .hyperparameters :
@@ -524,7 +538,9 @@ def train(
524
538
# The source code will be mounted at /opt/ml/input/data/sm_code in the container
525
539
if self .source_code .source_dir :
526
540
source_code_channel = self .create_input_data_channel (
527
- SM_CODE , self .source_code .source_dir
541
+ channel_name = SM_CODE ,
542
+ data_source = self .source_code .source_dir ,
543
+ key_prefix = input_data_key_prefix ,
528
544
)
529
545
input_data_config .append (source_code_channel )
530
546
@@ -542,7 +558,11 @@ def train(
542
558
self ._write_distributed_json (tmp_dir = drivers_dir , distributed = self .distributed )
543
559
544
560
# Create an input channel for drivers packaged by the sdk
545
- sm_drivers_channel = self .create_input_data_channel (SM_DRIVERS , drivers_dir .name )
561
+ sm_drivers_channel = self .create_input_data_channel (
562
+ channel_name = SM_DRIVERS ,
563
+ data_source = drivers_dir .name ,
564
+ key_prefix = input_data_key_prefix ,
565
+ )
546
566
input_data_config .append (sm_drivers_channel )
547
567
548
568
# If source_code is provided, we will always use
@@ -567,7 +587,7 @@ def train(
567
587
568
588
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB :
569
589
training_job = TrainingJob .create (
570
- training_job_name = _get_unique_name ( self . base_job_name ) ,
590
+ training_job_name = current_training_job_name ,
571
591
algorithm_specification = algorithm_specification ,
572
592
hyper_parameters = string_hyper_parameters ,
573
593
input_data_config = input_data_config ,
@@ -621,14 +641,22 @@ def train(
621
641
)
622
642
local_container .train (wait )
623
643
624
- def create_input_data_channel (self , channel_name : str , data_source : DataSourceType ) -> Channel :
644
+ def create_input_data_channel (
645
+ self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
646
+ ) -> Channel :
625
647
"""Create an input data channel for the training job.
626
648
627
649
Args:
628
650
channel_name (str): The name of the input data channel.
629
651
data_source (DataSourceType): The data source for the input data channel.
630
652
DataSourceType can be an S3 URI string, local file path string,
631
653
S3DataSource object, or FileSystemDataSource object.
654
+ key_prefix (Optional[str]): The key prefix to use when uploading data to S3.
655
+ Only applicable when data_source is a local file path string.
656
+ If not specified, local data will be uploaded to:
657
+ s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/
658
+ If specified, local data will be uploaded to:
659
+ s3://<default_bucket_path>/<key_prefix>/<channel_name>/
632
660
"""
633
661
channel = None
634
662
if isinstance (data_source , str ):
@@ -644,6 +672,10 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
644
672
),
645
673
input_mode = "File" ,
646
674
)
675
+ if key_prefix :
676
+ logger .warning (
677
+ "key_prefix is only applicable when data_source is a local file path."
678
+ )
647
679
elif _is_valid_path (data_source ):
648
680
if self .training_mode == Mode .LOCAL_CONTAINER :
649
681
channel = Channel (
@@ -657,10 +689,17 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
657
689
input_mode = "File" ,
658
690
)
659
691
else :
692
+ key_prefix = (
693
+ f"{ key_prefix } /{ channel_name } "
694
+ if key_prefix
695
+ else f"{ self .base_job_name } /input/{ channel_name } "
696
+ )
697
+ if self .sagemaker_session .default_bucket_prefix :
698
+ key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
660
699
s3_uri = self .sagemaker_session .upload_data (
661
700
path = data_source ,
662
701
bucket = self .sagemaker_session .default_bucket (),
663
- key_prefix = f" { self . base_job_name } /input/ { channel_name } " ,
702
+ key_prefix = key_prefix ,
664
703
)
665
704
channel = Channel (
666
705
channel_name = channel_name ,
@@ -687,7 +726,9 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
687
726
return channel
688
727
689
728
def _get_input_data_config (
690
- self , input_data_channels : Optional [List [Union [Channel , InputData ]]]
729
+ self ,
730
+ input_data_channels : Optional [List [Union [Channel , InputData ]]],
731
+ key_prefix : Optional [str ] = None ,
691
732
) -> List [Channel ]:
692
733
"""Get the input data configuration for the training job.
693
734
@@ -706,7 +747,7 @@ def _get_input_data_config(
706
747
channels .append (input_data )
707
748
elif isinstance (input_data , InputData ):
708
749
channel = self .create_input_data_channel (
709
- input_data .channel_name , input_data .data_source
750
+ input_data .channel_name , input_data .data_source , key_prefix = key_prefix
710
751
)
711
752
channels .append (channel )
712
753
else :
@@ -850,7 +891,7 @@ def from_recipe(
850
891
An array of key-value pairs. You can use tags to categorize your AWS resources
851
892
in different ways, for example, by purpose, owner, or environment.
852
893
sagemaker_session (Optional[Session]):
853
- The SageMaker session.
894
+ The SageMakerCore session.
854
895
If not specified, a new session will be created.
855
896
role (Optional[str]):
856
897
The IAM role ARN for the training job.
0 commit comments