Skip to content

Commit 13c4d5b

Browse files
imingtsoumizanfiu
authored andcommitted
Add integration tests for create_dataset (aws#743)
1 parent 7723c65 commit 13c4d5b

File tree

3 files changed

+232
-49
lines changed

3 files changed

+232
-49
lines changed

src/sagemaker/feature_store/dataset_builder.py

+37-37
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
"""
1717
from __future__ import absolute_import
1818

19+
import copy
1920
import datetime
20-
import os
2121
from typing import Any, Dict, List, Sequence, Tuple, Union
2222

2323
import attr
2424
import pandas as pd
2525

26-
from sagemaker import Session, s3, utils
26+
from sagemaker import Session, utils
2727
from sagemaker.feature_store.feature_group import FeatureGroup
2828

2929

@@ -166,7 +166,7 @@ class DatasetBuilder:
166166
_write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
167167
_event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None)
168168
_event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
169-
_feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, default=[])
169+
_feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list)
170170

171171
_DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = {
172172
"object": "STRING",
@@ -193,8 +193,6 @@ def with_feature_group(
193193
Returns:
194194
This DatasetBuilder object.
195195
"""
196-
if not target_feature_name_in_base:
197-
target_feature_name_in_base = self._record_identifier_feature_name
198196
self._feature_groups_to_be_merged.append(
199197
construct_feature_group_to_be_merged(
200198
feature_group, included_feature_names, target_feature_name_in_base
@@ -292,18 +290,11 @@ def to_csv(self) -> Tuple[str, str]:
292290
"""
293291
if isinstance(self._base, pd.DataFrame):
294292
temp_id = utils.unique_name_from_base("dataframe-base")
295-
local_file_name = f"{temp_id}.csv"
296-
desired_s3_folder = f"{self._output_path}/{temp_id}"
297-
self._base.to_csv(local_file_name, index=False, header=False)
298-
s3.S3Uploader.upload(
299-
local_path=local_file_name,
300-
desired_s3_uri=desired_s3_folder,
301-
sagemaker_session=self._sagemaker_session,
302-
kms_key=self._kms_key_id,
303-
)
304-
os.remove(local_file_name)
305-
temp_table_name = f"dataframe_{temp_id}"
306-
self._create_temp_table(temp_table_name, desired_s3_folder)
293+
s3_file_name = f"{temp_id}.csv"
294+
s3_folder = f"{self._output_path}/{temp_id}"
295+
self._base.to_csv(f"{s3_folder}/{s3_file_name}", index=False, header=False)
296+
temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
297+
self._create_temp_table(temp_table_name, s3_folder)
307298
base_features = list(self._base.columns)
308299
query_string = self._construct_query_string(
309300
FeatureGroupToBeMerged(
@@ -325,6 +316,10 @@ def to_csv(self) -> Tuple[str, str]:
325316
base_feature_group = construct_feature_group_to_be_merged(
326317
self._base, self._included_feature_names
327318
)
319+
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
320+
self._event_time_identifier_feature_name = (
321+
base_feature_group.event_time_identifier_feature_name
322+
)
328323
query_string = self._construct_query_string(base_feature_group)
329324
query_result = self._run_query(
330325
query_string,
@@ -344,15 +339,7 @@ def to_dataframe(self) -> Tuple[pd.DataFrame, str]:
344339
The query string executed.
345340
"""
346341
csv_file, query_string = self.to_csv()
347-
s3.S3Downloader.download(
348-
s3_uri=csv_file,
349-
local_path="./",
350-
kms_key=self._kms_key_id,
351-
sagemaker_session=self._sagemaker_session,
352-
)
353-
local_file_name = csv_file.split("/")[-1]
354-
df = pd.read_csv(local_file_name)
355-
os.remove(local_file_name)
342+
df = pd.read_csv(csv_file)
356343
return df, query_string
357344

358345
def _construct_where_query_string(
@@ -368,21 +355,25 @@ def _construct_where_query_string(
368355
The WHERE query string.
369356
"""
370357
where_conditions = []
371-
if not self._include_deleted_records:
358+
if not self._include_duplicated_records:
359+
where_conditions.append(f"row_{suffix} = 1")
360+
if not self._include_deleted_records and not isinstance(self._base, pd.DataFrame):
372361
where_conditions.append("NOT is_deleted")
373362
if self._write_time_ending_timestamp:
374363
where_conditions.append(
375-
f'table_{suffix}."write_time" <= {self._write_time_ending_timestamp}'
364+
f'table_{suffix}."write_time" <= '
365+
f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', "
366+
f"'yyyy-mm-dd hh24:mi:ss')"
376367
)
377368
if self._event_time_starting_timestamp:
378369
where_conditions.append(
379370
f'table_{suffix}."{event_time_identifier_feature_name}" >= '
380-
+ str(self._event_time_starting_timestamp)
371+
+ str(self._event_time_starting_timestamp.timestamp())
381372
)
382373
if self._event_time_ending_timestamp:
383374
where_conditions.append(
384375
f'table_{suffix}."{event_time_identifier_feature_name}" <= '
385-
+ str(self._event_time_ending_timestamp)
376+
+ str(self._event_time_ending_timestamp.timestamp())
386377
)
387378
if len(where_conditions) == 0:
388379
return ""
@@ -410,20 +401,27 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
410401
f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n'
411402
)
412403
else:
413-
features = feature_group.features
404+
features = copy.deepcopy(feature_group.features)
414405
features.remove(feature_group.event_time_identifier_feature_name)
415406
dedup_features = ", ".join([f'dedup_{suffix}."{feature}"' for feature in features])
407+
rank_query_string = (
408+
f'ORDER BY dedup_{suffix}."{feature_group.event_time_identifier_feature_name}" '
409+
+ f'DESC, dedup_{suffix}."api_invocation_time" DESC, dedup_{suffix}."write_time" '
410+
+ "DESC\n"
411+
)
412+
if isinstance(self._base, pd.DataFrame):
413+
rank_query_string = (
414+
f'ORDER BY dedup_{suffix}."{feature_group.event_time_identifier_feature_name}" '
415+
+ "DESC\n"
416+
)
416417
query_string += (
417418
"FROM (\n"
418419
+ "SELECT *, row_number() OVER (\n"
419420
+ f"PARTITION BY {dedup_features}\n"
420-
+ f'ORDER BY dedup_{suffix}."{feature_group.event_time_identifier_feature_name}" '
421-
+ f'DESC, dedup_{suffix}."api_invocation_time" DESC, '
422-
+ f'dedup_{suffix}."write_time" DESC\n'
421+
+ rank_query_string
423422
+ f") AS row_{suffix}\n"
424423
+ f'FROM "{feature_group.database}"."{feature_group.table_name}" dedup_{suffix}\n'
425424
+ f") AS table_{suffix}\n"
426-
+ f"WHERE row_{suffix} = 1\n"
427425
)
428426
return query_string + self._construct_where_query_string(
429427
suffix, feature_group.event_time_identifier_feature_name
@@ -446,8 +444,8 @@ def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str:
446444
for i, feature_group in enumerate(self._feature_groups_to_be_merged)
447445
]
448446
)
449-
query_string += f"{with_subquery_string}\n"
450-
query_string += "SELECT *\nFROM fg_base"
447+
query_string += with_subquery_string
448+
query_string += "\nSELECT *\nFROM fg_base"
451449
if len(self._feature_groups_to_be_merged) > 0:
452450
join_subquery_string = "".join(
453451
[
@@ -470,6 +468,8 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
470468
Returns:
471469
The JOIN query string.
472470
"""
471+
if not feature_group.target_feature_name_in_base:
472+
feature_group.target_feature_name_in_base = self._record_identifier_feature_name
473473
join_condition_string = (
474474
f"\nJOIN fg_{suffix}\n"
475475
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '

tests/integ/test_feature_store.py

+185
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def feature_group_name():
8282
return f"my-feature-group-{int(time.time() * 10**7)}"
8383

8484

85+
@pytest.fixture
86+
def base_name():
87+
return f"my-base-{int(time.time() * 10**7)}"
88+
89+
8590
@pytest.fixture
8691
def offline_store_s3_uri(feature_store_session, region_name):
8792
bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}"
@@ -109,6 +114,32 @@ def pandas_data_frame():
109114
return df
110115

111116

117+
@pytest.fixture
118+
def base_dataframe():
119+
base_data = [
120+
[1, 187512346.0, 123, 128],
121+
[2, 187512347.0, 168, 258],
122+
[3, 187512348.0, 125, 184],
123+
[1, 187512349.0, 195, 206],
124+
]
125+
return pd.DataFrame(
126+
base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"]
127+
)
128+
129+
130+
@pytest.fixture
131+
def feature_group_dataframe():
132+
feature_group_data = [
133+
[1, 187512246.0, 456, 325],
134+
[2, 187512247.0, 729, 693],
135+
[3, 187512348.0, 129, 901],
136+
[1, 187512449.0, 289, 286],
137+
]
138+
return pd.DataFrame(
139+
feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"]
140+
)
141+
142+
112143
@pytest.fixture
113144
def pandas_data_frame_without_string():
114145
df = pd.DataFrame(
@@ -527,6 +558,135 @@ def test_ingest_multi_process(
527558
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
528559

529560

561+
def test_create_dataset_with_feature_group_base(
562+
feature_store_session,
563+
region_name,
564+
role,
565+
base_name,
566+
feature_group_name,
567+
offline_store_s3_uri,
568+
base_dataframe,
569+
feature_group_dataframe,
570+
):
571+
base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session)
572+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
573+
with cleanup_feature_group(base) and cleanup_feature_group(feature_group):
574+
_create_feature_group_and_ingest_data(
575+
base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role
576+
)
577+
_create_feature_group_and_ingest_data(
578+
feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role
579+
)
580+
base_table_name = _get_athena_table_name_after_data_replication(
581+
feature_store_session, base, offline_store_s3_uri
582+
)
583+
feature_group_table_name = _get_athena_table_name_after_data_replication(
584+
feature_store_session, feature_group, offline_store_s3_uri
585+
)
586+
587+
with timeout(minutes=10) and cleanup_offline_store(
588+
base_table_name, feature_store_session
589+
) and cleanup_offline_store(feature_group_table_name, feature_store_session):
590+
feature_store = FeatureStore(sagemaker_session=feature_store_session)
591+
df, query_string = (
592+
feature_store.create_dataset(base=base, output_path=offline_store_s3_uri)
593+
.with_feature_group(feature_group)
594+
.to_dataframe()
595+
)
596+
sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True)
597+
merged_df = base_dataframe.merge(
598+
feature_group_dataframe, left_on="base_id", right_on="fg_id"
599+
)
600+
expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True)
601+
assert sorted_df.equals(expect_df)
602+
assert (
603+
query_string
604+
== 'WITH fg_base AS (SELECT table_base."base_id", table_base."base_time", '
605+
+ 'table_base."base_feature_1", table_base."base_feature_2"\n'
606+
+ "FROM (\n"
607+
+ "SELECT *, row_number() OVER (\n"
608+
+ 'PARTITION BY dedup_base."base_id", dedup_base."base_feature_1", '
609+
+ 'dedup_base."base_feature_2"\n'
610+
+ 'ORDER BY dedup_base."base_time" DESC, dedup_base."api_invocation_time" DESC, '
611+
+ 'dedup_base."write_time" DESC\n'
612+
+ ") AS row_base\n"
613+
+ f'FROM "sagemaker_featurestore"."{base_table_name}" dedup_base\n'
614+
+ ") AS table_base\n"
615+
+ "WHERE row_base = 1\n"
616+
+ "AND NOT is_deleted),\n"
617+
+ 'fg_0 AS (SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1", '
618+
+ 'table_0."fg_feature_2"\n'
619+
+ "FROM (\n"
620+
+ "SELECT *, row_number() OVER (\n"
621+
+ 'PARTITION BY dedup_0."fg_id", dedup_0."fg_feature_1", dedup_0."fg_feature_2"\n'
622+
+ 'ORDER BY dedup_0."fg_time" DESC, dedup_0."api_invocation_time" DESC, '
623+
+ 'dedup_0."write_time" DESC\n'
624+
+ ") AS row_0\n"
625+
+ f'FROM "sagemaker_featurestore"."{feature_group_table_name}" dedup_0\n'
626+
+ ") AS table_0\n"
627+
+ "WHERE row_0 = 1\n"
628+
+ "AND NOT is_deleted)\n"
629+
+ "SELECT *\n"
630+
+ "FROM fg_base\n"
631+
+ "JOIN fg_0\n"
632+
+ 'ON fg_base."base_id" = fg_0."fg_id"'
633+
)
634+
635+
636+
def _create_feature_group_and_ingest_data(
637+
feature_group: FeatureGroup,
638+
dataframe: DataFrame,
639+
offline_store_s3_uri: str,
640+
record_identifier_name: str,
641+
event_time_name: str,
642+
role: str,
643+
):
644+
feature_group.load_feature_definitions(data_frame=dataframe)
645+
feature_group.create(
646+
s3_uri=offline_store_s3_uri,
647+
record_identifier_name=record_identifier_name,
648+
event_time_feature_name=event_time_name,
649+
role_arn=role,
650+
enable_online_store=True,
651+
)
652+
_wait_for_feature_group_create(feature_group)
653+
654+
ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False)
655+
ingestion_manager.wait()
656+
assert 0 == len(ingestion_manager.failed_rows)
657+
658+
659+
def _get_athena_table_name_after_data_replication(
660+
feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri
661+
):
662+
feature_group_metadata = feature_group.describe()
663+
resolved_output_s3_uri = (
664+
feature_group_metadata.get("OfflineStoreConfig", None)
665+
.get("S3StorageConfig", None)
666+
.get("ResolvedOutputS3Uri", None)
667+
)
668+
s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "")
669+
region_name = feature_store_session.boto_session.region_name
670+
s3_client = feature_store_session.boto_session.client(
671+
service_name="s3", region_name=region_name
672+
)
673+
while True:
674+
objects_in_bucket = s3_client.list_objects(
675+
Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix
676+
)
677+
if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1:
678+
break
679+
else:
680+
print(f"Waiting for {feature_group.name} data in offline store...")
681+
time.sleep(60)
682+
print(f"{feature_group.name} data available.")
683+
return (
684+
feature_group_metadata.get("OfflineStoreConfig", None)
685+
.get("DataCatalogConfig", None)
686+
.get("TableName", None)
687+
)
688+
689+
530690
def _wait_for_feature_group_create(feature_group: FeatureGroup):
531691
status = feature_group.describe().get("FeatureGroupStatus")
532692
while status == "Creating":
@@ -560,3 +720,28 @@ def cleanup_feature_group(feature_group: FeatureGroup):
560720
feature_group.delete()
561721
except Exception:
562722
raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}")
723+
724+
725+
@contextmanager
726+
def cleanup_offline_store(table_name: str, feature_store_session: Session):
727+
try:
728+
yield
729+
finally:
730+
try:
731+
region_name = feature_store_session.boto_session.region_name
732+
s3_client = feature_store_session.boto_session.client(
733+
service_name="s3", region_name=region_name
734+
)
735+
account_id = feature_store_session.account_id()
736+
bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}"
737+
response = s3_client.list_objects_v2(
738+
Bucket=bucket_name,
739+
Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/",
740+
)
741+
files_in_folder = response["Contents"]
742+
files_to_delete = []
743+
for f in files_in_folder:
744+
files_to_delete.append({"Key": f["Key"]})
745+
s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete})
746+
except Exception:
747+
raise RuntimeError(f"Failed to delete data under {table_name}")

0 commit comments

Comments
 (0)