Skip to content

Commit 0eaa7c0

Browse files
verayu43ahsan-z-khanshreyapandit
authored
fix: multiprocess issue in feature_group.py (#2573)
Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent 9ca12bc commit 0eaa7c0

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

src/sagemaker/feature_store/feature_group.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def _ingest_single_batch(
207207
for row in data_frame[start_index:end_index].itertuples():
208208
record = [
209209
FeatureValue(
210-
feature_name=data_frame.columns[index - 1], value_as_string=str(row[index])
210+
feature_name=data_frame.columns[index - 1],
211+
value_as_string=str(row[index]),
211212
)
212213
for index in range(1, len(row))
213214
if pd.notna(row[index])
@@ -270,13 +271,24 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
270271
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
271272
if timeout is reached.
272273
"""
274+
# pylint: disable=I1101
273275
batch_size = math.ceil(data_frame.shape[0] / self.max_processes)
276+
# pylint: enable=I1101
274277

275278
args = []
276279
for i in range(self.max_processes):
277280
start_index = min(i * batch_size, data_frame.shape[0])
278281
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
279-
args += [(data_frame[start_index:end_index], start_index, timeout)]
282+
args += [
283+
(
284+
self.max_workers,
285+
self.feature_group_name,
286+
self.sagemaker_fs_runtime_client_config,
287+
data_frame[start_index:end_index],
288+
start_index,
289+
timeout,
290+
)
291+
]
280292

281293
def init_worker():
282294
# ignore keyboard interrupts in child processes.
@@ -285,13 +297,21 @@ def init_worker():
285297
self._processing_pool = ProcessingPool(self.max_processes, init_worker)
286298
self._processing_pool.restart(force=True)
287299

288-
f = lambda x: self._run_multi_threaded(*x) # noqa: E731
300+
f = lambda x: IngestionManagerPandas._run_multi_threaded(*x) # noqa: E731
289301
self._async_result = self._processing_pool.amap(f, args)
290302

291303
if wait:
292304
self.wait(timeout=timeout)
293305

294-
def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None) -> List[int]:
306+
@staticmethod
307+
def _run_multi_threaded(
308+
max_workers: int,
309+
feature_group_name: str,
310+
sagemaker_fs_runtime_client_config: Config,
311+
data_frame: DataFrame,
312+
row_offset=0,
313+
timeout=None,
314+
) -> List[int]:
295315
"""Start the ingestion process.
296316
297317
Args:
@@ -305,21 +325,23 @@ def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None)
305325
Returns:
306326
List of row indices that failed to be ingested.
307327
"""
308-
executor = ThreadPoolExecutor(max_workers=self.max_workers)
309-
batch_size = math.ceil(data_frame.shape[0] / self.max_workers)
328+
executor = ThreadPoolExecutor(max_workers=max_workers)
329+
# pylint: disable=I1101
330+
batch_size = math.ceil(data_frame.shape[0] / max_workers)
331+
# pylint: enable=I1101
310332

311333
futures = {}
312-
for i in range(self.max_workers):
334+
for i in range(max_workers):
313335
start_index = min(i * batch_size, data_frame.shape[0])
314336
end_index = min(i * batch_size + batch_size, data_frame.shape[0])
315337
futures[
316338
executor.submit(
317-
self._ingest_single_batch,
318-
feature_group_name=self.feature_group_name,
339+
IngestionManagerPandas._ingest_single_batch,
340+
feature_group_name=feature_group_name,
319341
data_frame=data_frame,
320342
start_index=start_index,
321343
end_index=end_index,
322-
client_config=self.sagemaker_fs_runtime_client_config,
344+
client_config=sagemaker_fs_runtime_client_config,
323345
)
324346
] = (start_index + row_offset, end_index + row_offset)
325347

0 commit comments

Comments
 (0)