Skip to content

Commit 66e18a4

Browse files
imingtsoumizanfiu
authored andcommitted
Unit test for DatasetBuilder (aws#734)
1 parent 98b60c0 commit 66e18a4

File tree

2 files changed

+315
-41
lines changed

2 files changed

+315
-41
lines changed

src/sagemaker/feature_store/dataset_builder.py

+33-38
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
from sagemaker.feature_store.feature_group import FeatureGroup
2828

2929

30+
_DEFAULT_CATALOG = "AwsDataCatalog"
31+
_DEFAULT_DATABASE = "sagemaker_featurestore"
32+
33+
3034
@attr.s
3135
class FeatureGroupToBeMerged:
3236
"""FeatureGroup metadata which will be used for SQL join.
@@ -88,11 +92,13 @@ def construct_feature_group_to_be_merged(
8892
table_name = data_catalog_config.get("TableName", None)
8993
database = data_catalog_config.get("Database", None)
9094
disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False)
91-
catalog = data_catalog_config.get("Catalog", None) if disable_glue else "AwsDataCatalog"
95+
catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
9296
features = [
9397
feature.get("FeatureName", None)
9498
for feature in feature_group_metadata.get("FeatureDefinitions", None)
9599
]
100+
if not included_feature_names:
101+
included_feature_names = features
96102
return FeatureGroupToBeMerged(
97103
features,
98104
included_feature_names,
@@ -167,7 +173,7 @@ class DatasetBuilder:
167173
"int64": "INT",
168174
"float64": "DOUBLE",
169175
"bool": "BOOLEAN",
170-
"datetime64": "TIMESTAMP",
176+
"datetime64[ns]": "TIMESTAMP",
171177
}
172178

173179
def with_feature_group(
@@ -300,12 +306,17 @@ def to_csv(self) -> Tuple[str, str]:
300306
self._create_temp_table(temp_table_name, desired_s3_folder)
301307
base_features = list(self._base.columns)
302308
query_string = self._construct_query_string(
303-
temp_table_name,
304-
"AwsDataCatalog",
305-
"sagemaker_featurestore",
306-
base_features,
309+
FeatureGroupToBeMerged(
310+
base_features,
311+
self._included_feature_names if self._included_feature_names else base_features,
312+
_DEFAULT_CATALOG,
313+
_DEFAULT_DATABASE,
314+
temp_table_name,
315+
self._record_identifier_feature_name,
316+
self._event_time_identifier_feature_name,
317+
)
307318
)
308-
query_result = self._run_query(query_string, "AwsDataCatalog", "sagemaker_featurestore")
319+
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
309320
# TODO: cleanup temp table, need more clarification, keep it for now
310321
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
311322
"OutputLocation", None
@@ -314,12 +325,7 @@ def to_csv(self) -> Tuple[str, str]:
314325
base_feature_group = construct_feature_group_to_be_merged(
315326
self._base, self._included_feature_names
316327
)
317-
query_string = self._construct_query_string(
318-
base_feature_group.table_name,
319-
base_feature_group.catalog,
320-
base_feature_group.database,
321-
base_feature_group.features,
322-
)
328+
query_string = self._construct_query_string(base_feature_group)
323329
query_result = self._run_query(
324330
query_string,
325331
base_feature_group.catalog,
@@ -330,14 +336,14 @@ def to_csv(self) -> Tuple[str, str]:
330336
), query_result.get("QueryExecution", {}).get("Query", None)
331337
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
332338

333-
def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
339+
def to_dataframe(self) -> Tuple[pd.DataFrame, str]:
334340
"""Get query string and result in pandas.Dataframe
335341
336342
Returns:
337343
The pandas.DataFrame object.
338344
The query string executed.
339345
"""
340-
query_string, csv_file = self.to_csv()
346+
csv_file, query_string = self.to_csv()
341347
s3.S3Downloader.download(
342348
s3_uri=csv_file,
343349
local_path="./",
@@ -347,9 +353,11 @@ def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
347353
local_file_name = csv_file.split("/")[-1]
348354
df = pd.read_csv(local_file_name)
349355
os.remove(local_file_name)
350-
return query_string, df
356+
return df, query_string
351357

352-
def _construct_where_query_string(self, suffix: str, event_time_identifier_feature_name: str):
358+
def _construct_where_query_string(
359+
self, suffix: str, event_time_identifier_feature_name: str
360+
) -> str:
353361
"""Internal method for constructing SQL WHERE query string by parameters.
354362
355363
Args:
@@ -380,7 +388,7 @@ def _construct_where_query_string(self, suffix: str, event_time_identifier_featu
380388
return ""
381389
return "WHERE " + "\nAND ".join(where_conditions)
382390

383-
def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str):
391+
def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
384392
"""Internal method for constructing SQL query string by parameters.
385393
386394
Args:
@@ -421,27 +429,14 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
421429
suffix, feature_group.event_time_identifier_feature_name
422430
)
423431

424-
def _construct_query_string(
425-
self, base_table_name: str, catalog: str, database: str, base_features: list
426-
) -> str:
432+
def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str:
427433
"""Internal method for constructing SQL query string by parameters.
428434
429435
Args:
430-
base_table_name (str): The Athena table name of base FeatureGroup or pandas.DataFrame.
431-
database (str): The Athena database of the base table.
432-
base_features (list): The list of features of the base table.
436+
base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata.
433437
Returns:
434438
The query string.
435439
"""
436-
base = FeatureGroupToBeMerged(
437-
base_features,
438-
self._included_feature_names,
439-
catalog,
440-
database,
441-
base_table_name,
442-
self._record_identifier_feature_name,
443-
self._event_time_identifier_feature_name,
444-
)
445440
base_table_query_string = self._construct_table_query(base, "base")
446441
query_string = f"WITH fg_base AS ({base_table_query_string})"
447442
if len(self._feature_groups_to_be_merged) > 0:
@@ -451,7 +446,7 @@ def _construct_query_string(
451446
for i, feature_group in enumerate(self._feature_groups_to_be_merged)
452447
]
453448
)
454-
query_string += with_subquery_string
449+
query_string += f"{with_subquery_string}\n"
455450
query_string += "SELECT *\nFROM fg_base"
456451
if len(self._feature_groups_to_be_merged) > 0:
457452
join_subquery_string = "".join(
@@ -465,7 +460,7 @@ def _construct_query_string(
465460
query_string += f"\nLIMIT {self._number_of_records}"
466461
return query_string
467462

468-
def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str):
463+
def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
469464
"""Internal method for constructing SQL JOIN query string by parameters.
470465
471466
Args:
@@ -504,7 +499,7 @@ def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
504499
+ f"WITH SERDEPROPERTIES ({serde_properties}) "
505500
+ f"LOCATION '{desired_s3_folder}';"
506501
)
507-
self._run_query(query_string, "AwsDataCatalog", "sagemaker_featurestore")
502+
self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
508503

509504
def _construct_athena_table_column_string(self, column: str) -> str:
510505
"""Internal method for constructing string of Athena column.
@@ -518,9 +513,9 @@ def _construct_athena_table_column_string(self, column: str) -> str:
518513
RuntimeError: The type of pandas.Dataframe column is not support yet.
519514
"""
520515
dataframe_type = self._base[column].dtypes
521-
if dataframe_type not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys():
516+
if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys():
522517
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
523-
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(dataframe_type, None)}"
518+
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"
524519

525520
def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
526521
"""Internal method for execute Athena query, wait for query finish and get query result.

0 commit comments

Comments
 (0)