Skip to content

Commit 9ce8f64

Browse files
author
Alex
committed
feature: support multiprocess feature group ingest (aws#2111)
1 parent 44fbcc9 commit 9ce8f64

File tree

4 files changed

+270
-63
lines changed

4 files changed

+270
-63
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def read_version():
4343
"importlib-metadata>=1.4.0",
4444
"packaging>=20.0",
4545
"pandas",
46+
"pathos",
4647
]
4748

4849
# Specific use case dependencies

src/sagemaker/feature_store/feature_group.py

+160-34
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@
3030
from typing import Sequence, List, Dict, Any, Union
3131
from urllib.parse import urlparse
3232

33+
from multiprocessing.pool import AsyncResult
34+
import signal
3335
import attr
3436
import pandas as pd
3537
from pandas import DataFrame
3638

39+
import boto3
40+
from botocore.config import Config
41+
from pathos.multiprocessing import ProcessingPool
42+
3743
from sagemaker import Session
3844
from sagemaker.feature_store.feature_definition import (
3945
FeatureDefinition,
@@ -150,23 +156,27 @@ class IngestionManagerPandas:
150156
151157
Attributes:
152158
feature_group_name (str): name of the Feature Group.
153-
sagemaker_session (Session): instance of the Session class to perform boto calls.
159+
sagemaker_fs_runtime_client_config (Config): instance of the Config class
160+
for boto calls.
154161
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
155162
max_workers (int): number of threads to create.
163+
max_processes (int): number of processes to create. Each process spawns
164+
``max_workers`` threads.
156165
"""
157166

158167
feature_group_name: str = attr.ib()
159-
sagemaker_session: Session = attr.ib()
160-
data_frame: DataFrame = attr.ib()
168+
sagemaker_fs_runtime_client_config: Config = attr.ib()
161169
max_workers: int = attr.ib(default=1)
162-
_futures: Dict[Any, Any] = attr.ib(init=False, factory=dict)
170+
max_processes: int = attr.ib(default=1)
171+
_async_result: AsyncResult = attr.ib(default=None)
172+
_processing_pool: ProcessingPool = attr.ib(default=None)
163173
_failed_indices: List[int] = attr.ib(factory=list)
164174

165175
@staticmethod
166176
def _ingest_single_batch(
167177
data_frame: DataFrame,
168178
feature_group_name: str,
169-
sagemaker_session: Session,
179+
client_config: Config,
170180
start_index: int,
171181
end_index: int,
172182
) -> List[int]:
@@ -175,13 +185,19 @@ def _ingest_single_batch(
175185
Args:
176186
data_frame (DataFrame): source DataFrame to be ingested.
177187
feature_group_name (str): name of the Feature Group.
178-
sagemaker_session (Session): session instance to perform boto calls.
188+
client_config (Config): Configuration for the sagemaker feature store runtime
189+
client to perform boto calls.
179190
start_index (int): starting position to ingest in this batch.
180191
end_index (int): ending position to ingest in this batch.
181192
182193
Returns:
183194
List of row indices that failed to be ingested.
184195
"""
196+
sagemaker_featurestore_runtime_client = boto3.Session().client(
197+
service_name="sagemaker-featurestore-runtime",
198+
config=client_config
199+
)
200+
185201
logger.info("Started ingesting index %d to %d", start_index, end_index)
186202
failed_rows = list()
187203
for row in data_frame[start_index:end_index].itertuples():
@@ -193,9 +209,9 @@ def _ingest_single_batch(
193209
if pd.notna(row[index])
194210
]
195211
try:
196-
sagemaker_session.put_record(
197-
feature_group_name=feature_group_name,
198-
record=[value.to_dict() for value in record],
212+
sagemaker_featurestore_runtime_client.put_record(
213+
FeatureGroupName=feature_group_name,
214+
Record=[value.to_dict() for value in record]
199215
)
200216
except Exception as e: # pylint: disable=broad-except
201217
logger.error("Failed to ingest row %d: %s", row[0], e)
@@ -205,7 +221,6 @@ def _ingest_single_batch(
205221
@property
206222
def failed_rows(self) -> List[int]:
207223
"""Get rows that failed to ingest
208-
209224
Returns:
210225
List of row indices that failed to be ingested.
211226
"""
@@ -218,52 +233,138 @@ def wait(self, timeout=None):
218233
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
219234
if timeout is reached.
220235
"""
221-
self._failed_indices = list()
222-
for future in as_completed(self._futures, timeout=timeout):
223-
start, end = self._futures[future]
224-
result = future.result()
225-
if result:
226-
logger.error("Failed to ingest row %d to %d", start, end)
227-
else:
228-
logger.info("Successfully ingested row %d to %d", start, end)
229-
self._failed_indices += result
236+
try:
237+
results = self._async_result.get(timeout=timeout)
238+
except KeyboardInterrupt as i:
239+
# terminate workers abruptly on keyboard interrupt.
240+
self._processing_pool.terminate()
241+
self._processing_pool.close()
242+
self._processing_pool.clear()
243+
raise i
244+
else:
245+
# terminate normally
246+
self._processing_pool.close()
247+
self._processing_pool.clear()
248+
249+
self._failed_indices = [
250+
failed_index for failed_indices in results for failed_index in failed_indices
251+
]
230252

231253
if len(self._failed_indices) > 0:
232-
raise RuntimeError(
254+
raise IngestionError(
255+
self._failed_indices,
233256
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}"
234257
)
235258

236-
def run(self, wait=True, timeout=None):
259+
def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
260+
"""Start the ingestion process with the specified number of processes.
261+
262+
Args:
263+
data_frame (DataFrame): source DataFrame to be ingested.
264+
wait (bool): whether to wait for the ingestion to finish or not.
265+
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
266+
if timeout is reached.
267+
"""
268+
batch_size = math.ceil(data_frame.shape[0] / self.max_processes)
269+
270+
args = []
271+
for i in range(self.max_processes):
272+
start_index = min(i * batch_size, data_frame.shape[0])
273+
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
274+
args += [(data_frame[start_index:end_index], start_index, True, timeout)]
275+
276+
def init_worker():
277+
# ignore keyboard interrupts in child processes.
278+
signal.signal(signal.SIGINT, signal.SIG_IGN)
279+
280+
self._processing_pool = ProcessingPool(self.max_processes, init_worker)
281+
self._processing_pool.restart(force=True)
282+
283+
f = lambda x: self._run_multi_threaded(*x) # noqa: E731
284+
self._async_result = self._processing_pool.amap(f, args)
285+
286+
if wait:
287+
self.wait(timeout=timeout)
288+
289+
def _run_multi_threaded(
290+
self,
291+
data_frame: DataFrame,
292+
row_offset=0,
293+
timeout=None
294+
) -> List[int]:
237295
"""Start the ingestion process.
238296
239297
Args:
298+
data_frame (DataFrame): source DataFrame to be ingested.
299+
row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
300+
index of the parent where ``data_frame`` starts. Otherwise, 0.
240301
wait (bool): whether to wait for the ingestion to finish or not.
241302
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
242-
if timeout is reached.
303+
if timeout is reached.
304+
305+
Returns:
306+
List of row indices that failed to be ingested.
243307
"""
244308
executor = ThreadPoolExecutor(max_workers=self.max_workers)
245-
batch_size = math.ceil(self.data_frame.shape[0] / self.max_workers)
309+
batch_size = math.ceil(data_frame.shape[0] / self.max_workers)
246310

247311
futures = {}
248312
for i in range(self.max_workers):
249-
start_index = min(i * batch_size, self.data_frame.shape[0])
250-
end_index = min(i * batch_size + batch_size, self.data_frame.shape[0])
313+
start_index = min(i * batch_size, data_frame.shape[0])
314+
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
251315
futures[
252316
executor.submit(
253317
self._ingest_single_batch,
254318
feature_group_name=self.feature_group_name,
255-
sagemaker_session=self.sagemaker_session,
256-
data_frame=self.data_frame,
319+
data_frame=data_frame,
257320
start_index=start_index,
258321
end_index=end_index,
322+
client_config=self.sagemaker_fs_runtime_client_config,
259323
)
260-
] = (start_index, end_index)
324+
] = (start_index + row_offset, end_index + row_offset)
325+
326+
failed_indices = list()
327+
for future in as_completed(futures, timeout=timeout):
328+
start, end = futures[future]
329+
result = future.result()
330+
if result:
331+
logger.error("Failed to ingest row %d to %d", start, end)
332+
else:
333+
logger.info("Successfully ingested row %d to %d", start, end)
334+
failed_indices += result
261335

262-
self._futures = futures
263-
if wait:
264-
self.wait(timeout=timeout)
265336
executor.shutdown(wait=False)
266337

338+
return failed_indices
339+
340+
def run(self, data_frame: DataFrame, wait=True, timeout=None):
341+
"""Start the ingestion process.
342+
343+
Args:
344+
data_frame (DataFrame): source DataFrame to be ingested.
345+
wait (bool): whether to wait for the ingestion to finish or not.
346+
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
347+
if timeout is reached.
348+
"""
349+
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
350+
351+
352+
class IngestionError(Exception):
353+
"""Exception raised for errors during ingestion.
354+
355+
Attributes:
356+
failed_rows: list of indices from the data frame for which ingestion failed.
357+
message: explanation of the error
358+
"""
359+
360+
def __init__(self, failed_rows, message):
361+
super(IngestionError, self).__init__(message)
362+
self.failed_rows = failed_rows
363+
self.message = message
364+
365+
def __str__(self):
366+
return f'{self.failed_rows} -> {self.message}'
367+
267368

268369
@attr.s
269370
class FeatureGroup:
@@ -447,31 +548,56 @@ def ingest(
447548
self,
448549
data_frame: DataFrame,
449550
max_workers: int = 1,
551+
max_processes: int = 1,
450552
wait: bool = True,
451553
timeout: Union[int, float] = None,
452554
) -> IngestionManagerPandas:
453555
"""Ingest the content of a pandas DataFrame to feature store.
454556
455557
``max_worker`` number of thread will be created to work on different partitions of
456558
the ``data_frame`` in parallel.
559+
560+
``max_processes`` number of processes will be created to work on different partitions
561+
of the ``data_frame`` in parallel, each with ``max_worker`` threads.
562+
563+
The ingest function will attempt to ingest all records in the data frame. If ``wait``
564+
is True, then an exception is thrown after all records have been processed. If ``wait``
565+
is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
566+
function will throw an exception.
567+
568+
Zero based indices of rows that failed to be ingested can be found in the exception.
569+
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
570+
the exception is thrown.
457571
458572
Args:
459573
data_frame (DataFrame): data_frame to be ingested to feature store.
460574
max_workers (int): number of threads to be created.
575+
max_processes (int): number of processes to be created. Each process spawns
576+
``max_worker`` number of threads.
461577
wait (bool): whether to wait for the ingestion to finish or not.
462578
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
463579
if timeout is reached.
464580
465581
Returns:
466582
An instance of IngestionManagerPandas.
467583
"""
584+
if max_processes <= 0:
585+
raise RuntimeError("max_processes must be greater than 0.")
586+
587+
if max_workers <= 0:
588+
raise RuntimeError("max_workers must be greater than 0.")
589+
468590
manager = IngestionManagerPandas(
469591
feature_group_name=self.name,
470-
sagemaker_session=self.sagemaker_session,
471-
data_frame=data_frame,
592+
sagemaker_fs_runtime_client_config=self.sagemaker_session
593+
.sagemaker_featurestore_runtime_client
594+
.meta.config,
472595
max_workers=max_workers,
596+
max_processes=max_processes,
473597
)
474-
manager.run(wait=wait, timeout=timeout)
598+
599+
manager.run(data_frame=data_frame, wait=wait, timeout=timeout)
600+
475601
return manager
476602

477603
def athena_query(self) -> AthenaQuery:

tests/integ/test_feature_store.py

+27
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,33 @@ def test_ingest_without_string_feature(
265265
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
266266

267267

268+
def test_ingest_multi_process(
269+
feature_store_session,
270+
role,
271+
feature_group_name,
272+
offline_store_s3_uri,
273+
pandas_data_frame,
274+
):
275+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
276+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
277+
278+
with cleanup_feature_group(feature_group):
279+
output = feature_group.create(
280+
s3_uri=offline_store_s3_uri,
281+
record_identifier_name="feature1",
282+
event_time_feature_name="feature3",
283+
role_arn=role,
284+
enable_online_store=True,
285+
)
286+
_wait_for_feature_group_create(feature_group)
287+
288+
feature_group.ingest(
289+
data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True
290+
)
291+
292+
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
293+
294+
268295
def _wait_for_feature_group_create(feature_group: FeatureGroup):
269296
status = feature_group.describe().get("FeatureGroupStatus")
270297
while status == "Creating":

0 commit comments

Comments
 (0)