Skip to content

Commit 14674eb

Browse files
ericpark3JoseJuan98
authored andcommitted
fix: use FeatureGroup's Session in nonconcurrency ingestion (aws#3617)
1 parent 8d138ff commit 14674eb

File tree

2 files changed

+115
-22
lines changed

2 files changed

+115
-22
lines changed

src/sagemaker/feature_store/feature_group.py

+79-19
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class IngestionManagerPandas:
171171
feature_group_name (str): name of the Feature Group.
172172
sagemaker_fs_runtime_client_config (Config): instance of the Config class
173173
for boto calls.
174+
sagemaker_session (Session): session instance to perform boto calls.
174175
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
175176
max_workers (int): number of threads to create.
176177
max_processes (int): number of processes to create. Each process spawns
@@ -180,7 +181,8 @@ class IngestionManagerPandas:
180181
"""
181182

182183
feature_group_name: str = attr.ib()
183-
sagemaker_fs_runtime_client_config: Config = attr.ib()
184+
sagemaker_fs_runtime_client_config: Config = attr.ib(default=None)
185+
sagemaker_session: Session = attr.ib(default=None)
184186
max_workers: int = attr.ib(default=1)
185187
max_processes: int = attr.ib(default=1)
186188
profile_name: str = attr.ib(default=None)
@@ -216,29 +218,20 @@ def _ingest_single_batch(
216218
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
217219
client_config = copy.deepcopy(client_config)
218220
client_config.retries = {"max_attempts": 10, "mode": "standard"}
219-
sagemaker_featurestore_runtime_client = boto3.Session(profile_name=profile_name).client(
221+
sagemaker_fs_runtime_client = boto3.Session(profile_name=profile_name).client(
220222
service_name="sagemaker-featurestore-runtime", config=client_config
221223
)
222224

223225
logger.info("Started ingesting index %d to %d", start_index, end_index)
224226
failed_rows = list()
225227
for row in data_frame[start_index:end_index].itertuples():
226-
record = [
227-
FeatureValue(
228-
feature_name=data_frame.columns[index - 1],
229-
value_as_string=str(row[index]),
230-
)
231-
for index in range(1, len(row))
232-
if pd.notna(row[index])
233-
]
234-
try:
235-
sagemaker_featurestore_runtime_client.put_record(
236-
FeatureGroupName=feature_group_name,
237-
Record=[value.to_dict() for value in record],
238-
)
239-
except Exception as e: # pylint: disable=broad-except
240-
logger.error("Failed to ingest row %d: %s", row[0], e)
241-
failed_rows.append(row[0])
228+
IngestionManagerPandas._ingest_row(
229+
data_frame=data_frame,
230+
row=row,
231+
feature_group_name=feature_group_name,
232+
sagemaker_fs_runtime_client=sagemaker_fs_runtime_client,
233+
failed_rows=failed_rows,
234+
)
242235
return failed_rows
243236

244237
@property
@@ -280,6 +273,69 @@ def wait(self, timeout=None):
280273
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
281274
)
282275

276+
@staticmethod
277+
def _ingest_row(
278+
data_frame: DataFrame,
279+
row: int,
280+
feature_group_name: str,
281+
sagemaker_fs_runtime_client: Session,
282+
failed_rows: List[int],
283+
):
284+
"""Ingest a single Dataframe row into FeatureStore.
285+
286+
Args:
287+
data_frame (DataFrame): source DataFrame to be ingested.
288+
row (int): current row that is being ingested
289+
feature_group_name (str): name of the Feature Group.
290+
sagemaker_featurestore_runtime_client (Session): session instance to perform boto calls.
291+
failed_rows (List[int]): list of indices from the data frame for which ingestion failed.
292+
293+
294+
Returns:
295+
int of row indices that failed to be ingested.
296+
"""
297+
record = [
298+
FeatureValue(
299+
feature_name=data_frame.columns[index - 1],
300+
value_as_string=str(row[index]),
301+
)
302+
for index in range(1, len(row))
303+
if pd.notna(row[index])
304+
]
305+
try:
306+
sagemaker_fs_runtime_client.put_record(
307+
FeatureGroupName=feature_group_name,
308+
Record=[value.to_dict() for value in record],
309+
)
310+
except Exception as e: # pylint: disable=broad-except
311+
logger.error("Failed to ingest row %d: %s", row[0], e)
312+
failed_rows.append(row[0])
313+
314+
def _run_single_process_single_thread(self, data_frame: DataFrame):
315+
"""Ingest a utilizing single process and single thread.
316+
317+
Args:
318+
data_frame (DataFrame): source DataFrame to be ingested.
319+
"""
320+
logger.info("Started ingesting index %d to %d")
321+
failed_rows = list()
322+
sagemaker_fs_runtime_client = self.sagemaker_session.sagemaker_featurestore_runtime_client
323+
for row in data_frame.itertuples():
324+
IngestionManagerPandas._ingest_row(
325+
data_frame=data_frame,
326+
row=row,
327+
feature_group_name=self.feature_group_name,
328+
sagemaker_fs_runtime_client=sagemaker_fs_runtime_client,
329+
failed_rows=failed_rows,
330+
)
331+
self._failed_indices = failed_rows
332+
333+
if len(self._failed_indices) > 0:
334+
raise IngestionError(
335+
self._failed_indices,
336+
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
337+
)
338+
283339
def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
284340
"""Start the ingestion process with the specified number of processes.
285341
@@ -391,7 +447,10 @@ def run(self, data_frame: DataFrame, wait=True, timeout=None):
391447
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
392448
if timeout is reached.
393449
"""
394-
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
450+
if self.max_workers == 1 and self.max_processes == 1 and self.profile_name is None:
451+
self._run_single_process_single_thread(data_frame=data_frame)
452+
else:
453+
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
395454

396455

397456
class IngestionError(Exception):
@@ -755,6 +814,7 @@ def ingest(
755814

756815
manager = IngestionManagerPandas(
757816
feature_group_name=self.name,
817+
sagemaker_session=self.sagemaker_session,
758818
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
759819
max_workers=max_workers,
760820
max_processes=max_processes,

tests/unit/sagemaker/feature_store/test_feature_group.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
307307

308308
ingestion_manager_init.assert_called_once_with(
309309
feature_group_name="MyGroup",
310+
sagemaker_session=sagemaker_session_mock,
310311
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
311312
max_workers=10,
312313
max_processes=1,
@@ -317,6 +318,32 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
317318
)
318319

319320

321+
@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
322+
def test_ingest_default(ingestion_manager_init, sagemaker_session_mock):
323+
sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
324+
fs_runtime_client_config_mock
325+
)
326+
327+
feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
328+
df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
329+
330+
mock_ingestion_manager_instance = Mock()
331+
ingestion_manager_init.return_value = mock_ingestion_manager_instance
332+
feature_group.ingest(data_frame=df)
333+
334+
ingestion_manager_init.assert_called_once_with(
335+
feature_group_name="MyGroup",
336+
sagemaker_session=sagemaker_session_mock,
337+
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
338+
max_workers=1,
339+
max_processes=1,
340+
profile_name=None,
341+
)
342+
mock_ingestion_manager_instance.run.assert_called_once_with(
343+
data_frame=df, wait=True, timeout=None
344+
)
345+
346+
320347
@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
321348
def test_ingest_with_profile_name(
322349
ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
@@ -334,6 +361,7 @@ def test_ingest_with_profile_name(
334361

335362
ingestion_manager_init.assert_called_once_with(
336363
feature_group_name="MyGroup",
364+
sagemaker_session=sagemaker_session_mock,
337365
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
338366
max_workers=10,
339367
max_processes=1,
@@ -403,6 +431,7 @@ def test_ingestion_manager_run_success():
403431
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
404432
manager = IngestionManagerPandas(
405433
feature_group_name="MyGroup",
434+
sagemaker_session=sagemaker_session_mock,
406435
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
407436
max_workers=10,
408437
)
@@ -421,6 +450,7 @@ def test_ingestion_manager_run_multi_process_with_multi_thread_success(
421450
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
422451
manager = IngestionManagerPandas(
423452
feature_group_name="MyGroup",
453+
sagemaker_session=sagemaker_session_mock,
424454
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
425455
max_workers=2,
426456
max_processes=2,
@@ -436,16 +466,17 @@ def test_ingestion_manager_run_failure():
436466
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
437467
manager = IngestionManagerPandas(
438468
feature_group_name="MyGroup",
469+
sagemaker_session=sagemaker_session_mock,
439470
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
440-
max_workers=1,
471+
max_workers=2,
441472
)
442473

443474
with pytest.raises(IngestionError) as error:
444475
manager.run(df)
445476

446477
assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
447-
assert error.value.failed_rows == [1]
448-
assert manager.failed_rows == [1]
478+
assert error.value.failed_rows == [1, 1]
479+
assert manager.failed_rows == [1, 1]
449480

450481

451482
@patch(
@@ -456,6 +487,7 @@ def test_ingestion_manager_with_profile_name_run_failure():
456487
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
457488
manager = IngestionManagerPandas(
458489
feature_group_name="MyGroup",
490+
sagemaker_session=sagemaker_session_mock,
459491
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
460492
max_workers=1,
461493
profile_name="non_exist",
@@ -475,6 +507,7 @@ def test_ingestion_manager_run_multi_process_failure():
475507
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
476508
manager = IngestionManagerPandas(
477509
feature_group_name="MyGroup",
510+
sagemaker_session=None,
478511
sagemaker_fs_runtime_client_config=None,
479512
max_workers=2,
480513
max_processes=2,

0 commit comments

Comments
 (0)