Skip to content

feature: support multiprocess feature group ingest (#2111) #2288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def read_version():
"importlib-metadata>=1.4.0",
"packaging>=20.0",
"pandas",
"pathos",
]

# Specific use case dependencies
Expand Down
190 changes: 155 additions & 35 deletions src/sagemaker/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@
from typing import Sequence, List, Dict, Any, Union
from urllib.parse import urlparse

from multiprocessing.pool import AsyncResult
import signal
import attr
import pandas as pd
from pandas import DataFrame

import boto3
from botocore.config import Config
from pathos.multiprocessing import ProcessingPool

from sagemaker import Session
from sagemaker.feature_store.feature_definition import (
FeatureDefinition,
Expand Down Expand Up @@ -150,23 +156,27 @@ class IngestionManagerPandas:

Attributes:
feature_group_name (str): name of the Feature Group.
sagemaker_session (Session): instance of the Session class to perform boto calls.
sagemaker_fs_runtime_client_config (Config): instance of the Config class
for boto calls.
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
max_workers (int): number of threads to create.
max_processes (int): number of processes to create. Each process spawns
``max_workers`` threads.
"""

feature_group_name: str = attr.ib()
sagemaker_session: Session = attr.ib()
data_frame: DataFrame = attr.ib()
sagemaker_fs_runtime_client_config: Config = attr.ib()
max_workers: int = attr.ib(default=1)
_futures: Dict[Any, Any] = attr.ib(init=False, factory=dict)
max_processes: int = attr.ib(default=1)
_async_result: AsyncResult = attr.ib(default=None)
_processing_pool: ProcessingPool = attr.ib(default=None)
_failed_indices: List[int] = attr.ib(factory=list)

@staticmethod
def _ingest_single_batch(
data_frame: DataFrame,
feature_group_name: str,
sagemaker_session: Session,
client_config: Config,
start_index: int,
end_index: int,
) -> List[int]:
Expand All @@ -175,13 +185,18 @@ def _ingest_single_batch(
Args:
data_frame (DataFrame): source DataFrame to be ingested.
feature_group_name (str): name of the Feature Group.
sagemaker_session (Session): session instance to perform boto calls.
client_config (Config): Configuration for the sagemaker feature store runtime
client to perform boto calls.
start_index (int): starting position to ingest in this batch.
end_index (int): ending position to ingest in this batch.

Returns:
List of row indices that failed to be ingested.
"""
sagemaker_featurestore_runtime_client = boto3.Session().client(
service_name="sagemaker-featurestore-runtime", config=client_config
)

logger.info("Started ingesting index %d to %d", start_index, end_index)
failed_rows = list()
for row in data_frame[start_index:end_index].itertuples():
Expand All @@ -193,9 +208,9 @@ def _ingest_single_batch(
if pd.notna(row[index])
]
try:
sagemaker_session.put_record(
feature_group_name=feature_group_name,
record=[value.to_dict() for value in record],
sagemaker_featurestore_runtime_client.put_record(
FeatureGroupName=feature_group_name,
Record=[value.to_dict() for value in record],
)
except Exception as e: # pylint: disable=broad-except
logger.error("Failed to ingest row %d: %s", row[0], e)
Expand All @@ -204,7 +219,7 @@ def _ingest_single_batch(

@property
def failed_rows(self) -> List[int]:
"""Get rows that failed to ingest
"""Get rows that failed to ingest.

Returns:
List of row indices that failed to be ingested.
Expand All @@ -218,52 +233,134 @@ def wait(self, timeout=None):
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
"""
self._failed_indices = list()
for future in as_completed(self._futures, timeout=timeout):
start, end = self._futures[future]
result = future.result()
if result:
logger.error("Failed to ingest row %d to %d", start, end)
else:
logger.info("Successfully ingested row %d to %d", start, end)
self._failed_indices += result
try:
results = self._async_result.get(timeout=timeout)
except KeyboardInterrupt as i:
# terminate workers abruptly on keyboard interrupt.
self._processing_pool.terminate()
self._processing_pool.close()
self._processing_pool.clear()
raise i
else:
# terminate normally
self._processing_pool.close()
self._processing_pool.clear()

self._failed_indices = [
failed_index for failed_indices in results for failed_index in failed_indices
]

if len(self._failed_indices) > 0:
raise RuntimeError(
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}"
raise IngestionError(
self._failed_indices,
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
)

def run(self, wait=True, timeout=None):
def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
"""Start the ingestion process with the specified number of processes.

Args:
data_frame (DataFrame): source DataFrame to be ingested.
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
"""
batch_size = math.ceil(data_frame.shape[0] / self.max_processes)

args = []
for i in range(self.max_processes):
start_index = min(i * batch_size, data_frame.shape[0])
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
args += [(data_frame[start_index:end_index], start_index, timeout)]
Comment on lines +270 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional

Suggested change
args = []
for i in range(self.max_processes):
start_index = min(i * batch_size, data_frame.shape[0])
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
args += [(data_frame[start_index:end_index], start_index, timeout)]
args = [
data_frame[
min(i * batch_size, data_frame.shape[0]):
min(i * batch_size + batch_size, data_frame.shape[0]))
],
min(i * batch_size, data_frame.shape[0]), timeout)
for i in range(self.max_processes)
]


def init_worker():
# ignore keyboard interrupts in child processes.
signal.signal(signal.SIGINT, signal.SIG_IGN)

self._processing_pool = ProcessingPool(self.max_processes, init_worker)
self._processing_pool.restart(force=True)

f = lambda x: self._run_multi_threaded(*x) # noqa: E731
self._async_result = self._processing_pool.amap(f, args)

if wait:
self.wait(timeout=timeout)

def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None) -> List[int]:
"""Start the ingestion process.

Args:
data_frame (DataFrame): source DataFrame to be ingested.
row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
index of the parent where ``data_frame`` starts. Otherwise, 0.
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
if timeout is reached.

Returns:
List of row indices that failed to be ingested.
"""
executor = ThreadPoolExecutor(max_workers=self.max_workers)
batch_size = math.ceil(self.data_frame.shape[0] / self.max_workers)
batch_size = math.ceil(data_frame.shape[0] / self.max_workers)

futures = {}
for i in range(self.max_workers):
start_index = min(i * batch_size, self.data_frame.shape[0])
end_index = min(i * batch_size + batch_size, self.data_frame.shape[0])
start_index = min(i * batch_size, data_frame.shape[0])
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
futures[
executor.submit(
self._ingest_single_batch,
feature_group_name=self.feature_group_name,
sagemaker_session=self.sagemaker_session,
data_frame=self.data_frame,
data_frame=data_frame,
start_index=start_index,
end_index=end_index,
client_config=self.sagemaker_fs_runtime_client_config,
)
] = (start_index, end_index)
] = (start_index + row_offset, end_index + row_offset)

failed_indices = list()
for future in as_completed(futures, timeout=timeout):
start, end = futures[future]
result = future.result()
if result:
logger.error("Failed to ingest row %d to %d", start, end)
else:
logger.info("Successfully ingested row %d to %d", start, end)
failed_indices += result

self._futures = futures
if wait:
self.wait(timeout=timeout)
executor.shutdown(wait=False)

return failed_indices

def run(self, data_frame: DataFrame, wait=True, timeout=None):
"""Start the ingestion process.

Args:
data_frame (DataFrame): source DataFrame to be ingested.
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
"""
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)


class IngestionError(Exception):
"""Exception raised for errors during ingestion.

Attributes:
failed_rows: list of indices from the data frame for which ingestion failed.
message: explanation of the error
"""

def __init__(self, failed_rows, message):
super(IngestionError, self).__init__(message)
self.failed_rows = failed_rows
self.message = message

def __str__(self) -> str:
"""String representation of the error."""
return f"{self.failed_rows} -> {self.message}"


@attr.s
class FeatureGroup:
Expand Down Expand Up @@ -447,6 +544,7 @@ def ingest(
self,
data_frame: DataFrame,
max_workers: int = 1,
max_processes: int = 1,
wait: bool = True,
timeout: Union[int, float] = None,
) -> IngestionManagerPandas:
Expand All @@ -455,23 +553,45 @@ def ingest(
``max_worker`` number of thread will be created to work on different partitions of
the ``data_frame`` in parallel.

``max_processes`` number of processes will be created to work on different partitions
of the ``data_frame`` in parallel, each with ``max_worker`` threads.

The ingest function will attempt to ingest all records in the data frame. If ``wait``
is True, then an exception is thrown after all records have been processed. If ``wait``
is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
function will throw an exception.

Zero based indices of rows that failed to be ingested can be found in the exception.
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
the exception is thrown.

Args:
data_frame (DataFrame): data_frame to be ingested to feature store.
max_workers (int): number of threads to be created.
max_processes (int): number of processes to be created. Each process spawns
``max_worker`` number of threads.
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.

Returns:
An instance of IngestionManagerPandas.
"""
if max_processes <= 0:
raise RuntimeError("max_processes must be greater than 0.")

if max_workers <= 0:
raise RuntimeError("max_workers must be greater than 0.")

manager = IngestionManagerPandas(
feature_group_name=self.name,
sagemaker_session=self.sagemaker_session,
data_frame=data_frame,
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
max_workers=max_workers,
max_processes=max_processes,
)
manager.run(wait=wait, timeout=timeout)

manager.run(data_frame=data_frame, wait=wait, timeout=timeout)

return manager

def athena_query(self) -> AthenaQuery:
Expand Down
27 changes: 27 additions & 0 deletions tests/integ/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,33 @@ def test_ingest_without_string_feature(
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")


def test_ingest_multi_process(
feature_store_session,
role,
feature_group_name,
offline_store_s3_uri,
pandas_data_frame,
):
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
feature_group.load_feature_definitions(data_frame=pandas_data_frame)

with cleanup_feature_group(feature_group):
output = feature_group.create(
s3_uri=offline_store_s3_uri,
record_identifier_name="feature1",
event_time_feature_name="feature3",
role_arn=role,
enable_online_store=True,
)
_wait_for_feature_group_create(feature_group)

feature_group.ingest(
data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True
)

assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")


def _wait_for_feature_group_create(feature_group: FeatureGroup):
status = feature_group.describe().get("FeatureGroupStatus")
while status == "Creating":
Expand Down
Loading