30
30
from typing import Sequence , List , Dict , Any , Union
31
31
from urllib .parse import urlparse
32
32
33
+ from multiprocessing .pool import AsyncResult
34
+ import signal
33
35
import attr
34
36
import pandas as pd
35
37
from pandas import DataFrame
36
38
39
+ import boto3
40
+ from botocore .config import Config
41
+ from pathos .multiprocessing import ProcessingPool
42
+
37
43
from sagemaker import Session
38
44
from sagemaker .feature_store .feature_definition import (
39
45
FeatureDefinition ,
@@ -150,23 +156,27 @@ class IngestionManagerPandas:
150
156
151
157
Attributes:
152
158
feature_group_name (str): name of the Feature Group.
153
- sagemaker_session (Session): instance of the Session class to perform boto calls.
159
+ sagemaker_fs_runtime_client_config (Config): instance of the Config class
160
+ for boto calls.
154
161
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
155
162
max_workers (int): number of threads to create.
163
+ max_processes (int): number of processes to create. Each process spawns
164
+ ``max_workers`` threads.
156
165
"""
157
166
158
167
feature_group_name : str = attr .ib ()
159
- sagemaker_session : Session = attr .ib ()
160
- data_frame : DataFrame = attr .ib ()
168
+ sagemaker_fs_runtime_client_config : Config = attr .ib ()
161
169
max_workers : int = attr .ib (default = 1 )
162
- _futures : Dict [Any , Any ] = attr .ib (init = False , factory = dict )
170
+ max_processes : int = attr .ib (default = 1 )
171
+ _async_result : AsyncResult = attr .ib (default = None )
172
+ _processing_pool : ProcessingPool = attr .ib (default = None )
163
173
_failed_indices : List [int ] = attr .ib (factory = list )
164
174
165
175
@staticmethod
166
176
def _ingest_single_batch (
167
177
data_frame : DataFrame ,
168
178
feature_group_name : str ,
169
- sagemaker_session : Session ,
179
+ client_config : Config ,
170
180
start_index : int ,
171
181
end_index : int ,
172
182
) -> List [int ]:
@@ -175,13 +185,18 @@ def _ingest_single_batch(
175
185
Args:
176
186
data_frame (DataFrame): source DataFrame to be ingested.
177
187
feature_group_name (str): name of the Feature Group.
178
- sagemaker_session (Session): session instance to perform boto calls.
188
+ client_config (Config): Configuration for the sagemaker feature store runtime
189
+ client to perform boto calls.
179
190
start_index (int): starting position to ingest in this batch.
180
191
end_index (int): ending position to ingest in this batch.
181
192
182
193
Returns:
183
194
List of row indices that failed to be ingested.
184
195
"""
196
+ sagemaker_featurestore_runtime_client = boto3 .Session ().client (
197
+ service_name = "sagemaker-featurestore-runtime" , config = client_config
198
+ )
199
+
185
200
logger .info ("Started ingesting index %d to %d" , start_index , end_index )
186
201
failed_rows = list ()
187
202
for row in data_frame [start_index :end_index ].itertuples ():
@@ -193,9 +208,9 @@ def _ingest_single_batch(
193
208
if pd .notna (row [index ])
194
209
]
195
210
try :
196
- sagemaker_session .put_record (
197
- feature_group_name = feature_group_name ,
198
- record = [value .to_dict () for value in record ],
211
+ sagemaker_featurestore_runtime_client .put_record (
212
+ FeatureGroupName = feature_group_name ,
213
+ Record = [value .to_dict () for value in record ],
199
214
)
200
215
except Exception as e : # pylint: disable=broad-except
201
216
logger .error ("Failed to ingest row %d: %s" , row [0 ], e )
@@ -204,7 +219,7 @@ def _ingest_single_batch(
204
219
205
220
@property
206
221
def failed_rows (self ) -> List [int ]:
207
- """Get rows that failed to ingest
222
+ """Get rows that failed to ingest.
208
223
209
224
Returns:
210
225
List of row indices that failed to be ingested.
@@ -218,52 +233,134 @@ def wait(self, timeout=None):
218
233
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
219
234
if timeout is reached.
220
235
"""
221
- self ._failed_indices = list ()
222
- for future in as_completed (self ._futures , timeout = timeout ):
223
- start , end = self ._futures [future ]
224
- result = future .result ()
225
- if result :
226
- logger .error ("Failed to ingest row %d to %d" , start , end )
227
- else :
228
- logger .info ("Successfully ingested row %d to %d" , start , end )
229
- self ._failed_indices += result
236
+ try :
237
+ results = self ._async_result .get (timeout = timeout )
238
+ except KeyboardInterrupt as i :
239
+ # terminate workers abruptly on keyboard interrupt.
240
+ self ._processing_pool .terminate ()
241
+ self ._processing_pool .close ()
242
+ self ._processing_pool .clear ()
243
+ raise i
244
+ else :
245
+ # terminate normally
246
+ self ._processing_pool .close ()
247
+ self ._processing_pool .clear ()
248
+
249
+ self ._failed_indices = [
250
+ failed_index for failed_indices in results for failed_index in failed_indices
251
+ ]
230
252
231
253
if len (self ._failed_indices ) > 0 :
232
- raise RuntimeError (
233
- f"Failed to ingest some data into FeatureGroup { self .feature_group_name } "
254
+ raise IngestionError (
255
+ self ._failed_indices ,
256
+ f"Failed to ingest some data into FeatureGroup { self .feature_group_name } " ,
234
257
)
235
258
236
- def run (self , wait = True , timeout = None ):
259
+ def _run_multi_process (self , data_frame : DataFrame , wait = True , timeout = None ):
260
+ """Start the ingestion process with the specified number of processes.
261
+
262
+ Args:
263
+ data_frame (DataFrame): source DataFrame to be ingested.
264
+ wait (bool): whether to wait for the ingestion to finish or not.
265
+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
266
+ if timeout is reached.
267
+ """
268
+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_processes )
269
+
270
+ args = []
271
+ for i in range (self .max_processes ):
272
+ start_index = min (i * batch_size , data_frame .shape [0 ])
273
+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
274
+ args += [(data_frame [start_index :end_index ], start_index , timeout )]
275
+
276
+ def init_worker ():
277
+ # ignore keyboard interrupts in child processes.
278
+ signal .signal (signal .SIGINT , signal .SIG_IGN )
279
+
280
+ self ._processing_pool = ProcessingPool (self .max_processes , init_worker )
281
+ self ._processing_pool .restart (force = True )
282
+
283
+ f = lambda x : self ._run_multi_threaded (* x ) # noqa: E731
284
+ self ._async_result = self ._processing_pool .amap (f , args )
285
+
286
+ if wait :
287
+ self .wait (timeout = timeout )
288
+
289
+ def _run_multi_threaded (self , data_frame : DataFrame , row_offset = 0 , timeout = None ) -> List [int ]:
237
290
"""Start the ingestion process.
238
291
239
292
Args:
293
+ data_frame (DataFrame): source DataFrame to be ingested.
294
+ row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
295
+ index of the parent where ``data_frame`` starts. Otherwise, 0.
240
296
wait (bool): whether to wait for the ingestion to finish or not.
241
297
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
242
- if timeout is reached.
298
+ if timeout is reached.
299
+
300
+ Returns:
301
+ List of row indices that failed to be ingested.
243
302
"""
244
303
executor = ThreadPoolExecutor (max_workers = self .max_workers )
245
- batch_size = math .ceil (self . data_frame .shape [0 ] / self .max_workers )
304
+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_workers )
246
305
247
306
futures = {}
248
307
for i in range (self .max_workers ):
249
- start_index = min (i * batch_size , self . data_frame .shape [0 ])
250
- end_index = min (i * batch_size + batch_size , self . data_frame .shape [0 ])
308
+ start_index = min (i * batch_size , data_frame .shape [0 ])
309
+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
251
310
futures [
252
311
executor .submit (
253
312
self ._ingest_single_batch ,
254
313
feature_group_name = self .feature_group_name ,
255
- sagemaker_session = self .sagemaker_session ,
256
- data_frame = self .data_frame ,
314
+ data_frame = data_frame ,
257
315
start_index = start_index ,
258
316
end_index = end_index ,
317
+ client_config = self .sagemaker_fs_runtime_client_config ,
259
318
)
260
- ] = (start_index , end_index )
319
+ ] = (start_index + row_offset , end_index + row_offset )
320
+
321
+ failed_indices = list ()
322
+ for future in as_completed (futures , timeout = timeout ):
323
+ start , end = futures [future ]
324
+ result = future .result ()
325
+ if result :
326
+ logger .error ("Failed to ingest row %d to %d" , start , end )
327
+ else :
328
+ logger .info ("Successfully ingested row %d to %d" , start , end )
329
+ failed_indices += result
261
330
262
- self ._futures = futures
263
- if wait :
264
- self .wait (timeout = timeout )
265
331
executor .shutdown (wait = False )
266
332
333
+ return failed_indices
334
+
335
+ def run (self , data_frame : DataFrame , wait = True , timeout = None ):
336
+ """Start the ingestion process.
337
+
338
+ Args:
339
+ data_frame (DataFrame): source DataFrame to be ingested.
340
+ wait (bool): whether to wait for the ingestion to finish or not.
341
+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
342
+ if timeout is reached.
343
+ """
344
+ self ._run_multi_process (data_frame = data_frame , wait = wait , timeout = timeout )
345
+
346
+
347
+ class IngestionError (Exception ):
348
+ """Exception raised for errors during ingestion.
349
+
350
+ Attributes:
351
+ failed_rows: list of indices from the data frame for which ingestion failed.
352
+ message: explanation of the error
353
+ """
354
+
355
+ def __init__ (self , failed_rows , message ):
356
+ super (IngestionError , self ).__init__ (message )
357
+ self .failed_rows = failed_rows
358
+ self .message = message
359
+
360
+ def __str__ (self ) -> str :
361
+ """String representation of the error."""
362
+ return f"{ self .failed_rows } -> { self .message } "
363
+
267
364
268
365
@attr .s
269
366
class FeatureGroup :
@@ -447,6 +544,7 @@ def ingest(
447
544
self ,
448
545
data_frame : DataFrame ,
449
546
max_workers : int = 1 ,
547
+ max_processes : int = 1 ,
450
548
wait : bool = True ,
451
549
timeout : Union [int , float ] = None ,
452
550
) -> IngestionManagerPandas :
@@ -455,23 +553,45 @@ def ingest(
455
553
``max_worker`` number of thread will be created to work on different partitions of
456
554
the ``data_frame`` in parallel.
457
555
556
+ ``max_processes`` number of processes will be created to work on different partitions
557
+ of the ``data_frame`` in parallel, each with ``max_worker`` threads.
558
+
559
+ The ingest function will attempt to ingest all records in the data frame. If ``wait``
560
+ is True, then an exception is thrown after all records have been processed. If ``wait``
561
+ is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
562
+ function will throw an exception.
563
+
564
+ Zero based indices of rows that failed to be ingested can be found in the exception.
565
+ They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
566
+ the exception is thrown.
567
+
458
568
Args:
459
569
data_frame (DataFrame): data_frame to be ingested to feature store.
460
570
max_workers (int): number of threads to be created.
571
+ max_processes (int): number of processes to be created. Each process spawns
572
+ ``max_worker`` number of threads.
461
573
wait (bool): whether to wait for the ingestion to finish or not.
462
574
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
463
575
if timeout is reached.
464
576
465
577
Returns:
466
578
An instance of IngestionManagerPandas.
467
579
"""
580
+ if max_processes <= 0 :
581
+ raise RuntimeError ("max_processes must be greater than 0." )
582
+
583
+ if max_workers <= 0 :
584
+ raise RuntimeError ("max_workers must be greater than 0." )
585
+
468
586
manager = IngestionManagerPandas (
469
587
feature_group_name = self .name ,
470
- sagemaker_session = self .sagemaker_session ,
471
- data_frame = data_frame ,
588
+ sagemaker_fs_runtime_client_config = self .sagemaker_session .sagemaker_featurestore_runtime_client .meta .config ,
472
589
max_workers = max_workers ,
590
+ max_processes = max_processes ,
473
591
)
474
- manager .run (wait = wait , timeout = timeout )
592
+
593
+ manager .run (data_frame = data_frame , wait = wait , timeout = timeout )
594
+
475
595
return manager
476
596
477
597
def athena_query (self ) -> AthenaQuery :
0 commit comments