Skip to content

Commit 93a867c

Browse files
imingtsoumizanfiu
authored andcommitted
feat: Add to_dataframe method in DatasetBuilder (aws#729)
1 parent c06ed9b commit 93a867c

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

src/sagemaker/feature_store/dataset_builder.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import absolute_import
1818

1919
import datetime
20-
from typing import Any, Dict, List, Sequence, Union
20+
from typing import Any, Dict, List, Sequence, Tuple, Union
2121

2222
import attr
2323
import pandas as pd
@@ -248,7 +248,7 @@ def with_event_time_range(
248248
self._event_time_ending_timestamp = ending_timestamp
249249
return self
250250

251-
def to_csv(self):
251+
def to_csv(self) -> Tuple[str, str]:
252252
"""Get query string and result in .csv format
253253
254254
Returns:
@@ -314,6 +314,23 @@ def to_csv(self):
314314
), query_result.get("QueryExecution", None).get("Query", None)
315315
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
316316

317+
def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
318+
"""Get query string and result in pandas.Dataframe
319+
320+
Returns:
321+
The pandas.DataFrame object.
322+
The query string executed.
323+
"""
324+
query_string, csv_file = self.to_csv()
325+
s3.S3Downloader.download(
326+
s3_uri=csv_file,
327+
local_path="./",
328+
kms_key=self._kms_key_id,
329+
sagemaker_session=self._sagemaker_session,
330+
)
331+
# TODO: do we need to clean up local file?
332+
return query_string, pd.read_csv(csv_file.split("/")[-1])
333+
317334
def _construct_where_query_string(self, suffix: str, event_time_identifier_feature_name: str):
318335
"""Internal method for constructing SQL WHERE query string by parameters.
319336
@@ -425,6 +442,8 @@ def _construct_query_string(
425442
]
426443
)
427444
query_string += join_subquery_string
445+
if self._number_of_records:
446+
query_string += f"\nLIMIT {self._number_of_records}"
428447
return query_string
429448

430449
def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str):

tests/unit/sagemaker/feature_store/test_dataset_builder.py

+54
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import datetime
1616

17+
import pandas as pd
1718
import pytest
1819
from mock import Mock
1920

2021
from sagemaker.feature_store.dataset_builder import DatasetBuilder
22+
from sagemaker.feature_store.feature_group import FeatureGroup
2123

2224

2325
@pytest.fixture
@@ -30,6 +32,58 @@ def feature_group_mock():
3032
return Mock()
3133

3234

35+
def test_with_feature_group_throw_runtime_error(sagemaker_session_mock):
36+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
37+
dataset_builder = DatasetBuilder(
38+
sagemaker_session=sagemaker_session_mock,
39+
base=feature_group,
40+
output_path="file/to/path",
41+
)
42+
sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}}
43+
with pytest.raises(RuntimeError) as error:
44+
dataset_builder.with_feature_group(
45+
feature_group, "target-feature", ["feature-1", "feature-2"]
46+
)
47+
assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error)
48+
49+
50+
def test_with_feature_group(sagemaker_session_mock):
51+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
52+
dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]})
53+
feature_group.load_feature_definitions(dataframe)
54+
dataset_builder = DatasetBuilder(
55+
sagemaker_session=sagemaker_session_mock,
56+
base=feature_group,
57+
output_path="file/to/path",
58+
)
59+
sagemaker_session_mock.describe_feature_group.return_value = {
60+
"OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}},
61+
"RecordIdentifierFeatureName": "feature-1",
62+
"EventTimeFeatureName": "feature-2",
63+
}
64+
dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"])
65+
assert len(dataset_builder._feature_groups_to_be_merged) == 1
66+
assert dataset_builder._feature_groups_to_be_merged[0].features == ["feature-1", "feature-2"]
67+
assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [
68+
"feature-1",
69+
"feature-2",
70+
]
71+
assert dataset_builder._feature_groups_to_be_merged[0].database == "database"
72+
assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table"
73+
assert (
74+
dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name
75+
== "feature-1"
76+
)
77+
assert (
78+
dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature_name
79+
== "feature-2"
80+
)
81+
assert (
82+
dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base
83+
== "target-feature"
84+
)
85+
86+
3387
def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock):
3488
dataset_builder = DatasetBuilder(
3589
sagemaker_session=sagemaker_session_mock,

0 commit comments

Comments
 (0)