diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 122c176217..02e683efd3 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -207,7 +207,8 @@ def _ingest_single_batch( for row in data_frame[start_index:end_index].itertuples(): record = [ FeatureValue( - feature_name=data_frame.columns[index - 1], value_as_string=str(row[index]) + feature_name=data_frame.columns[index - 1], + value_as_string=str(row[index]), ) for index in range(1, len(row)) if pd.notna(row[index]) @@ -270,13 +271,24 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None): timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised if timeout is reached. """ + # pylint: disable=I1101 batch_size = math.ceil(data_frame.shape[0] / self.max_processes) + # pylint: enable=I1101 args = [] for i in range(self.max_processes): start_index = min(i * batch_size, data_frame.shape[0]) end_index = min(i * batch_size + batch_size, data_frame.shape[0]) - args += [(data_frame[start_index:end_index], start_index, timeout)] + args += [ + ( + self.max_workers, + self.feature_group_name, + self.sagemaker_fs_runtime_client_config, + data_frame[start_index:end_index], + start_index, + timeout, + ) + ] def init_worker(): # ignore keyboard interrupts in child processes. @@ -285,13 +297,21 @@ def init_worker(): self._processing_pool = ProcessingPool(self.max_processes, init_worker) self._processing_pool.restart(force=True) - f = lambda x: self._run_multi_threaded(*x) # noqa: E731 + f = lambda x: IngestionManagerPandas._run_multi_threaded(*x) # noqa: E731 self._async_result = self._processing_pool.amap(f, args) if wait: self.wait(timeout=timeout) - def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None) -> List[int]: + @staticmethod + def _run_multi_threaded( + max_workers: int, + feature_group_name: str, + sagemaker_fs_runtime_client_config: Config, + data_frame: DataFrame, + row_offset=0, + timeout=None, + ) -> List[int]: """Start the ingestion process. Args: @@ -305,21 +325,23 @@ def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None) Returns: List of row indices that failed to be ingested. """ - executor = ThreadPoolExecutor(max_workers=self.max_workers) - batch_size = math.ceil(data_frame.shape[0] / self.max_workers) + executor = ThreadPoolExecutor(max_workers=max_workers) + # pylint: disable=I1101 + batch_size = math.ceil(data_frame.shape[0] / max_workers) + # pylint: enable=I1101 futures = {} - for i in range(self.max_workers): + for i in range(max_workers): start_index = min(i * batch_size, data_frame.shape[0]) end_index = min(i * batch_size + batch_size, data_frame.shape[0]) futures[ executor.submit( - self._ingest_single_batch, - feature_group_name=self.feature_group_name, + IngestionManagerPandas._ingest_single_batch, + feature_group_name=feature_group_name, data_frame=data_frame, start_index=start_index, end_index=end_index, - client_config=self.sagemaker_fs_runtime_client_config, + client_config=sagemaker_fs_runtime_client_config, ) ] = (start_index + row_offset, end_index + row_offset)