Skip to content

feature: Add profile_name support for Feature Store ingestion #2744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/sagemaker/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,15 @@ class IngestionManagerPandas:
max_workers (int): number of threads to create.
max_processes (int): number of processes to create. Each process spawns
``max_workers`` threads.
profile_name (str): the profile credential should be used for ``PutRecord``
(default: None).
"""

feature_group_name: str = attr.ib()
sagemaker_fs_runtime_client_config: Config = attr.ib()
max_workers: int = attr.ib(default=1)
max_processes: int = attr.ib(default=1)
profile_name: str = attr.ib(default=None)
_async_result: AsyncResult = attr.ib(default=None)
_processing_pool: ProcessingPool = attr.ib(default=None)
_failed_indices: List[int] = attr.ib(factory=list)
Expand All @@ -180,6 +183,7 @@ def _ingest_single_batch(
client_config: Config,
start_index: int,
end_index: int,
profile_name: str = None,
) -> List[int]:
"""Ingest a single batch of DataFrame rows into FeatureStore.

Expand All @@ -190,6 +194,8 @@ def _ingest_single_batch(
client to perform boto calls.
start_index (int): starting position to ingest in this batch.
end_index (int): ending position to ingest in this batch.
profile_name (str): the profile credential should be used for ``PutRecord``
(default: None).

Returns:
List of row indices that failed to be ingested.
Expand All @@ -198,7 +204,7 @@ def _ingest_single_batch(
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
client_config = copy.deepcopy(client_config)
client_config.retries = {"max_attempts": 10, "mode": "standard"}
sagemaker_featurestore_runtime_client = boto3.Session().client(
sagemaker_featurestore_runtime_client = boto3.Session(profile_name=profile_name).client(
service_name="sagemaker-featurestore-runtime", config=client_config
)

Expand Down Expand Up @@ -287,6 +293,7 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
data_frame[start_index:end_index],
start_index,
timeout,
self.profile_name,
)
]

Expand All @@ -311,6 +318,7 @@ def _run_multi_threaded(
data_frame: DataFrame,
row_offset=0,
timeout=None,
profile_name=None,
) -> List[int]:
"""Start the ingestion process.

Expand All @@ -321,6 +329,8 @@ def _run_multi_threaded(
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
profile_name (str): the profile credential should be used for ``PutRecord``
(default: None).

Returns:
List of row indices that failed to be ingested.
Expand All @@ -342,6 +352,7 @@ def _run_multi_threaded(
start_index=start_index,
end_index=end_index,
client_config=sagemaker_fs_runtime_client_config,
profile_name=profile_name,
)
] = (start_index + row_offset, end_index + row_offset)

Expand Down Expand Up @@ -581,6 +592,7 @@ def ingest(
max_processes: int = 1,
wait: bool = True,
timeout: Union[int, float] = None,
profile_name: str = None,
) -> IngestionManagerPandas:
"""Ingest the content of a pandas DataFrame to feature store.

Expand All @@ -599,6 +611,11 @@ def ingest(
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
the exception is thrown.

`profile_name` argument is an optional one. It will use the default credential if None is
passed. This `profile_name` is used in the sagemaker_featurestore_runtime client only. See
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more
about the default credential.

Args:
data_frame (DataFrame): data_frame to be ingested to feature store.
max_workers (int): number of threads to be created.
Expand All @@ -607,6 +624,8 @@ def ingest(
wait (bool): whether to wait for the ingestion to finish or not.
timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
if timeout is reached.
profile_name (str): the profile credential should be used for ``PutRecord``
(default: None).

Returns:
An instance of IngestionManagerPandas.
Expand All @@ -622,6 +641,7 @@ def ingest(
sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config,
max_workers=max_workers,
max_processes=max_processes,
profile_name=profile_name,
)

manager.run(data_frame=data_frame, wait=wait, timeout=timeout)
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/sagemaker/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import pytest
from mock import Mock, patch, MagicMock
from botocore.exceptions import ProfileNotFound

from sagemaker.feature_store.feature_definition import (
FractionalFeatureDefinition,
Expand Down Expand Up @@ -227,6 +228,34 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
max_workers=10,
max_processes=1,
profile_name=None,
)
mock_ingestion_manager_instance.run.assert_called_once_with(
data_frame=df, wait=True, timeout=None
)


@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
def test_ingest_with_profile_name(
ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
):
sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
fs_runtime_client_config_mock
)

feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))

mock_ingestion_manager_instance = Mock()
ingestion_manager_init.return_value = mock_ingestion_manager_instance
feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name")

ingestion_manager_init.assert_called_once_with(
feature_group_name="MyGroup",
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
max_workers=10,
max_processes=1,
profile_name="profile_name",
)
mock_ingestion_manager_instance.run.assert_called_once_with(
data_frame=df, wait=True, timeout=None
Expand Down Expand Up @@ -340,6 +369,25 @@ def test_ingestion_manager_run_failure():
assert manager.failed_rows == [1]


@patch(
"sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
MagicMock(side_effect=ProfileNotFound(profile="non_exist")),
)
def test_ingestion_manager_with_profile_name_run_failure():
df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
manager = IngestionManagerPandas(
feature_group_name="MyGroup",
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
max_workers=1,
profile_name="non_exist",
)

try:
manager.run(df)
except Exception as e:
assert "The config profile (non_exist) could not be found" in str(e)


@patch(
"sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
PicklableMock(return_value=[1]),
Expand Down