@@ -207,7 +207,8 @@ def _ingest_single_batch(
207
207
for row in data_frame [start_index :end_index ].itertuples ():
208
208
record = [
209
209
FeatureValue (
210
- feature_name = data_frame .columns [index - 1 ], value_as_string = str (row [index ])
210
+ feature_name = data_frame .columns [index - 1 ],
211
+ value_as_string = str (row [index ]),
211
212
)
212
213
for index in range (1 , len (row ))
213
214
if pd .notna (row [index ])
@@ -270,13 +271,24 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
270
271
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
271
272
if timeout is reached.
272
273
"""
274
+ # pylint: disable=I1101
273
275
batch_size = math .ceil (data_frame .shape [0 ] / self .max_processes )
276
+ # pylint: enable=I1101
274
277
275
278
args = []
276
279
for i in range (self .max_processes ):
277
280
start_index = min (i * batch_size , data_frame .shape [0 ])
278
281
end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
279
- args += [(data_frame [start_index :end_index ], start_index , timeout )]
282
+ args += [
283
+ (
284
+ self .max_workers ,
285
+ self .feature_group_name ,
286
+ self .sagemaker_fs_runtime_client_config ,
287
+ data_frame [start_index :end_index ],
288
+ start_index ,
289
+ timeout ,
290
+ )
291
+ ]
280
292
281
293
def init_worker ():
282
294
# ignore keyboard interrupts in child processes.
@@ -285,13 +297,21 @@ def init_worker():
285
297
self ._processing_pool = ProcessingPool (self .max_processes , init_worker )
286
298
self ._processing_pool .restart (force = True )
287
299
288
- f = lambda x : self ._run_multi_threaded (* x ) # noqa: E731
300
+ f = lambda x : IngestionManagerPandas ._run_multi_threaded (* x ) # noqa: E731
289
301
self ._async_result = self ._processing_pool .amap (f , args )
290
302
291
303
if wait :
292
304
self .wait (timeout = timeout )
293
305
294
- def _run_multi_threaded (self , data_frame : DataFrame , row_offset = 0 , timeout = None ) -> List [int ]:
306
+ @staticmethod
307
+ def _run_multi_threaded (
308
+ max_workers : int ,
309
+ feature_group_name : str ,
310
+ sagemaker_fs_runtime_client_config : Config ,
311
+ data_frame : DataFrame ,
312
+ row_offset = 0 ,
313
+ timeout = None ,
314
+ ) -> List [int ]:
295
315
"""Start the ingestion process.
296
316
297
317
Args:
@@ -305,21 +325,23 @@ def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None)
305
325
Returns:
306
326
List of row indices that failed to be ingested.
307
327
"""
308
- executor = ThreadPoolExecutor (max_workers = self .max_workers )
309
- batch_size = math .ceil (data_frame .shape [0 ] / self .max_workers )
328
+ executor = ThreadPoolExecutor (max_workers = max_workers )
329
+ # pylint: disable=I1101
330
+ batch_size = math .ceil (data_frame .shape [0 ] / max_workers )
331
+ # pylint: enable=I1101
310
332
311
333
futures = {}
312
- for i in range (self . max_workers ):
334
+ for i in range (max_workers ):
313
335
start_index = min (i * batch_size , data_frame .shape [0 ])
314
336
end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
315
337
futures [
316
338
executor .submit (
317
- self ._ingest_single_batch ,
318
- feature_group_name = self . feature_group_name ,
339
+ IngestionManagerPandas ._ingest_single_batch ,
340
+ feature_group_name = feature_group_name ,
319
341
data_frame = data_frame ,
320
342
start_index = start_index ,
321
343
end_index = end_index ,
322
- client_config = self . sagemaker_fs_runtime_client_config ,
344
+ client_config = sagemaker_fs_runtime_client_config ,
323
345
)
324
346
] = (start_index + row_offset , end_index + row_offset )
325
347
0 commit comments