@@ -126,6 +126,60 @@ def test_write_config_file(LocalSession, tmpdir):
126
126
assert channel ['ChannelName' ] in input_data_config_data
127
127
128
128
129
+ @patch ('sagemaker.local.local_session.LocalSession' )
130
+ def test_write_config_files_input_content_type (LocalSession , tmpdir ):
131
+ sagemaker_container = _SageMakerContainer ('local' , 1 , 'my-image' )
132
+ sagemaker_container .container_root = str (tmpdir .mkdir ('container-root' ))
133
+ host = 'algo-1'
134
+
135
+ sagemaker .local .image ._create_config_file_directories (sagemaker_container .container_root , host )
136
+
137
+ container_root = sagemaker_container .container_root
138
+ config_file_root = os .path .join (container_root , host , 'input' , 'config' )
139
+
140
+ input_data_config_file = os .path .join (config_file_root , 'inputdataconfig.json' )
141
+
142
+ # write the config files, and then lets check they exist and have the right content.
143
+ input_data_config = [
144
+ {
145
+ 'ChannelName' : 'channel_a' ,
146
+ 'DataUri' : 'file:///tmp/source1' ,
147
+ 'ContentType' : 'text/csv' ,
148
+ 'DataSource' : {
149
+ 'FileDataSource' : {
150
+ 'FileDataDistributionType' : 'FullyReplicated' ,
151
+ 'FileUri' : 'file:///tmp/source1'
152
+ }
153
+ }
154
+ },
155
+ {
156
+ 'ChannelName' : 'channel_b' ,
157
+ 'DataUri' : 's3://my-own-bucket/prefix' ,
158
+ 'DataSource' : {
159
+ 'S3DataSource' : {
160
+ 'S3DataDistributionType' : 'FullyReplicated' ,
161
+ 'S3DataType' : 'S3Prefix' ,
162
+ 'S3Uri' : 's3://my-own-bucket/prefix'
163
+ }
164
+ }
165
+ }
166
+ ]
167
+ sagemaker_container .write_config_files (host , HYPERPARAMETERS , input_data_config )
168
+
169
+ assert os .path .exists (input_data_config_file )
170
+ parsed_input_config = json .load (open (input_data_config_file ))
171
+ # Validate Input Data Config
172
+ for channel in input_data_config :
173
+ assert channel ['ChannelName' ] in parsed_input_config
174
+
175
+ # Channel A has a content type
176
+ assert 'ContentType' in parsed_input_config ['channel_a' ]
177
+ assert parsed_input_config ['channel_a' ]['ContentType' ] == 'text/csv'
178
+
179
+ # Channel B does not have content type
180
+ assert 'ContentType' not in parsed_input_config ['channel_b' ]
181
+
182
+
129
183
@patch ('sagemaker.local.local_session.LocalSession' )
130
184
def test_retrieve_artifacts (LocalSession , tmpdir ):
131
185
sagemaker_container = _SageMakerContainer ('local' , 2 , 'my-image' )
0 commit comments