Skip to content

Commit 3c08b59

Browse files
committed
Add profile_name support for Feature Store ingestion
1 parent 99f023e commit 3c08b59

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/sagemaker/feature_store/feature_group.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,14 @@ class IngestionManagerPandas:
163163
max_workers (int): number of threads to create.
164164
max_processes (int): number of processes to create. Each process spawns
165165
``max_workers`` threads.
166+
profile_name (str): the profile credential should be used for ``PutRecord``.
166167
"""
167168

168169
feature_group_name: str = attr.ib()
169170
sagemaker_fs_runtime_client_config: Config = attr.ib()
170171
max_workers: int = attr.ib(default=1)
171172
max_processes: int = attr.ib(default=1)
173+
profile_name: str = attr.ib(default=None)
172174
_async_result: AsyncResult = attr.ib(default=None)
173175
_processing_pool: ProcessingPool = attr.ib(default=None)
174176
_failed_indices: List[int] = attr.ib(factory=list)
@@ -180,6 +182,7 @@ def _ingest_single_batch(
180182
client_config: Config,
181183
start_index: int,
182184
end_index: int,
185+
profile_name: str = None,
183186
) -> List[int]:
184187
"""Ingest a single batch of DataFrame rows into FeatureStore.
185188
@@ -190,6 +193,7 @@ def _ingest_single_batch(
190193
client to perform boto calls.
191194
start_index (int): starting position to ingest in this batch.
192195
end_index (int): ending position to ingest in this batch.
196+
profile_name (str): the profile credential should be used for ``PutRecord``.
193197
194198
Returns:
195199
List of row indices that failed to be ingested.
@@ -198,7 +202,7 @@ def _ingest_single_batch(
198202
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
199203
client_config = copy.deepcopy(client_config)
200204
client_config.retries = {"max_attempts": 10, "mode": "standard"}
201-
sagemaker_featurestore_runtime_client = boto3.Session().client(
205+
sagemaker_featurestore_runtime_client = boto3.Session(profile_name=profile_name).client(
202206
service_name="sagemaker-featurestore-runtime", config=client_config
203207
)
204208

@@ -287,6 +291,7 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
287291
data_frame[start_index:end_index],
288292
start_index,
289293
timeout,
294+
self.profile_name,
290295
)
291296
]
292297

@@ -311,6 +316,7 @@ def _run_multi_threaded(
311316
data_frame: DataFrame,
312317
row_offset=0,
313318
timeout=None,
319+
profile_name=None,
314320
) -> List[int]:
315321
"""Start the ingestion process.
316322
@@ -321,6 +327,7 @@ def _run_multi_threaded(
321327
wait (bool): whether to wait for the ingestion to finish or not.
322328
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
323329
if timeout is reached.
330+
profile_name (str): the profile credential should be used for ``PutRecord``.
324331
325332
Returns:
326333
List of row indices that failed to be ingested.
@@ -342,6 +349,7 @@ def _run_multi_threaded(
342349
start_index=start_index,
343350
end_index=end_index,
344351
client_config=sagemaker_fs_runtime_client_config,
352+
profile_name=profile_name,
345353
)
346354
] = (start_index + row_offset, end_index + row_offset)
347355

@@ -581,6 +589,7 @@ def ingest(
581589
max_processes: int = 1,
582590
wait: bool = True,
583591
timeout: Union[int, float] = None,
592+
profile_name: str = None,
584593
) -> IngestionManagerPandas:
585594
"""Ingest the content of a pandas DataFrame to feature store.
586595
@@ -599,6 +608,11 @@ def ingest(
599608
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
600609
the exception is thrown.
601610
611+
`profile_name` argument is an optional one. It will use the default credential if None is
612+
passed. This `profile_name` is used in the sagemaker_featurestore_runtime client only. See
613+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more
614+
about the default credential.
615+
602616
Args:
603617
data_frame (DataFrame): data_frame to be ingested to feature store.
604618
max_workers (int): number of threads to be created.
@@ -607,6 +621,8 @@ def ingest(
607621
wait (bool): whether to wait for the ingestion to finish or not.
608622
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
609623
if timeout is reached.
624+
profile_name (str): the profile credential should be used for ``PutRecord``
625+
(default: None).
610626
611627
Returns:
612628
An instance of IngestionManagerPandas.
@@ -622,6 +638,7 @@ def ingest(
622638
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
623639
max_workers=max_workers,
624640
max_processes=max_processes,
641+
profile_name=profile_name,
625642
)
626643

627644
manager.run(data_frame=data_frame, wait=wait, timeout=timeout)

tests/unit/sagemaker/feature_store/test_feature_store.py

+48
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pandas as pd
1818
import pytest
1919
from mock import Mock, patch, MagicMock
20+
from botocore.exceptions import ProfileNotFound
2021

2122
from sagemaker.feature_store.feature_definition import (
2223
FractionalFeatureDefinition,
@@ -227,6 +228,34 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
227228
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
228229
max_workers=10,
229230
max_processes=1,
231+
profile_name=None,
232+
)
233+
mock_ingestion_manager_instance.run.assert_called_once_with(
234+
data_frame=df, wait=True, timeout=None
235+
)
236+
237+
238+
@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
239+
def test_ingest_with_profile_name(
240+
ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
241+
):
242+
sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
243+
fs_runtime_client_config_mock
244+
)
245+
246+
feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
247+
df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
248+
249+
mock_ingestion_manager_instance = Mock()
250+
ingestion_manager_init.return_value = mock_ingestion_manager_instance
251+
feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name")
252+
253+
ingestion_manager_init.assert_called_once_with(
254+
feature_group_name="MyGroup",
255+
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
256+
max_workers=10,
257+
max_processes=1,
258+
profile_name="profile_name",
230259
)
231260
mock_ingestion_manager_instance.run.assert_called_once_with(
232261
data_frame=df, wait=True, timeout=None
@@ -340,6 +369,25 @@ def test_ingestion_manager_run_failure():
340369
assert manager.failed_rows == [1]
341370

342371

372+
@patch(
373+
"sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
374+
MagicMock(side_effect=ProfileNotFound(profile="non_exist")),
375+
)
376+
def test_ingestion_manager_with_profile_name_run_failure():
377+
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
378+
manager = IngestionManagerPandas(
379+
feature_group_name="MyGroup",
380+
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
381+
max_workers=1,
382+
profile_name="non_exist",
383+
)
384+
385+
try:
386+
manager.run(df)
387+
except Exception as e:
388+
assert "The config profile (non_exist) could not be found" in str(e)
389+
390+
343391
@patch(
344392
"sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
345393
PicklableMock(return_value=[1]),

0 commit comments

Comments
 (0)