diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 77678f4eca..41bdcd764c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -163,12 +163,15 @@ class IngestionManagerPandas: max_workers (int): number of threads to create. max_processes (int): number of processes to create. Each process spawns ``max_workers`` threads. + profile_name (str): the profile credential should be used for ``PutRecord`` + (default: None). """ feature_group_name: str = attr.ib() sagemaker_fs_runtime_client_config: Config = attr.ib() max_workers: int = attr.ib(default=1) max_processes: int = attr.ib(default=1) + profile_name: str = attr.ib(default=None) _async_result: AsyncResult = attr.ib(default=None) _processing_pool: ProcessingPool = attr.ib(default=None) _failed_indices: List[int] = attr.ib(factory=list) @@ -180,6 +183,7 @@ def _ingest_single_batch( client_config: Config, start_index: int, end_index: int, + profile_name: str = None, ) -> List[int]: """Ingest a single batch of DataFrame rows into FeatureStore. @@ -190,6 +194,8 @@ def _ingest_single_batch( client to perform boto calls. start_index (int): starting position to ingest in this batch. end_index (int): ending position to ingest in this batch. + profile_name (str): the profile credential should be used for ``PutRecord`` + (default: None). Returns: List of row indices that failed to be ingested. @@ -198,7 +204,7 @@ def _ingest_single_batch( if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config: client_config = copy.deepcopy(client_config) client_config.retries = {"max_attempts": 10, "mode": "standard"} - sagemaker_featurestore_runtime_client = boto3.Session().client( + sagemaker_featurestore_runtime_client = boto3.Session(profile_name=profile_name).client( service_name="sagemaker-featurestore-runtime", config=client_config ) @@ -287,6 +293,7 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None): data_frame[start_index:end_index], start_index, timeout, + self.profile_name, ) ] @@ -311,6 +318,7 @@ def _run_multi_threaded( data_frame: DataFrame, row_offset=0, timeout=None, + profile_name=None, ) -> List[int]: """Start the ingestion process. @@ -321,6 +329,8 @@ def _run_multi_threaded( wait (bool): whether to wait for the ingestion to finish or not. timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised if timeout is reached. + profile_name (str): the profile credential should be used for ``PutRecord`` + (default: None). Returns: List of row indices that failed to be ingested. @@ -342,6 +352,7 @@ def _run_multi_threaded( start_index=start_index, end_index=end_index, client_config=sagemaker_fs_runtime_client_config, + profile_name=profile_name, ) ] = (start_index + row_offset, end_index + row_offset) @@ -581,6 +592,7 @@ def ingest( max_processes: int = 1, wait: bool = True, timeout: Union[int, float] = None, + profile_name: str = None, ) -> IngestionManagerPandas: """Ingest the content of a pandas DataFrame to feature store. @@ -599,6 +611,11 @@ def ingest( They can also be found from the IngestionManagerPandas' ``failed_rows`` function after the exception is thrown. + `profile_name` argument is an optional one. It will use the default credential if None is + passed. This `profile_name` is used in the sagemaker_featurestore_runtime client only. See + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more + about the default credential. + Args: data_frame (DataFrame): data_frame to be ingested to feature store. max_workers (int): number of threads to be created. @@ -607,6 +624,8 @@ def ingest( wait (bool): whether to wait for the ingestion to finish or not. timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised if timeout is reached. + profile_name (str): the profile credential should be used for ``PutRecord`` + (default: None). Returns: An instance of IngestionManagerPandas. @@ -622,6 +641,7 @@ def ingest( sagemaker_fs_runtime_client_config=self.sagemaker_session.sagemaker_featurestore_runtime_client.meta.config, max_workers=max_workers, max_processes=max_processes, + profile_name=profile_name, ) manager.run(data_frame=data_frame, wait=wait, timeout=timeout) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 8cd38ede35..0192287c35 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -17,6 +17,7 @@ import pandas as pd import pytest from mock import Mock, patch, MagicMock +from botocore.exceptions import ProfileNotFound from sagemaker.feature_store.feature_definition import ( FractionalFeatureDefinition, @@ -227,6 +228,34 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, max_workers=10, max_processes=1, + profile_name=None, + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest_with_profile_name( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name="profile_name", ) mock_ingestion_manager_instance.run.assert_called_once_with( data_frame=df, wait=True, timeout=None @@ -340,6 +369,25 @@ def test_ingestion_manager_run_failure(): assert manager.failed_rows == [1] +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(side_effect=ProfileNotFound(profile="non_exist")), +) +def test_ingestion_manager_with_profile_name_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + profile_name="non_exist", + ) + + try: + manager.run(df) + except Exception as e: + assert "The config profile (non_exist) could not be found" in str(e) + + @patch( "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", PicklableMock(return_value=[1]),