Skip to content

Commit 9c75b2b

Browse files
nargokulpintaoz-aws
authored andcommitted
Model Trainer Bucket improvements (#1618)
* Model Trainer Bucket improvements * Address Comments * Unit test fix * Unit test fix * Codestyle * Codestyle * Codestyle * Fixes * Fixes * Fixes * Fixes * Fixes
1 parent 775a627 commit 9c75b2b

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

src/sagemaker/modules/train/model_trainer.py

+53-12
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class ModelTrainer(BaseModel):
135135
The SageMakerCore session. For convinience, can be imported like:
136136
`from sagemaker.modules import Session`.
137137
If not specified, a new session will be created.
138+
If the default bucket for the artifacts needs to be updated, it can be done by
139+
passing it in the Session object.
138140
role (Optional(str)):
139141
The IAM role ARN for the training job.
140142
If not specified, the default SageMaker execution role will be used.
@@ -173,7 +175,8 @@ class ModelTrainer(BaseModel):
173175
output_data_config (Optional[OutputDataConfig]):
174176
The output data configuration. This is used to specify the output data location
175177
for the training job.
176-
If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
178+
If not specified in the session, will default to
179+
`s3://<default_bucket>/<default_prefix>/<base_job_name>/`.
177180
input_data_config (Optional[List[Union[Channel, InputData]]]):
178181
The input data config for the training job.
179182
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -348,7 +351,7 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
348351
configurable_attribute
349352
)(
350353
**default_config # pylint: disable=E1134
351-
) # noqa
354+
)
352355
setattr(self, configurable_attribute, default_config)
353356

354357
def __del__(self):
@@ -461,7 +464,8 @@ def model_post_init(self, __context: Any):
461464
session = self.sagemaker_session
462465
base_job_name = self.base_job_name
463466
self.output_data_config = OutputDataConfig(
464-
s3_output_path=f"s3://{session.default_bucket()}/{base_job_name}",
467+
s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}"
468+
f"/{base_job_name}",
465469
compression_type="GZIP",
466470
kms_key_id=None,
467471
)
@@ -473,6 +477,12 @@ def model_post_init(self, __context: Any):
473477
if self.training_image:
474478
logger.info(f"Training image URI: {self.training_image}")
475479

480+
def _fetch_bucket_name_and_prefix(self, session: Session) -> str:
481+
"""Helper function to get the bucket name with the corresponding prefix if applicable"""
482+
if session.default_bucket_prefix is not None:
483+
return f"{session.default_bucket()}/{session.default_bucket_prefix}"
484+
return session.default_bucket()
485+
476486
@_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train")
477487
@validate_call
478488
def train(
@@ -497,12 +507,16 @@ def train(
497507
Defaults to True.
498508
"""
499509
self._populate_intelligent_defaults()
510+
current_training_job_name = _get_unique_name(self.base_job_name)
511+
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
500512
if input_data_config:
501513
self.input_data_config = input_data_config
502514

503515
input_data_config = []
504516
if self.input_data_config:
505-
input_data_config = self._get_input_data_config(self.input_data_config)
517+
input_data_config = self._get_input_data_config(
518+
self.input_data_config, input_data_key_prefix
519+
)
506520

507521
string_hyper_parameters = {}
508522
if self.hyperparameters:
@@ -524,7 +538,9 @@ def train(
524538
# The source code will be mounted at /opt/ml/input/data/sm_code in the container
525539
if self.source_code.source_dir:
526540
source_code_channel = self.create_input_data_channel(
527-
SM_CODE, self.source_code.source_dir
541+
channel_name=SM_CODE,
542+
data_source=self.source_code.source_dir,
543+
key_prefix=input_data_key_prefix,
528544
)
529545
input_data_config.append(source_code_channel)
530546

@@ -542,7 +558,11 @@ def train(
542558
self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed)
543559

544560
# Create an input channel for drivers packaged by the sdk
545-
sm_drivers_channel = self.create_input_data_channel(SM_DRIVERS, drivers_dir.name)
561+
sm_drivers_channel = self.create_input_data_channel(
562+
channel_name=SM_DRIVERS,
563+
data_source=drivers_dir.name,
564+
key_prefix=input_data_key_prefix,
565+
)
546566
input_data_config.append(sm_drivers_channel)
547567

548568
# If source_code is provided, we will always use
@@ -567,7 +587,7 @@ def train(
567587

568588
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
569589
training_job = TrainingJob.create(
570-
training_job_name=_get_unique_name(self.base_job_name),
590+
training_job_name=current_training_job_name,
571591
algorithm_specification=algorithm_specification,
572592
hyper_parameters=string_hyper_parameters,
573593
input_data_config=input_data_config,
@@ -621,14 +641,22 @@ def train(
621641
)
622642
local_container.train(wait)
623643

624-
def create_input_data_channel(self, channel_name: str, data_source: DataSourceType) -> Channel:
644+
def create_input_data_channel(
645+
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None
646+
) -> Channel:
625647
"""Create an input data channel for the training job.
626648
627649
Args:
628650
channel_name (str): The name of the input data channel.
629651
data_source (DataSourceType): The data source for the input data channel.
630652
DataSourceType can be an S3 URI string, local file path string,
631653
S3DataSource object, or FileSystemDataSource object.
654+
key_prefix (Optional[str]): The key prefix to use when uploading data to S3.
655+
Only applicable when data_source is a local file path string.
656+
If not specified, local data will be uploaded to:
657+
s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/
658+
If specified, local data will be uploaded to:
659+
s3://<default_bucket_path>/<key_prefix>/<channel_name>/
632660
"""
633661
channel = None
634662
if isinstance(data_source, str):
@@ -644,6 +672,10 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
644672
),
645673
input_mode="File",
646674
)
675+
if key_prefix:
676+
logger.warning(
677+
"key_prefix is only applicable when data_source is a local file path."
678+
)
647679
elif _is_valid_path(data_source):
648680
if self.training_mode == Mode.LOCAL_CONTAINER:
649681
channel = Channel(
@@ -657,10 +689,17 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
657689
input_mode="File",
658690
)
659691
else:
692+
key_prefix = (
693+
f"{key_prefix}/{channel_name}"
694+
if key_prefix
695+
else f"{self.base_job_name}/input/{channel_name}"
696+
)
697+
if self.sagemaker_session.default_bucket_prefix:
698+
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
660699
s3_uri = self.sagemaker_session.upload_data(
661700
path=data_source,
662701
bucket=self.sagemaker_session.default_bucket(),
663-
key_prefix=f"{self.base_job_name}/input/{channel_name}",
702+
key_prefix=key_prefix,
664703
)
665704
channel = Channel(
666705
channel_name=channel_name,
@@ -687,7 +726,9 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
687726
return channel
688727

689728
def _get_input_data_config(
690-
self, input_data_channels: Optional[List[Union[Channel, InputData]]]
729+
self,
730+
input_data_channels: Optional[List[Union[Channel, InputData]]],
731+
key_prefix: Optional[str] = None,
691732
) -> List[Channel]:
692733
"""Get the input data configuration for the training job.
693734
@@ -706,7 +747,7 @@ def _get_input_data_config(
706747
channels.append(input_data)
707748
elif isinstance(input_data, InputData):
708749
channel = self.create_input_data_channel(
709-
input_data.channel_name, input_data.data_source
750+
input_data.channel_name, input_data.data_source, key_prefix=key_prefix
710751
)
711752
channels.append(channel)
712753
else:
@@ -850,7 +891,7 @@ def from_recipe(
850891
An array of key-value pairs. You can use tags to categorize your AWS resources
851892
in different ways, for example, by purpose, owner, or environment.
852893
sagemaker_session (Optional[Session]):
853-
The SageMaker session.
894+
The SageMakerCore session.
854895
If not specified, a new session will be created.
855896
role (Optional[str]):
856897
The IAM role ARN for the training job.

tests/unit/sagemaker/modules/train/test_model_trainer.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@
5959
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
6060
DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000"
6161
DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role"
62+
DEFAULT_BUCKET_PREFIX = "sample-prefix"
6263
DEFAULT_COMPUTE_CONFIG = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1)
6364
DEFAULT_OUTPUT_DATA_CONFIG = OutputDataConfig(
64-
s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}",
65+
s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}",
6566
compression_type="GZIP",
6667
kms_key_id=None,
6768
)
@@ -85,6 +86,7 @@ def modules_session():
8586
session_instance = session_mock.return_value
8687
session_instance.default_bucket.return_value = DEFAULT_BUCKET
8788
session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE
89+
session_instance.default_bucket_prefix = DEFAULT_BUCKET_PREFIX
8890
session_instance.boto_session = MagicMock(spec="boto3.session.Session")
8991
yield session_instance
9092

@@ -170,8 +172,9 @@ def test_train_with_default_params(mock_training_job, model_trainer):
170172

171173
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
172174
@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config")
175+
@patch("sagemaker.modules.train.model_trainer.ModelTrainer.create_input_data_channel")
173176
def test_train_with_intelligent_defaults(
174-
mock_resolve_value_from_config, mock_training_job, model_trainer
177+
mock_create_input_data_channel, mock_resolve_value_from_config, mock_training_job, model_trainer
175178
):
176179
source_code_path = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, "sourceCode")
177180

@@ -229,7 +232,11 @@ def test_train_with_intelligent_defaults_training_job_space(
229232
max_pending_time_in_seconds=None,
230233
),
231234
output_data_config=OutputDataConfig(
232-
s3_output_path="s3://" "sagemaker-us-west-2" "-000000000000/d" "ummy-image-job",
235+
s3_output_path="s3://"
236+
"sagemaker-us-west-2"
237+
"-000000000000/"
238+
"sample-prefix/"
239+
"dummy-image-job",
233240
kms_key_id=None,
234241
compression_type="GZIP",
235242
),
@@ -258,7 +265,7 @@ def test_train_with_input_data_channels(mock_get_input_config, mock_training_job
258265

259266
model_trainer.train(input_data_config=mock_input_data_config)
260267

261-
mock_get_input_config.assert_called_once_with(mock_input_data_config)
268+
mock_get_input_config.assert_called_once_with(mock_input_data_config, ANY)
262269
mock_training_job.create.assert_called_once()
263270

264271

@@ -309,10 +316,11 @@ def test_train_with_input_data_channels(mock_get_input_config, mock_training_job
309316
],
310317
)
311318
@patch("sagemaker.modules.train.model_trainer.Session.upload_data")
312-
def test_create_input_data_channel(mock_upload_data, model_trainer, test_case):
319+
@patch("sagemaker.modules.train.model_trainer.Session.default_bucket")
320+
def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_trainer, test_case):
313321
expected_s3_uri = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test"
314322
mock_upload_data.return_value = expected_s3_uri
315-
323+
mock_default_bucket.return_value = DEFAULT_BUCKET
316324
if not test_case["valid"]:
317325
with pytest.raises(ValueError):
318326
model_trainer.create_input_data_channel(
@@ -323,7 +331,6 @@ def test_create_input_data_channel(mock_upload_data, model_trainer, test_case):
323331
test_case["channel_name"], test_case["data_source"]
324332
)
325333
assert channel.channel_name == test_case["channel_name"]
326-
327334
if isinstance(test_case["data_source"], S3DataSource):
328335
assert channel.data_source.s3_data_source == test_case["data_source"]
329336
elif isinstance(test_case["data_source"], FileSystemDataSource):

0 commit comments

Comments
 (0)