diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 58ae724074..2143da4e5c 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -580,7 +580,7 @@ def train( """Train a model using AWS SageMaker. Args: - input_data_config (Optional[Union[List[Channel], Dict[str, DataSourceType]]]): + input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, @@ -596,11 +596,23 @@ def train( current_training_job_name = _get_unique_name(self.base_job_name) input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" - self.input_data_config = input_data_config or self.input_data_config or [] + final_input_data_config = self.input_data_config.copy() if self.input_data_config else [] + + if input_data_config: + # merge the inputs with method parameter taking precedence + existing_channels = {input.channel_name: input for input in final_input_data_config} + new_channels = [] + for new_input in input_data_config: + if new_input.channel_name in existing_channels: + existing_channels[new_input.channel_name] = new_input + else: + new_channels.append(new_input) + + final_input_data_config = list(existing_channels.values()) + new_channels - if self.input_data_config: - self.input_data_config = self._get_input_data_config( - self.input_data_config, input_data_key_prefix + if final_input_data_config: + final_input_data_config = self._get_input_data_config( + final_input_data_config, input_data_key_prefix ) if self.checkpoint_config and not self.checkpoint_config.s3_uri: @@ -643,7 +655,7 @@ def train( data_source=self.source_code.source_dir, key_prefix=input_data_key_prefix, ) - self.input_data_config.append(source_code_channel) + final_input_data_config.append(source_code_channel) self._prepare_train_script( tmp_dir=tmp_dir, @@ -664,7 +676,7 @@ def train( data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) - self.input_data_config.append(sm_drivers_channel) + final_input_data_config.append(sm_drivers_channel) # If source_code is provided, we will always use # the default container entrypoint and arguments @@ -691,7 +703,7 @@ def train( training_job_name=current_training_job_name, algorithm_specification=algorithm_specification, hyper_parameters=string_hyper_parameters, - input_data_config=self.input_data_config, + input_data_config=final_input_data_config, resource_config=resource_config, vpc_config=vpc_config, # Public Instance Attributes @@ -736,7 +748,7 @@ def train( sagemaker_session=self.sagemaker_session, container_entrypoint=algorithm_specification.container_entrypoint, container_arguments=algorithm_specification.container_arguments, - input_data_config=self.input_data_config, + input_data_config=final_input_data_config, hyper_parameters=string_hyper_parameters, environment=self.environment, ) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index b1348b5ac9..5d4722b8aa 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1258,3 +1258,44 @@ def mock_upload_data(path, bucket, key_prefix): assert kwargs["tensor_board_output_config"].s3_output_path == default_base_path assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_input_merge(mock_training_job, modules_session): + model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz") + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + input_data_config=[model_input], + ) + + train_input = InputData(channel_name="train", data_source="s3://bucket/data/train") + model_trainer.train(input_data_config=[train_input]) + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="model", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri="s3://bucket/model/model.tar.gz", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name="train", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri="s3://bucket/data/train", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + ]