Skip to content

fix: return all failed row indices in feature_group.ingest #2193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions src/sagemaker/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _ingest_single_batch(
sagemaker_session: Session,
start_index: int,
end_index: int,
):
) -> List[int]:
"""Ingest a single batch of DataFrame rows into FeatureStore.

Args:
Expand All @@ -177,50 +177,62 @@ def _ingest_single_batch(
sagemaker_session (Session): session instance to perform boto calls.
start_index (int): starting position to ingest in this batch.
end_index (int): ending position to ingest in this batch.

Returns:
List of row indices that failed to be ingested.
"""
logger.info("Started ingesting index %d to %d", start_index, end_index)
for row in data_frame[start_index:end_index].itertuples(index=False):
failed_rows = list()
for row in data_frame[start_index:end_index].itertuples():
record = [
FeatureValue(
feature_name=data_frame.columns[index], value_as_string=str(row[index])
feature_name=data_frame.columns[index - 1], value_as_string=str(row[index])
)
for index in range(len(row))
for index in range(1, len(row))
if pd.notna(row[index])
]
sagemaker_session.put_record(
feature_group_name=feature_group_name, record=[value.to_dict() for value in record]
)
try:
sagemaker_session.put_record(
feature_group_name=feature_group_name,
record=[value.to_dict() for value in record],
)
except Exception as e: # pylint: disable=broad-except
logger.error("Failed to ingest row %d: %s", row[0], e)
failed_rows.append(row[0])
return failed_rows

def wait(self, timeout=None):
def wait(self, timeout=None) -> List[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke offline, I think the wait method returning a value seems a bit weird in my opinion.

"""Wait for the ingestion process to finish.

Args:
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.

Returns:
List of row indices that failed to be ingested.
"""
failed = False
failed = []
for future in as_completed(self._futures, timeout=timeout):
start, end = self._futures[future]
try:
future.result()
except Exception as e: # pylint: disable=broad-except
failed = True
logger.error("Failed to ingest row %d to %d: %s", start, end, e)
result = future.result()
if result:
logger.error("Failed to ingest row %d to %d", start, end)
else:
logger.info("Successfully ingested row %d to %d", start, end)
failed += result

if failed:
raise RuntimeError(
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}"
)
return failed

def run(self, wait=True, timeout=None):
def run(self, wait=True, timeout=None) -> List[int]:
"""Start the ingestion process.

Args:
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.

Returns:
List of row indices that failed to be ingested.
"""
executor = ThreadPoolExecutor(max_workers=self.max_workers)
batch_size = math.ceil(self.data_frame.shape[0] / self.max_workers)
Expand All @@ -241,9 +253,11 @@ def run(self, wait=True, timeout=None):
] = (start_index, end_index)

self._futures = futures
failed = []
if wait:
self.wait(timeout=timeout)
failed = self.wait(timeout=timeout)
executor.shutdown(wait=False)
return failed


@attr.s
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def test_create_feature_store(
ingestion_manager = feature_group.ingest(
data_frame=pandas_data_frame, max_workers=3, wait=False
)
ingestion_manager.wait()
failed = ingestion_manager.wait()
assert 0 == len(failed)

# Query the integrated Glue table.
athena_query = feature_group.athena_query()
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/sagemaker/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,17 @@ def test_ingestion_manager_run_success():

@patch(
"sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
MagicMock(side_effect=Exception("Failed!")),
MagicMock(return_value=[1]),
)
def test_ingestion_manager_run_failure():
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
manager = IngestionManagerPandas(
feature_group_name="MyGroup",
sagemaker_session=sagemaker_session_mock,
data_frame=df,
max_workers=10,
max_workers=1,
)
with pytest.raises(RuntimeError) as error:
manager.run()
assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
assert manager.run() == [1]


@pytest.fixture
Expand Down