diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 6a9bdfd9d4..e8298f2b4b 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -160,6 +160,7 @@ class IngestionManagerPandas: data_frame: DataFrame = attr.ib() max_workers: int = attr.ib(default=1) _futures: Dict[Any, Any] = attr.ib(init=False, factory=dict) + _failed_indices: List[int] = attr.ib(factory=list) @staticmethod def _ingest_single_batch( @@ -168,7 +169,7 @@ def _ingest_single_batch( sagemaker_session: Session, start_index: int, end_index: int, - ): + ) -> List[int]: """Ingest a single batch of DataFrame rows into FeatureStore. Args: @@ -177,19 +178,38 @@ def _ingest_single_batch( sagemaker_session (Session): session instance to perform boto calls. start_index (int): starting position to ingest in this batch. end_index (int): ending position to ingest in this batch. + + Returns: + List of row indices that failed to be ingested. """ logger.info("Started ingesting index %d to %d", start_index, end_index) - for row in data_frame[start_index:end_index].itertuples(index=False): + failed_rows = list() + for row in data_frame[start_index:end_index].itertuples(): record = [ FeatureValue( - feature_name=data_frame.columns[index], value_as_string=str(row[index]) + feature_name=data_frame.columns[index - 1], value_as_string=str(row[index]) ) - for index in range(len(row)) + for index in range(1, len(row)) if pd.notna(row[index]) ] - sagemaker_session.put_record( - feature_group_name=feature_group_name, record=[value.to_dict() for value in record] - ) + try: + sagemaker_session.put_record( + feature_group_name=feature_group_name, + record=[value.to_dict() for value in record], + ) + except Exception as e: # pylint: disable=broad-except + logger.error("Failed to ingest row %d: %s", row[0], e) + failed_rows.append(row[0]) + return failed_rows + + @property + def failed_rows(self) -> List[int]: + """Get rows that failed to ingest + + Returns: + List of row indices that failed to be ingested. + """ + return self._failed_indices def wait(self, timeout=None): """Wait for the ingestion process to finish. @@ -198,18 +218,17 @@ def wait(self, timeout=None): timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised if timeout is reached. """ - failed = False + self._failed_indices = list() for future in as_completed(self._futures, timeout=timeout): start, end = self._futures[future] - try: - future.result() - except Exception as e: # pylint: disable=broad-except - failed = True - logger.error("Failed to ingest row %d to %d: %s", start, end, e) + result = future.result() + if result: + logger.error("Failed to ingest row %d to %d", start, end) else: logger.info("Successfully ingested row %d to %d", start, end) + self._failed_indices += result - if failed: + if len(self._failed_indices) > 0: raise RuntimeError( f"Failed to ingest some data into FeatureGroup {self.feature_group_name}" ) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index ae2a899b2f..c527116a3a 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -197,6 +197,7 @@ def test_create_feature_store( data_frame=pandas_data_frame, max_workers=3, wait=False ) ingestion_manager.wait() + assert 0 == len(ingestion_manager.failed_rows) # Query the integrated Glue table. athena_query = feature_group.athena_query() diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index d2902fe853..b69e19dce6 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -274,7 +274,7 @@ def test_ingestion_manager_run_success(): @patch( "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(side_effect=Exception("Failed!")), + MagicMock(return_value=[1]), ) def test_ingestion_manager_run_failure(): df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) @@ -282,11 +282,12 @@ def test_ingestion_manager_run_failure(): feature_group_name="MyGroup", sagemaker_session=sagemaker_session_mock, data_frame=df, - max_workers=10, + max_workers=1, ) with pytest.raises(RuntimeError) as error: manager.run() assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert manager.failed_rows == [1] @pytest.fixture