30
30
from typing import Sequence , List , Dict , Any , Union
31
31
from urllib .parse import urlparse
32
32
33
+ from multiprocessing .pool import AsyncResult
34
+ import signal
33
35
import attr
34
36
import pandas as pd
35
37
from pandas import DataFrame
36
38
39
+ import boto3
40
+ from botocore .config import Config
41
+ from pathos .multiprocessing import ProcessingPool
42
+
37
43
from sagemaker import Session
38
44
from sagemaker .feature_store .feature_definition import (
39
45
FeatureDefinition ,
@@ -150,23 +156,27 @@ class IngestionManagerPandas:
150
156
151
157
Attributes:
152
158
feature_group_name (str): name of the Feature Group.
153
- sagemaker_session (Session): instance of the Session class to perform boto calls.
159
+ sagemaker_fs_runtime_client_config (Config): instance of the Config class
160
+ for boto calls.
154
161
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
155
162
max_workers (int): number of threads to create.
163
+ max_processes (int): number of processes to create. Each process spawns
164
+ ``max_workers`` threads.
156
165
"""
157
166
158
167
feature_group_name : str = attr .ib ()
159
- sagemaker_session : Session = attr .ib ()
160
- data_frame : DataFrame = attr .ib ()
168
+ sagemaker_fs_runtime_client_config : Config = attr .ib ()
161
169
max_workers : int = attr .ib (default = 1 )
162
- _futures : Dict [Any , Any ] = attr .ib (init = False , factory = dict )
170
+ max_processes : int = attr .ib (default = 1 )
171
+ _async_result : AsyncResult = attr .ib (default = None )
172
+ _processing_pool : ProcessingPool = attr .ib (default = None )
163
173
_failed_indices : List [int ] = attr .ib (factory = list )
164
174
165
175
@staticmethod
166
176
def _ingest_single_batch (
167
177
data_frame : DataFrame ,
168
178
feature_group_name : str ,
169
- sagemaker_session : Session ,
179
+ client_config : Config ,
170
180
start_index : int ,
171
181
end_index : int ,
172
182
) -> List [int ]:
@@ -175,13 +185,19 @@ def _ingest_single_batch(
175
185
Args:
176
186
data_frame (DataFrame): source DataFrame to be ingested.
177
187
feature_group_name (str): name of the Feature Group.
178
- sagemaker_session (Session): session instance to perform boto calls.
188
+ client_config (Config): Configuration for the sagemaker feature store runtime
189
+ client to perform boto calls.
179
190
start_index (int): starting position to ingest in this batch.
180
191
end_index (int): ending position to ingest in this batch.
181
192
182
193
Returns:
183
194
List of row indices that failed to be ingested.
184
195
"""
196
+ sagemaker_featurestore_runtime_client = boto3 .Session ().client (
197
+ service_name = "sagemaker-featurestore-runtime" ,
198
+ config = client_config
199
+ )
200
+
185
201
logger .info ("Started ingesting index %d to %d" , start_index , end_index )
186
202
failed_rows = list ()
187
203
for row in data_frame [start_index :end_index ].itertuples ():
@@ -193,9 +209,9 @@ def _ingest_single_batch(
193
209
if pd .notna (row [index ])
194
210
]
195
211
try :
196
- sagemaker_session .put_record (
197
- feature_group_name = feature_group_name ,
198
- record = [value .to_dict () for value in record ],
212
+ sagemaker_featurestore_runtime_client .put_record (
213
+ FeatureGroupName = feature_group_name ,
214
+ Record = [value .to_dict () for value in record ]
199
215
)
200
216
except Exception as e : # pylint: disable=broad-except
201
217
logger .error ("Failed to ingest row %d: %s" , row [0 ], e )
@@ -205,7 +221,6 @@ def _ingest_single_batch(
205
221
@property
206
222
def failed_rows (self ) -> List [int ]:
207
223
"""Get rows that failed to ingest
208
-
209
224
Returns:
210
225
List of row indices that failed to be ingested.
211
226
"""
@@ -218,52 +233,138 @@ def wait(self, timeout=None):
218
233
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
219
234
if timeout is reached.
220
235
"""
221
- self ._failed_indices = list ()
222
- for future in as_completed (self ._futures , timeout = timeout ):
223
- start , end = self ._futures [future ]
224
- result = future .result ()
225
- if result :
226
- logger .error ("Failed to ingest row %d to %d" , start , end )
227
- else :
228
- logger .info ("Successfully ingested row %d to %d" , start , end )
229
- self ._failed_indices += result
236
+ try :
237
+ results = self ._async_result .get (timeout = timeout )
238
+ except KeyboardInterrupt as i :
239
+ # terminate workers abruptly on keyboard interrupt.
240
+ self ._processing_pool .terminate ()
241
+ self ._processing_pool .close ()
242
+ self ._processing_pool .clear ()
243
+ raise i
244
+ else :
245
+ # terminate normally
246
+ self ._processing_pool .close ()
247
+ self ._processing_pool .clear ()
248
+
249
+ self ._failed_indices = [
250
+ failed_index for failed_indices in results for failed_index in failed_indices
251
+ ]
230
252
231
253
if len (self ._failed_indices ) > 0 :
232
- raise RuntimeError (
254
+ raise IngestionError (
255
+ self ._failed_indices ,
233
256
f"Failed to ingest some data into FeatureGroup { self .feature_group_name } "
234
257
)
235
258
236
- def run (self , wait = True , timeout = None ):
259
+ def _run_multi_process (self , data_frame : DataFrame , wait = True , timeout = None ):
260
+ """Start the ingestion process with the specified number of processes.
261
+
262
+ Args:
263
+ data_frame (DataFrame): source DataFrame to be ingested.
264
+ wait (bool): whether to wait for the ingestion to finish or not.
265
+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
266
+ if timeout is reached.
267
+ """
268
+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_processes )
269
+
270
+ args = []
271
+ for i in range (self .max_processes ):
272
+ start_index = min (i * batch_size , data_frame .shape [0 ])
273
+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
274
+ args += [(data_frame [start_index :end_index ], start_index , True , timeout )]
275
+
276
+ def init_worker ():
277
+ # ignore keyboard interrupts in child processes.
278
+ signal .signal (signal .SIGINT , signal .SIG_IGN )
279
+
280
+ self ._processing_pool = ProcessingPool (self .max_processes , init_worker )
281
+ self ._processing_pool .restart (force = True )
282
+
283
+ f = lambda x : self ._run_multi_threaded (* x ) # noqa: E731
284
+ self ._async_result = self ._processing_pool .amap (f , args )
285
+
286
+ if wait :
287
+ self .wait (timeout = timeout )
288
+
289
+ def _run_multi_threaded (
290
+ self ,
291
+ data_frame : DataFrame ,
292
+ row_offset = 0 ,
293
+ timeout = None
294
+ ) -> List [int ]:
237
295
"""Start the ingestion process.
238
296
239
297
Args:
298
+ data_frame (DataFrame): source DataFrame to be ingested.
299
+ row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
300
+ index of the parent where ``data_frame`` starts. Otherwise, 0.
240
301
wait (bool): whether to wait for the ingestion to finish or not.
241
302
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
242
- if timeout is reached.
303
+ if timeout is reached.
304
+
305
+ Returns:
306
+ List of row indices that failed to be ingested.
243
307
"""
244
308
executor = ThreadPoolExecutor (max_workers = self .max_workers )
245
- batch_size = math .ceil (self . data_frame .shape [0 ] / self .max_workers )
309
+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_workers )
246
310
247
311
futures = {}
248
312
for i in range (self .max_workers ):
249
- start_index = min (i * batch_size , self . data_frame .shape [0 ])
250
- end_index = min (i * batch_size + batch_size , self . data_frame .shape [0 ])
313
+ start_index = min (i * batch_size , data_frame .shape [0 ])
314
+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
251
315
futures [
252
316
executor .submit (
253
317
self ._ingest_single_batch ,
254
318
feature_group_name = self .feature_group_name ,
255
- sagemaker_session = self .sagemaker_session ,
256
- data_frame = self .data_frame ,
319
+ data_frame = data_frame ,
257
320
start_index = start_index ,
258
321
end_index = end_index ,
322
+ client_config = self .sagemaker_fs_runtime_client_config ,
259
323
)
260
- ] = (start_index , end_index )
324
+ ] = (start_index + row_offset , end_index + row_offset )
325
+
326
+ failed_indices = list ()
327
+ for future in as_completed (futures , timeout = timeout ):
328
+ start , end = futures [future ]
329
+ result = future .result ()
330
+ if result :
331
+ logger .error ("Failed to ingest row %d to %d" , start , end )
332
+ else :
333
+ logger .info ("Successfully ingested row %d to %d" , start , end )
334
+ failed_indices += result
261
335
262
- self ._futures = futures
263
- if wait :
264
- self .wait (timeout = timeout )
265
336
executor .shutdown (wait = False )
266
337
338
+ return failed_indices
339
+
340
+ def run (self , data_frame : DataFrame , wait = True , timeout = None ):
341
+ """Start the ingestion process.
342
+
343
+ Args:
344
+ data_frame (DataFrame): source DataFrame to be ingested.
345
+ wait (bool): whether to wait for the ingestion to finish or not.
346
+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
347
+ if timeout is reached.
348
+ """
349
+ self ._run_multi_process (data_frame = data_frame , wait = wait , timeout = timeout )
350
+
351
+
352
+ class IngestionError (Exception ):
353
+ """Exception raised for errors during ingestion.
354
+
355
+ Attributes:
356
+ failed_rows: list of indices from the data frame for which ingestion failed.
357
+ message: explanation of the error
358
+ """
359
+
360
+ def __init__ (self , failed_rows , message ):
361
+ super (IngestionError , self ).__init__ (message )
362
+ self .failed_rows = failed_rows
363
+ self .message = message
364
+
365
+ def __str__ (self ):
366
+ return f'{ self .failed_rows } -> { self .message } '
367
+
267
368
268
369
@attr .s
269
370
class FeatureGroup :
@@ -447,31 +548,56 @@ def ingest(
447
548
self ,
448
549
data_frame : DataFrame ,
449
550
max_workers : int = 1 ,
551
+ max_processes : int = 1 ,
450
552
wait : bool = True ,
451
553
timeout : Union [int , float ] = None ,
452
554
) -> IngestionManagerPandas :
453
555
"""Ingest the content of a pandas DataFrame to feature store.
454
556
455
557
``max_worker`` number of thread will be created to work on different partitions of
456
558
the ``data_frame`` in parallel.
559
+
560
+ ``max_processes`` number of processes will be created to work on different partitions
561
+ of the ``data_frame`` in parallel, each with ``max_worker`` threads.
562
+
563
+ The ingest function will attempt to ingest all records in the data frame. If ``wait``
564
+ is True, then an exception is thrown after all records have been processed. If ``wait``
565
+ is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
566
+ function will throw an exception.
567
+
568
+ Zero based indices of rows that failed to be ingested can be found in the exception.
569
+ They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
570
+ the exception is thrown.
457
571
458
572
Args:
459
573
data_frame (DataFrame): data_frame to be ingested to feature store.
460
574
max_workers (int): number of threads to be created.
575
+ max_processes (int): number of processes to be created. Each process spawns
576
+ ``max_worker`` number of threads.
461
577
wait (bool): whether to wait for the ingestion to finish or not.
462
578
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
463
579
if timeout is reached.
464
580
465
581
Returns:
466
582
An instance of IngestionManagerPandas.
467
583
"""
584
+ if max_processes <= 0 :
585
+ raise RuntimeError ("max_processes must be greater than 0." )
586
+
587
+ if max_workers <= 0 :
588
+ raise RuntimeError ("max_workers must be greater than 0." )
589
+
468
590
manager = IngestionManagerPandas (
469
591
feature_group_name = self .name ,
470
- sagemaker_session = self .sagemaker_session ,
471
- data_frame = data_frame ,
592
+ sagemaker_fs_runtime_client_config = self .sagemaker_session
593
+ .sagemaker_featurestore_runtime_client
594
+ .meta .config ,
472
595
max_workers = max_workers ,
596
+ max_processes = max_processes ,
473
597
)
474
- manager .run (wait = wait , timeout = timeout )
598
+
599
+ manager .run (data_frame = data_frame , wait = wait , timeout = timeout )
600
+
475
601
return manager
476
602
477
603
def athena_query (self ) -> AthenaQuery :
0 commit comments