Skip to content

Commit 82ad736

Browse files
committed
feat: Add to_csv method in DatasetBuilder (aws#699)
1 parent 74b9049 commit 82ad736

File tree

5 files changed

+162
-16
lines changed

5 files changed

+162
-16
lines changed

src/sagemaker/feature_store/dataset_builder.py

+104
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import attr
2323
import pandas as pd
2424

25+
from sagemaker import Session
2526
from sagemaker.feature_store.feature_group import FeatureGroup
2627

2728

@@ -33,6 +34,7 @@ class DatasetBuilder:
3334
an output path and a KMS key ID.
3435
3536
Attributes:
37+
_sagemaker_session (Session): Session instance to perform boto calls.
3638
_base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
3739
pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset.
3840
_output_path (str): An S3 URI which stores the output .csv file.
@@ -59,6 +61,7 @@ class DatasetBuilder:
5961
dataset will be before it.
6062
"""
6163

64+
_sagemaker_session: Session = attr.ib()
6265
_base: Union[FeatureGroup, pd.DataFrame] = attr.ib()
6366
_output_path: str = attr.ib()
6467
_record_identifier_feature_name: str = attr.ib(default=None)
@@ -155,3 +158,104 @@ def with_event_time_range(
155158
self._event_time_starting_timestamp = starting_timestamp
156159
self._event_time_ending_timestamp = ending_timestamp
157160
return self
161+
162+
def to_csv(self):
163+
"""Get query string and result in .csv format
164+
165+
Returns:
166+
The S3 path of the .csv file.
167+
The query string executed.
168+
"""
169+
if isinstance(self._base, FeatureGroup):
170+
# TODO: handle pagination and input feature validation
171+
base_feature_group = self._base.describe()
172+
data_catalog_config = base_feature_group.get("OfflineStoreConfig", None).get(
173+
"DataCatalogConfig", None
174+
)
175+
if not data_catalog_config:
176+
raise RuntimeError("No metastore is configured with the base FeatureGroup.")
177+
disable_glue = base_feature_group.get("DisableGlueTableCreation", False)
178+
self._record_identifier_feature_name = base_feature_group.get(
179+
"RecordIdentifierFeatureName", None
180+
)
181+
self._event_time_identifier_feature_name = base_feature_group.get(
182+
"EventTimeFeatureName", None
183+
)
184+
base_features = [
185+
feature.get("FeatureName", None)
186+
for feature in base_feature_group.get("FeatureDefinitions", None)
187+
]
188+
189+
query = self._sagemaker_session.start_query_execution(
190+
catalog=data_catalog_config.get("Catalog", None)
191+
if disable_glue
192+
else "AwsDataCatalog",
193+
database=data_catalog_config.get("Database", None),
194+
query_string=self._construct_query_string(
195+
data_catalog_config.get("TableName", None),
196+
data_catalog_config.get("Database", None),
197+
base_features,
198+
),
199+
output_location=self._output_path,
200+
kms_key=self._kms_key_id,
201+
)
202+
query_id = query.get("QueryExecutionId", None)
203+
self._sagemaker_session.wait_for_athena_query(
204+
query_execution_id=query_id,
205+
)
206+
query_state = (
207+
self._sagemaker_session.get_query_execution(
208+
query_execution_id=query_id,
209+
)
210+
.get("QueryExecution", None)
211+
.get("Status", None)
212+
.get("State", None)
213+
)
214+
if query_state != "SUCCEEDED":
215+
raise RuntimeError(f"Failed to execute query {query_id}.")
216+
217+
return query_state.get("QueryExecution", None).get("ResultConfiguration", None).get(
218+
"OutputLocation", None
219+
), query_state.get("QueryExecution", None).get("Query", None)
220+
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
221+
222+
def _construct_query_string(
223+
self, base_table_name: str, database: str, base_features: list
224+
) -> str:
225+
"""Internal method for constructing SQL query string by parameters.
226+
227+
Args:
228+
base_table_name (str): The Athena table name of base FeatureGroup or pandas.DataFrame.
229+
database (str): The Athena database of the base table.
230+
base_features (list): The list of features of the base table.
231+
Returns:
232+
The query string.
233+
"""
234+
included_features = ", ".join(
235+
[
236+
f'base."{include_feature_name}"'
237+
for include_feature_name in self._included_feature_names
238+
]
239+
)
240+
query_string = f"SELECT {included_features}\n"
241+
if self._include_duplicated_records:
242+
query_string += f'FROM "{database}"."{base_table_name}" base\n'
243+
if not self._include_deleted_records:
244+
query_string += "WHERE NOT is_deleted\n"
245+
else:
246+
base_features.remove(self._event_time_identifier_feature_name)
247+
dedup_features = ", ".join([f'dedup_base."{feature}"' for feature in base_features])
248+
query_string += (
249+
"FROM (\n"
250+
+ "SELECT *, row_number() OVER (\n"
251+
+ f"PARTITION BY {dedup_features}\n"
252+
+ f'ORDER BY dedup_base."{self._event_time_identifier_feature_name}" '
253+
+ 'DESC, dedup_base."api_invocation_time" DESC, dedup_base."write_time" DESC\n'
254+
+ ") AS row_base\n"
255+
+ f'FROM "{database}"."{base_table_name}" dedup_base\n'
256+
+ ") AS base\n"
257+
+ "WHERE row_base = 1\n"
258+
)
259+
if not self._include_deleted_records:
260+
query_string += "AND NOT is_deleted\n"
261+
return query_string

src/sagemaker/feature_store/feature_store.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class FeatureStore:
4040

4141
sagemaker_session: Session = attr.ib(default=Session)
4242

43-
@staticmethod
4443
def create_dataset(
44+
self,
4545
base: Union[FeatureGroup, pd.DataFrame],
4646
output_path: str,
4747
record_identifier_feature_name: str = None,
@@ -76,6 +76,7 @@ def create_dataset(
7676
+ "identifier feature name if specify DataFrame as base."
7777
)
7878
return DatasetBuilder(
79+
self.sagemaker_session,
7980
base,
8081
output_path,
8182
record_identifier_feature_name,

tests/unit/sagemaker/feature_store/test_dataset_builder.py

+49-14
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,85 @@
2020
from sagemaker.feature_store.dataset_builder import DatasetBuilder
2121

2222

23+
@pytest.fixture
24+
def sagemaker_session_mock():
25+
return Mock()
26+
27+
2328
@pytest.fixture
2429
def feature_group_mock():
2530
return Mock()
2631

2732

28-
def test_point_in_time_accurate_join(feature_group_mock):
29-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
33+
def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock):
34+
dataset_builder = DatasetBuilder(
35+
sagemaker_session=sagemaker_session_mock,
36+
base=feature_group_mock,
37+
output_path="file/to/path",
38+
)
3039
dataset_builder.point_in_time_accurate_join()
3140
assert dataset_builder._point_in_time_accurate_join
3241

3342

34-
def test_include_duplicated_records(feature_group_mock):
35-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
43+
def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock):
44+
dataset_builder = DatasetBuilder(
45+
sagemaker_session=sagemaker_session_mock,
46+
base=feature_group_mock,
47+
output_path="file/to/path",
48+
)
3649
dataset_builder.include_duplicated_records()
3750
assert dataset_builder._include_duplicated_records
3851

3952

40-
def test_include_deleted_records(feature_group_mock):
41-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
53+
def test_include_deleted_records(sagemaker_session_mock, feature_group_mock):
54+
dataset_builder = DatasetBuilder(
55+
sagemaker_session=sagemaker_session_mock,
56+
base=feature_group_mock,
57+
output_path="file/to/path",
58+
)
4259
dataset_builder.include_deleted_records()
4360
assert dataset_builder._include_deleted_records
4461

4562

46-
def test_with_number_of_recent_records_by_record_identifier(feature_group_mock):
47-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
63+
def test_with_number_of_recent_records_by_record_identifier(
64+
sagemaker_session_mock, feature_group_mock
65+
):
66+
dataset_builder = DatasetBuilder(
67+
sagemaker_session=sagemaker_session_mock,
68+
base=feature_group_mock,
69+
output_path="file/to/path",
70+
)
4871
dataset_builder.with_number_of_recent_records_by_record_identifier(5)
4972
assert dataset_builder._number_of_recent_records == 5
5073

5174

52-
def test_with_number_of_records_from_query_results(feature_group_mock):
53-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
75+
def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock):
76+
dataset_builder = DatasetBuilder(
77+
sagemaker_session=sagemaker_session_mock,
78+
base=feature_group_mock,
79+
output_path="file/to/path",
80+
)
5481
dataset_builder.with_number_of_records_from_query_results(100)
5582
assert dataset_builder._number_of_records == 100
5683

5784

58-
def test_as_of(feature_group_mock):
59-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
85+
def test_as_of(sagemaker_session_mock, feature_group_mock):
86+
dataset_builder = DatasetBuilder(
87+
sagemaker_session=sagemaker_session_mock,
88+
base=feature_group_mock,
89+
output_path="file/to/path",
90+
)
6091
time = datetime.datetime.now()
6192
dataset_builder.as_of(time)
6293
assert dataset_builder._write_time_ending_timestamp == time
6394

6495

65-
def test_with_event_time_range(feature_group_mock):
66-
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
96+
def test_with_event_time_range(sagemaker_session_mock, feature_group_mock):
97+
dataset_builder = DatasetBuilder(
98+
sagemaker_session=sagemaker_session_mock,
99+
base=feature_group_mock,
100+
output_path="file/to/path",
101+
)
67102
start = datetime.datetime.now()
68103
end = start + datetime.timedelta(minutes=1)
69104
dataset_builder.with_event_time_range(start, end)

tests/unit/sagemaker/feature_store/test_feature_group.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -500,13 +500,16 @@ def query(sagemaker_session_mock):
500500

501501
def test_athena_query_run(sagemaker_session_mock, query):
502502
sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
503-
query.run(query_string="query", output_location="s3://some-bucket/some-path")
503+
query.run(
504+
query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup"
505+
)
504506
sagemaker_session_mock.start_query_execution.assert_called_with(
505507
catalog="catalog",
506508
database="database",
507509
query_string="query",
508510
output_location="s3://some-bucket/some-path",
509511
kms_key=None,
512+
workgroup="workgroup",
510513
)
511514
assert "some-bucket" == query._result_bucket
512515
assert "some-path" == query._result_file_prefix

tests/unit/sagemaker/feature_store/test_feature_store.py

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock):
3939
base=feature_group_mock,
4040
output_path="file/to/path",
4141
)
42+
assert dataset_builder._sagemaker_session == sagemaker_session_mock
4243
assert dataset_builder._base == feature_group_mock
4344
assert dataset_builder._output_path == "file/to/path"
4445

@@ -51,6 +52,7 @@ def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock):
5152
output_path="file/to/path",
5253
kms_key_id="kms-key-id",
5354
)
55+
assert dataset_builder._sagemaker_session == sagemaker_session_mock
5456
assert dataset_builder._base == feature_group_mock
5557
assert dataset_builder._included_feature_names == ["feature_1", "feature_2"]
5658
assert dataset_builder._output_path == "file/to/path"
@@ -67,6 +69,7 @@ def test_create_dataset_with_dataframe(sagemaker_session_mock):
6769
output_path="file/to/path",
6870
kms_key_id="kms-key-id",
6971
)
72+
assert dataset_builder._sagemaker_session == sagemaker_session_mock
7073
assert dataset_builder._base.equals(DATAFRAME)
7174
assert dataset_builder._record_identifier_feature_name == "feature_1"
7275
assert dataset_builder._event_time_identifier_feature_name == "feature_2"

0 commit comments

Comments
 (0)