Skip to content

Commit 1474cc2

Browse files
ericpark3Namrata Madan
authored and
Namrata Madan
committed
fix: use FeatureGroup's Session in nonconcurrency ingestion (aws#3617)
1 parent a8947d6 commit 1474cc2

File tree

2 files changed

+115
-22
lines changed

2 files changed

+115
-22
lines changed

src/sagemaker/feature_store/feature_group.py

+79-19
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class IngestionManagerPandas:
165165
feature_group_name (str): name of the Feature Group.
166166
sagemaker_fs_runtime_client_config (Config): instance of the Config class
167167
for boto calls.
168+
sagemaker_session (Session): session instance to perform boto calls.
168169
data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
169170
max_workers (int): number of threads to create.
170171
max_processes (int): number of processes to create. Each process spawns
@@ -174,7 +175,8 @@ class IngestionManagerPandas:
174175
"""
175176

176177
feature_group_name: str = attr.ib()
177-
sagemaker_fs_runtime_client_config: Config = attr.ib()
178+
sagemaker_fs_runtime_client_config: Config = attr.ib(default=None)
179+
sagemaker_session: Session = attr.ib(default=None)
178180
max_workers: int = attr.ib(default=1)
179181
max_processes: int = attr.ib(default=1)
180182
profile_name: str = attr.ib(default=None)
@@ -210,29 +212,20 @@ def _ingest_single_batch(
210212
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
211213
client_config = copy.deepcopy(client_config)
212214
client_config.retries = {"max_attempts": 10, "mode": "standard"}
213-
sagemaker_featurestore_runtime_client = boto3.Session(profile_name=profile_name).client(
215+
sagemaker_fs_runtime_client = boto3.Session(profile_name=profile_name).client(
214216
service_name="sagemaker-featurestore-runtime", config=client_config
215217
)
216218

217219
logger.info("Started ingesting index %d to %d", start_index, end_index)
218220
failed_rows = list()
219221
for row in data_frame[start_index:end_index].itertuples():
220-
record = [
221-
FeatureValue(
222-
feature_name=data_frame.columns[index - 1],
223-
value_as_string=str(row[index]),
224-
)
225-
for index in range(1, len(row))
226-
if pd.notna(row[index])
227-
]
228-
try:
229-
sagemaker_featurestore_runtime_client.put_record(
230-
FeatureGroupName=feature_group_name,
231-
Record=[value.to_dict() for value in record],
232-
)
233-
except Exception as e: # pylint: disable=broad-except
234-
logger.error("Failed to ingest row %d: %s", row[0], e)
235-
failed_rows.append(row[0])
222+
IngestionManagerPandas._ingest_row(
223+
data_frame=data_frame,
224+
row=row,
225+
feature_group_name=feature_group_name,
226+
sagemaker_fs_runtime_client=sagemaker_fs_runtime_client,
227+
failed_rows=failed_rows,
228+
)
236229
return failed_rows
237230

238231
@property
@@ -274,6 +267,69 @@ def wait(self, timeout=None):
274267
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
275268
)
276269

270+
@staticmethod
271+
def _ingest_row(
272+
data_frame: DataFrame,
273+
row: int,
274+
feature_group_name: str,
275+
sagemaker_fs_runtime_client: Session,
276+
failed_rows: List[int],
277+
):
278+
"""Ingest a single Dataframe row into FeatureStore.
279+
280+
Args:
281+
data_frame (DataFrame): source DataFrame to be ingested.
282+
row (int): current row that is being ingested
283+
feature_group_name (str): name of the Feature Group.
284+
sagemaker_featurestore_runtime_client (Session): session instance to perform boto calls.
285+
failed_rows (List[int]): list of indices from the data frame for which ingestion failed.
286+
287+
288+
Returns:
289+
int of row indices that failed to be ingested.
290+
"""
291+
record = [
292+
FeatureValue(
293+
feature_name=data_frame.columns[index - 1],
294+
value_as_string=str(row[index]),
295+
)
296+
for index in range(1, len(row))
297+
if pd.notna(row[index])
298+
]
299+
try:
300+
sagemaker_fs_runtime_client.put_record(
301+
FeatureGroupName=feature_group_name,
302+
Record=[value.to_dict() for value in record],
303+
)
304+
except Exception as e: # pylint: disable=broad-except
305+
logger.error("Failed to ingest row %d: %s", row[0], e)
306+
failed_rows.append(row[0])
307+
308+
def _run_single_process_single_thread(self, data_frame: DataFrame):
309+
"""Ingest a utilizing single process and single thread.
310+
311+
Args:
312+
data_frame (DataFrame): source DataFrame to be ingested.
313+
"""
314+
logger.info("Started ingesting index %d to %d")
315+
failed_rows = list()
316+
sagemaker_fs_runtime_client = self.sagemaker_session.sagemaker_featurestore_runtime_client
317+
for row in data_frame.itertuples():
318+
IngestionManagerPandas._ingest_row(
319+
data_frame=data_frame,
320+
row=row,
321+
feature_group_name=self.feature_group_name,
322+
sagemaker_fs_runtime_client=sagemaker_fs_runtime_client,
323+
failed_rows=failed_rows,
324+
)
325+
self._failed_indices = failed_rows
326+
327+
if len(self._failed_indices) > 0:
328+
raise IngestionError(
329+
self._failed_indices,
330+
f"Failed to ingest some data into FeatureGroup {self.feature_group_name}",
331+
)
332+
277333
def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
278334
"""Start the ingestion process with the specified number of processes.
279335
@@ -385,7 +441,10 @@ def run(self, data_frame: DataFrame, wait=True, timeout=None):
385441
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
386442
if timeout is reached.
387443
"""
388-
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
444+
if self.max_workers == 1 and self.max_processes == 1 and self.profile_name is None:
445+
self._run_single_process_single_thread(data_frame=data_frame)
446+
else:
447+
self._run_multi_process(data_frame=data_frame, wait=wait, timeout=timeout)
389448

390449

391450
class IngestionError(Exception):
@@ -749,6 +808,7 @@ def ingest(
749808

750809
manager = IngestionManagerPandas(
751810
feature_group_name=self.name,
811+
sagemaker_session=self.sagemaker_session,
752812
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
753813
max_workers=max_workers,
754814
max_processes=max_processes,

tests/unit/sagemaker/feature_store/test_feature_group.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
307307

308308
ingestion_manager_init.assert_called_once_with(
309309
feature_group_name="MyGroup",
310+
sagemaker_session=sagemaker_session_mock,
310311
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
311312
max_workers=10,
312313
max_processes=1,
@@ -317,6 +318,32 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
317318
)
318319

319320

321+
@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
322+
def test_ingest_default(ingestion_manager_init, sagemaker_session_mock):
323+
sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
324+
fs_runtime_client_config_mock
325+
)
326+
327+
feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
328+
df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
329+
330+
mock_ingestion_manager_instance = Mock()
331+
ingestion_manager_init.return_value = mock_ingestion_manager_instance
332+
feature_group.ingest(data_frame=df)
333+
334+
ingestion_manager_init.assert_called_once_with(
335+
feature_group_name="MyGroup",
336+
sagemaker_session=sagemaker_session_mock,
337+
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
338+
max_workers=1,
339+
max_processes=1,
340+
profile_name=None,
341+
)
342+
mock_ingestion_manager_instance.run.assert_called_once_with(
343+
data_frame=df, wait=True, timeout=None
344+
)
345+
346+
320347
@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
321348
def test_ingest_with_profile_name(
322349
ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
@@ -334,6 +361,7 @@ def test_ingest_with_profile_name(
334361

335362
ingestion_manager_init.assert_called_once_with(
336363
feature_group_name="MyGroup",
364+
sagemaker_session=sagemaker_session_mock,
337365
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
338366
max_workers=10,
339367
max_processes=1,
@@ -403,6 +431,7 @@ def test_ingestion_manager_run_success():
403431
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
404432
manager = IngestionManagerPandas(
405433
feature_group_name="MyGroup",
434+
sagemaker_session=sagemaker_session_mock,
406435
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
407436
max_workers=10,
408437
)
@@ -421,6 +450,7 @@ def test_ingestion_manager_run_multi_process_with_multi_thread_success(
421450
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
422451
manager = IngestionManagerPandas(
423452
feature_group_name="MyGroup",
453+
sagemaker_session=sagemaker_session_mock,
424454
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
425455
max_workers=2,
426456
max_processes=2,
@@ -436,16 +466,17 @@ def test_ingestion_manager_run_failure():
436466
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
437467
manager = IngestionManagerPandas(
438468
feature_group_name="MyGroup",
469+
sagemaker_session=sagemaker_session_mock,
439470
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
440-
max_workers=1,
471+
max_workers=2,
441472
)
442473

443474
with pytest.raises(IngestionError) as error:
444475
manager.run(df)
445476

446477
assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
447-
assert error.value.failed_rows == [1]
448-
assert manager.failed_rows == [1]
478+
assert error.value.failed_rows == [1, 1]
479+
assert manager.failed_rows == [1, 1]
449480

450481

451482
@patch(
@@ -456,6 +487,7 @@ def test_ingestion_manager_with_profile_name_run_failure():
456487
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
457488
manager = IngestionManagerPandas(
458489
feature_group_name="MyGroup",
490+
sagemaker_session=sagemaker_session_mock,
459491
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
460492
max_workers=1,
461493
profile_name="non_exist",
@@ -475,6 +507,7 @@ def test_ingestion_manager_run_multi_process_failure():
475507
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
476508
manager = IngestionManagerPandas(
477509
feature_group_name="MyGroup",
510+
sagemaker_session=None,
478511
sagemaker_fs_runtime_client_config=None,
479512
max_workers=2,
480513
max_processes=2,

0 commit comments

Comments
 (0)