Skip to content

Commit f2d9c2d

Browse files
author
Alex
committed
feature: support multiprocess feature group ingest (#2111)
1 parent 564a061 commit f2d9c2d

File tree

4 files changed

+274
-64
lines changed

4 files changed

+274
-64
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def read_version():
4343
"importlib-metadata>=1.4.0",
4444
"packaging>=20.0",
4545
"pandas",
46+
"pathos",
4647
]
4748

4849
# Specific use case dependencies

src/sagemaker/feature_store/feature_group.py

+155-35
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@
3030
from typing import Sequence, List, Dict, Any, Union
3131
from urllib.parse import urlparse
3232

33+
from multiprocessing.pool import AsyncResult
34+
import signal
3335
import attr
3436
import pandas as pd
3537
from pandas import DataFrame
3638

39+
import boto3
40+
from botocore.config import Config
41+
from pathos.multiprocessing import ProcessingPool
42+
3743
from sagemaker import Session
3844
from sagemaker.feature_store.feature_definition import (
3945
FeatureDefinition,
@@ -150,23 +156,27 @@ class IngestionManagerPandas:
150156
151157
Attributes:
152158
feature_group_name (str): name of the Feature Group.
153-
sagemaker_session (Session): instance of the Session class to perform boto calls.
159+
sagemaker_fs_runtime_client_config (Config): instance of the Config class
160+
for boto calls.
154161
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
155162
max_workers (int): number of threads to create.
163+
max_processes (int): number of processes to create. Each process spawns
164+
``max_workers`` threads.
156165
"""
157166

158167
feature_group_name: str = attr.ib()
159-
sagemaker_session: Session = attr.ib()
160-
data_frame: DataFrame = attr.ib()
168+
sagemaker_fs_runtime_client_config: Config = attr.ib()
161169
max_workers: int = attr.ib(default=1)
162-
_futures: Dict[Any, Any] = attr.ib(init=False, factory=dict)
170+
max_processes: int = attr.ib(default=1)
171+
_async_result: AsyncResult = attr.ib(default=None)
172+
_processing_pool: ProcessingPool = attr.ib(default=None)
163173
_failed_indices: List[int] = attr.ib(factory=list)
164174

165175
@staticmethod
166176
def _ingest_single_batch(
167177
data_frame: DataFrame,
168178
feature_group_name: str,
169-
sagemaker_session: Session,
179+
client_config: Config,
170180
start_index: int,
171181
end_index: int,
172182
) -> List[int]:
@@ -175,13 +185,18 @@ def _ingest_single_batch(
175185
Args:
176186
data_frame (DataFrame): source DataFrame to be ingested.
177187
feature_group_name (str): name of the Feature Group.
178-
sagemaker_session (Session): session instance to perform boto calls.
188+
client_config (Config): Configuration for the sagemaker feature store runtime
189+
client to perform boto calls.
179190
start_index (int): starting position to ingest in this batch.
180191
end_index (int): ending position to ingest in this batch.
181192
182193
Returns:
183194
List of row indices that failed to be ingested.
184195
"""
196+
sagemaker_featurestore_runtime_client = boto3.Session().client(
197+
service_name="sagemaker-featurestore-runtime", config=client_config
198+
)
199+
185200
logger.info("Started ingesting index %d to %d", start_index, end_index)
186201
failed_rows = list()
187202
for row in data_frame[start_index:end_index].itertuples():
@@ -193,9 +208,9 @@ def _ingest_single_batch(
193208
if pd.notna(row[index])
194209
]
195210
try:
196-
sagemaker_session.put_record(
197-
feature_group_name=feature_group_name,
198-
record=[value.to_dict() for value in record],
211+
sagemaker_featurestore_runtime_client.put_record(
212+
FeatureGroupName=feature_group_name,
213+
Record=[value.to_dict() for value in record],
199214
)
200215
except Exception as e: # pylint: disable=broad-except
201216
logger.error("Failed to ingest row %d: %s", row[0], e)
@@ -204,7 +219,7 @@ def _ingest_single_batch(
204219

205220
@property
206221
def failed_rows(self) -> List[int]:
207-
"""Get rows that failed to ingest
222+
"""Get rows that failed to ingest.
208223
209224
Returns:
210225
List of row indices that failed to be ingested.
@@ -218,52 +233,134 @@ def wait(self, timeout=None):
218233
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
219234
if timeout is reached.
220235
"""
221-
self._failed_indices = list()
222-
for future in as_completed(self._futures, timeout=timeout):
223-
start, end = self._futures[future]
224-
result = future.result()
225-
if result:
226-
logger.error("Failed to ingest row %d to %d", start, end)
227-
else:
228-
logger.info("Successfully ingested row %d to %d", start, end)
229-
self._failed_indices += result
236+
try:
237+
results = self._async_result.get(timeout=timeout)
238+
except KeyboardInterrupt as i:
239+
# terminate workers abruptly on keyboard interrupt.
240+
self._processing_pool.terminate()
241+
self._processing_pool.close()
242+
self._processing_pool.clear()
243+
raise i
244+
else:
245+
# terminate normally
246+
self._processing_pool.close()
247+
self._processing_pool.clear()
248+
249+
self._failed_indices = [
250+
failed_index for failed_indices in results for failed_index in failed_indices
251+
]
230252

231253
if len(self._failed_indices) > 0:
232-
raise RuntimeError(
233-
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}"
254+
raise IngestionError(
255+
self._failed_indices,
256+
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
234257
)
235258

236-
def run(self, wait=True, timeout=None):
259+
def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
260+
"""Start the ingestion process with the specified number of processes.
261+
262+
Args:
263+
data_frame (DataFrame): source DataFrame to be ingested.
264+
wait (bool): whether to wait for the ingestion to finish or not.
265+
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
266+
if timeout is reached.
267+
"""
268+
batch_size = math.ceil(data_frame.shape[0] / self.max_processes)
269+
270+
args = []
271+
for i in range(self.max_processes):
272+
start_index = min(i * batch_size, data_frame.shape[0])
273+
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
274+
args += [(data_frame[start_index:end_index], start_index, timeout)]
275+
276+
def init_worker():
277+
# ignore keyboard interrupts in child processes.
278+
signal.signal(signal.SIGINT, signal.SIG_IGN)
279+
280+
self._processing_pool = ProcessingPool(self.max_processes, init_worker)
281+
self._processing_pool.restart(force=True)
282+
283+
f = lambda x: self._run_multi_threaded(*x) # noqa: E731
284+
self._async_result = self._processing_pool.amap(f, args)
285+
286+
if wait:
287+
self.wait(timeout=timeout)
288+
289+
def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None) -> List[int]:
237290
"""Start the ingestion process.
238291
239292
Args:
293+
data_frame (DataFrame): source DataFrame to be ingested.
294+
row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
295+
index of the parent where ``data_frame`` starts. Otherwise, 0.
240296
wait (bool): whether to wait for the ingestion to finish or not.
241297
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
242-
if timeout is reached.
298+
if timeout is reached.
299+
300+
Returns:
301+
List of row indices that failed to be ingested.
243302
"""
244303
executor = ThreadPoolExecutor(max_workers=self.max_workers)
245-
batch_size = math.ceil(self.data_frame.shape[0] / self.max_workers)
304+
batch_size = math.ceil(data_frame.shape[0] / self.max_workers)
246305

247306
futures = {}
248307
for i in range(self.max_workers):
249-
start_index = min(i * batch_size, self.data_frame.shape[0])
250-
end_index = min(i * batch_size + batch_size, self.data_frame.shape[0])
308+
start_index = min(i * batch_size, data_frame.shape[0])
309+
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
251310
futures[
252311
executor.submit(
253312
self._ingest_single_batch,
254313
feature_group_name=self.feature_group_name,
255-
sagemaker_session=self.sagemaker_session,
256-
data_frame=self.data_frame,
314+
data_frame=data_frame,
257315
start_index=start_index,
258316
end_index=end_index,
317+
client_config=self.sagemaker_fs_runtime_client_config,
259318
)
260-
] = (start_index, end_index)
319+
] = (start_index + row_offset, end_index + row_offset)
320+
321+
failed_indices = list()
322+
for future in as_completed(futures, timeout=timeout):
323+
start, end = futures[future]
324+
result = future.result()
325+
if result:
326+
logger.error("Failed to ingest row %d to %d", start, end)
327+
else:
328+
logger.info("Successfully ingested row %d to %d", start, end)
329+
failed_indices += result
261330

262-
self._futures = futures
263-
if wait:
264-
self.wait(timeout=timeout)
265331
executor.shutdown(wait=False)
266332

333+
return failed_indices
334+
335+
def run(self, data_frame: DataFrame, wait=True, timeout=None):
336+
"""Start the ingestion process.
337+
338+
Args:
339+
data_frame (DataFrame): source DataFrame to be ingested.
340+
wait (bool): whether to wait for the ingestion to finish or not.
341+
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
342+
if timeout is reached.
343+
"""
344+
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
345+
346+
347+
class IngestionError(Exception):
348+
"""Exception raised for errors during ingestion.
349+
350+
Attributes:
351+
failed_rows: list of indices from the data frame for which ingestion failed.
352+
message: explanation of the error
353+
"""
354+
355+
def __init__(self, failed_rows, message):
356+
super(IngestionError, self).__init__(message)
357+
self.failed_rows = failed_rows
358+
self.message = message
359+
360+
def __str__(self) -> str:
361+
"""String representation of the error."""
362+
return f"{self.failed_rows} -> {self.message}"
363+
267364

268365
@attr.s
269366
class FeatureGroup:
@@ -447,6 +544,7 @@ def ingest(
447544
self,
448545
data_frame: DataFrame,
449546
max_workers: int = 1,
547+
max_processes: int = 1,
450548
wait: bool = True,
451549
timeout: Union[int, float] = None,
452550
) -> IngestionManagerPandas:
@@ -455,23 +553,45 @@ def ingest(
455553
``max_worker`` number of thread will be created to work on different partitions of
456554
the ``data_frame`` in parallel.
457555
556+
``max_processes`` number of processes will be created to work on different partitions
557+
of the ``data_frame`` in parallel, each with ``max_worker`` threads.
558+
559+
The ingest function will attempt to ingest all records in the data frame. If ``wait``
560+
is True, then an exception is thrown after all records have been processed. If ``wait``
561+
is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
562+
function will throw an exception.
563+
564+
Zero based indices of rows that failed to be ingested can be found in the exception.
565+
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
566+
the exception is thrown.
567+
458568
Args:
459569
data_frame (DataFrame): data_frame to be ingested to feature store.
460570
max_workers (int): number of threads to be created.
571+
max_processes (int): number of processes to be created. Each process spawns
572+
``max_worker`` number of threads.
461573
wait (bool): whether to wait for the ingestion to finish or not.
462574
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
463575
if timeout is reached.
464576
465577
Returns:
466578
An instance of IngestionManagerPandas.
467579
"""
580+
if max_processes <= 0:
581+
raise RuntimeError("max_processes must be greater than 0.")
582+
583+
if max_workers <= 0:
584+
raise RuntimeError("max_workers must be greater than 0.")
585+
468586
manager = IngestionManagerPandas(
469587
feature_group_name=self.name,
470-
sagemaker_session=self.sagemaker_session,
471-
data_frame=data_frame,
588+
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
472589
max_workers=max_workers,
590+
max_processes=max_processes,
473591
)
474-
manager.run(wait=wait, timeout=timeout)
592+
593+
manager.run(data_frame=data_frame, wait=wait, timeout=timeout)
594+
475595
return manager
476596

477597
def athena_query(self) -> AthenaQuery:

tests/integ/test_feature_store.py

+27
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,33 @@ def test_ingest_without_string_feature(
265265
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
266266

267267

268+
def test_ingest_multi_process(
269+
feature_store_session,
270+
role,
271+
feature_group_name,
272+
offline_store_s3_uri,
273+
pandas_data_frame,
274+
):
275+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
276+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
277+
278+
with cleanup_feature_group(feature_group):
279+
output = feature_group.create(
280+
s3_uri=offline_store_s3_uri,
281+
record_identifier_name="feature1",
282+
event_time_feature_name="feature3",
283+
role_arn=role,
284+
enable_online_store=True,
285+
)
286+
_wait_for_feature_group_create(feature_group)
287+
288+
feature_group.ingest(
289+
data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True
290+
)
291+
292+
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
293+
294+
268295
def _wait_for_feature_group_create(feature_group: FeatureGroup):
269296
status = feature_group.describe().get("FeatureGroupStatus")
270297
while status == "Creating":

0 commit comments

Comments
 (0)