27
27
from sagemaker .feature_store .feature_group import FeatureGroup
28
28
29
29
30
+ _DEFAULT_CATALOG = "AwsDataCatalog"
31
+ _DEFAULT_DATABASE = "sagemaker_featurestore"
32
+
33
+
30
34
@attr .s
31
35
class FeatureGroupToBeMerged :
32
36
"""FeatureGroup metadata which will be used for SQL join.
@@ -88,11 +92,13 @@ def construct_feature_group_to_be_merged(
88
92
table_name = data_catalog_config .get ("TableName" , None )
89
93
database = data_catalog_config .get ("Database" , None )
90
94
disable_glue = feature_group_metadata .get ("DisableGlueTableCreation" , False )
91
- catalog = data_catalog_config .get ("Catalog" , None ) if disable_glue else "AwsDataCatalog"
95
+ catalog = data_catalog_config .get ("Catalog" , None ) if disable_glue else _DEFAULT_CATALOG
92
96
features = [
93
97
feature .get ("FeatureName" , None )
94
98
for feature in feature_group_metadata .get ("FeatureDefinitions" , None )
95
99
]
100
+ if not included_feature_names :
101
+ included_feature_names = features
96
102
return FeatureGroupToBeMerged (
97
103
features ,
98
104
included_feature_names ,
@@ -167,7 +173,7 @@ class DatasetBuilder:
167
173
"int64" : "INT" ,
168
174
"float64" : "DOUBLE" ,
169
175
"bool" : "BOOLEAN" ,
170
- "datetime64" : "TIMESTAMP" ,
176
+ "datetime64[ns] " : "TIMESTAMP" ,
171
177
}
172
178
173
179
def with_feature_group (
@@ -300,12 +306,17 @@ def to_csv(self) -> Tuple[str, str]:
300
306
self ._create_temp_table (temp_table_name , desired_s3_folder )
301
307
base_features = list (self ._base .columns )
302
308
query_string = self ._construct_query_string (
303
- temp_table_name ,
304
- "AwsDataCatalog" ,
305
- "sagemaker_featurestore" ,
306
- base_features ,
309
+ FeatureGroupToBeMerged (
310
+ base_features ,
311
+ self ._included_feature_names if self ._included_feature_names else base_features ,
312
+ _DEFAULT_CATALOG ,
313
+ _DEFAULT_DATABASE ,
314
+ temp_table_name ,
315
+ self ._record_identifier_feature_name ,
316
+ self ._event_time_identifier_feature_name ,
317
+ )
307
318
)
308
- query_result = self ._run_query (query_string , "AwsDataCatalog" , "sagemaker_featurestore" )
319
+ query_result = self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
309
320
# TODO: cleanup temp table, need more clarification, keep it for now
310
321
return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
311
322
"OutputLocation" , None
@@ -314,12 +325,7 @@ def to_csv(self) -> Tuple[str, str]:
314
325
base_feature_group = construct_feature_group_to_be_merged (
315
326
self ._base , self ._included_feature_names
316
327
)
317
- query_string = self ._construct_query_string (
318
- base_feature_group .table_name ,
319
- base_feature_group .catalog ,
320
- base_feature_group .database ,
321
- base_feature_group .features ,
322
- )
328
+ query_string = self ._construct_query_string (base_feature_group )
323
329
query_result = self ._run_query (
324
330
query_string ,
325
331
base_feature_group .catalog ,
@@ -330,14 +336,14 @@ def to_csv(self) -> Tuple[str, str]:
330
336
), query_result .get ("QueryExecution" , {}).get ("Query" , None )
331
337
raise ValueError ("Base must be either a FeatureGroup or a DataFrame." )
332
338
333
- def to_dataframe (self ) -> Tuple [str , pd .DataFrame ]:
339
+ def to_dataframe (self ) -> Tuple [pd .DataFrame , str ]:
334
340
"""Get query string and result in pandas.Dataframe
335
341
336
342
Returns:
337
343
The pandas.DataFrame object.
338
344
The query string executed.
339
345
"""
340
- query_string , csv_file = self .to_csv ()
346
+ csv_file , query_string = self .to_csv ()
341
347
s3 .S3Downloader .download (
342
348
s3_uri = csv_file ,
343
349
local_path = "./" ,
@@ -347,9 +353,11 @@ def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
347
353
local_file_name = csv_file .split ("/" )[- 1 ]
348
354
df = pd .read_csv (local_file_name )
349
355
os .remove (local_file_name )
350
- return query_string , df
356
+ return df , query_string
351
357
352
- def _construct_where_query_string (self , suffix : str , event_time_identifier_feature_name : str ):
358
+ def _construct_where_query_string (
359
+ self , suffix : str , event_time_identifier_feature_name : str
360
+ ) -> str :
353
361
"""Internal method for constructing SQL WHERE query string by parameters.
354
362
355
363
Args:
@@ -380,7 +388,7 @@ def _construct_where_query_string(self, suffix: str, event_time_identifier_featu
380
388
return ""
381
389
return "WHERE " + "\n AND " .join (where_conditions )
382
390
383
- def _construct_table_query (self , feature_group : FeatureGroupToBeMerged , suffix : str ):
391
+ def _construct_table_query (self , feature_group : FeatureGroupToBeMerged , suffix : str ) -> str :
384
392
"""Internal method for constructing SQL query string by parameters.
385
393
386
394
Args:
@@ -421,27 +429,14 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
421
429
suffix , feature_group .event_time_identifier_feature_name
422
430
)
423
431
424
- def _construct_query_string (
425
- self , base_table_name : str , catalog : str , database : str , base_features : list
426
- ) -> str :
432
+ def _construct_query_string (self , base : FeatureGroupToBeMerged ) -> str :
427
433
"""Internal method for constructing SQL query string by parameters.
428
434
429
435
Args:
430
- base_table_name (str): The Athena table name of base FeatureGroup or pandas.DataFrame.
431
- database (str): The Athena database of the base table.
432
- base_features (list): The list of features of the base table.
436
+ base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata.
433
437
Returns:
434
438
The query string.
435
439
"""
436
- base = FeatureGroupToBeMerged (
437
- base_features ,
438
- self ._included_feature_names ,
439
- catalog ,
440
- database ,
441
- base_table_name ,
442
- self ._record_identifier_feature_name ,
443
- self ._event_time_identifier_feature_name ,
444
- )
445
440
base_table_query_string = self ._construct_table_query (base , "base" )
446
441
query_string = f"WITH fg_base AS ({ base_table_query_string } )"
447
442
if len (self ._feature_groups_to_be_merged ) > 0 :
@@ -451,7 +446,7 @@ def _construct_query_string(
451
446
for i , feature_group in enumerate (self ._feature_groups_to_be_merged )
452
447
]
453
448
)
454
- query_string += with_subquery_string
449
+ query_string += f" { with_subquery_string } \n "
455
450
query_string += "SELECT *\n FROM fg_base"
456
451
if len (self ._feature_groups_to_be_merged ) > 0 :
457
452
join_subquery_string = "" .join (
@@ -465,7 +460,7 @@ def _construct_query_string(
465
460
query_string += f"\n LIMIT { self ._number_of_records } "
466
461
return query_string
467
462
468
- def _construct_join_condition (self , feature_group : FeatureGroupToBeMerged , suffix : str ):
463
+ def _construct_join_condition (self , feature_group : FeatureGroupToBeMerged , suffix : str ) -> str :
469
464
"""Internal method for constructing SQL JOIN query string by parameters.
470
465
471
466
Args:
@@ -504,7 +499,7 @@ def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
504
499
+ f"WITH SERDEPROPERTIES ({ serde_properties } ) "
505
500
+ f"LOCATION '{ desired_s3_folder } ';"
506
501
)
507
- self ._run_query (query_string , "AwsDataCatalog" , "sagemaker_featurestore" )
502
+ self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
508
503
509
504
def _construct_athena_table_column_string (self , column : str ) -> str :
510
505
"""Internal method for constructing string of Athena column.
@@ -518,9 +513,9 @@ def _construct_athena_table_column_string(self, column: str) -> str:
518
513
RuntimeError: The type of pandas.Dataframe column is not support yet.
519
514
"""
520
515
dataframe_type = self ._base [column ].dtypes
521
- if dataframe_type not in self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .keys ():
516
+ if str ( dataframe_type ) not in self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .keys ():
522
517
raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
523
- return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (dataframe_type , None )} "
518
+ return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (str ( dataframe_type ) , None )} "
524
519
525
520
def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
526
521
"""Internal method for execute Athena query, wait for query finish and get query result.
0 commit comments