17
17
from __future__ import absolute_import
18
18
19
19
import datetime
20
+ import os
20
21
from typing import Any , Dict , List , Sequence , Tuple , Union
21
22
22
23
import attr
@@ -39,6 +40,7 @@ class FeatureGroupToBeMerged:
39
40
features (List[str]): A list of strings representing feature names of this FeatureGroup.
40
41
included_feature_names (Sequence[str]): A list of strings representing features to be
41
42
included in the output.
43
+ catalog (str): A string representing the catalog.
42
44
database (str): A string representing the database.
43
45
table_name (str): A string representing the Athena table name of this FeatureGroup.
44
46
record_dentifier_feature_name (str): A string representing the record identifier feature.
@@ -50,13 +52,59 @@ class FeatureGroupToBeMerged:
50
52
51
53
features : List [str ] = attr .ib ()
52
54
included_feature_names : Sequence [str ] = attr .ib ()
55
+ catalog : str = attr .ib ()
53
56
database : str = attr .ib ()
54
57
table_name : str = attr .ib ()
55
58
record_identifier_feature_name : str = attr .ib ()
56
59
event_time_identifier_feature_name : str = attr .ib ()
57
60
target_feature_name_in_base : str = attr .ib (default = None )
58
61
59
62
63
+ def construct_feature_group_to_be_merged (
64
+ feature_group : FeatureGroup ,
65
+ included_feature_names : Sequence [str ],
66
+ target_feature_name_in_base : str = None ,
67
+ ) -> FeatureGroupToBeMerged :
68
+ """Construct a FeatureGroupToBeMerged object by provided parameters.
69
+
70
+ Args:
71
+ feature_group (FeatureGroup): A FeatureGroup object.
72
+ included_feature_names (Sequence[str]): A list of strings representing features to be
73
+ included in the output.
74
+ target_feature_name_in_base (str): A string representing the feature name in base which
75
+ will be used as target join key (default: None).
76
+ Returns:
77
+ A FeatureGroupToBeMerged object.
78
+ """
79
+ feature_group_metadata = feature_group .describe ()
80
+ data_catalog_config = feature_group_metadata .get ("OfflineStoreConfig" , {}).get (
81
+ "DataCatalogConfig" , None
82
+ )
83
+ if not data_catalog_config :
84
+ raise RuntimeError (f"No metastore is configured with FeatureGroup { feature_group .name } ." )
85
+
86
+ record_identifier_feature_name = feature_group_metadata .get ("RecordIdentifierFeatureName" , None )
87
+ event_time_identifier_feature_name = feature_group_metadata .get ("EventTimeFeatureName" , None )
88
+ table_name = data_catalog_config .get ("TableName" , None )
89
+ database = data_catalog_config .get ("Database" , None )
90
+ disable_glue = feature_group_metadata .get ("DisableGlueTableCreation" , False )
91
+ catalog = data_catalog_config .get ("Catalog" , None ) if disable_glue else "AwsDataCatalog"
92
+ features = [
93
+ feature .get ("FeatureName" , None )
94
+ for feature in feature_group_metadata .get ("FeatureDefinitions" , None )
95
+ ]
96
+ return FeatureGroupToBeMerged (
97
+ features ,
98
+ included_feature_names ,
99
+ catalog ,
100
+ database ,
101
+ table_name ,
102
+ record_identifier_feature_name ,
103
+ event_time_identifier_feature_name ,
104
+ target_feature_name_in_base ,
105
+ )
106
+
107
+
60
108
@attr .s
61
109
class DatasetBuilder :
62
110
"""DatasetBuilder definition.
@@ -114,6 +162,14 @@ class DatasetBuilder:
114
162
_event_time_ending_timestamp : datetime .datetime = attr .ib (init = False , default = None )
115
163
_feature_groups_to_be_merged : List [FeatureGroupToBeMerged ] = attr .ib (init = False , default = [])
116
164
165
+ _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = {
166
+ "object" : "STRING" ,
167
+ "int64" : "INT" ,
168
+ "float64" : "DOUBLE" ,
169
+ "bool" : "BOOLEAN" ,
170
+ "datetime64" : "TIMESTAMP" ,
171
+ }
172
+
117
173
def with_feature_group (
118
174
self ,
119
175
feature_group : FeatureGroup ,
@@ -131,38 +187,11 @@ def with_feature_group(
131
187
Returns:
132
188
This DatasetBuilder object.
133
189
"""
134
- # TODO: handle pagination and input feature validation
135
- # TODO: potential refactor with FeatureGroup base
136
- feature_group_metadata = feature_group .describe ()
137
- data_catalog_config = feature_group_metadata .get ("OfflineStoreConfig" , None ).get (
138
- "DataCatalogConfig" , None
139
- )
140
- if not data_catalog_config :
141
- raise RuntimeError (
142
- f"No metastore is configured with FeatureGroup { feature_group .name } ."
143
- )
144
-
145
- record_identifier_feature_name = feature_group_metadata .get (
146
- "RecordIdentifierFeatureName" , None
147
- )
148
- event_time_identifier_feature_name = feature_group_metadata .get (
149
- "EventTimeFeatureName" , None
150
- )
151
- # TODO: back fill feature definitions due to UpdateFG
152
- table_name = data_catalog_config .get ("TableName" , None )
153
- database = data_catalog_config .get ("Database" , None )
154
- features = [feature .feature_name for feature in feature_group .feature_definitions ]
155
190
if not target_feature_name_in_base :
156
191
target_feature_name_in_base = self ._record_identifier_feature_name
157
192
self ._feature_groups_to_be_merged .append (
158
- FeatureGroupToBeMerged (
159
- features ,
160
- included_feature_names ,
161
- database ,
162
- table_name ,
163
- record_identifier_feature_name ,
164
- event_time_identifier_feature_name ,
165
- target_feature_name_in_base ,
193
+ construct_feature_group_to_be_merged (
194
+ feature_group , included_feature_names , target_feature_name_in_base
166
195
)
167
196
)
168
197
return self
@@ -257,61 +286,48 @@ def to_csv(self) -> Tuple[str, str]:
257
286
"""
258
287
if isinstance (self ._base , pd .DataFrame ):
259
288
temp_id = utils .unique_name_from_base ("dataframe-base" )
260
- local_filename = f"{ temp_id } .csv"
289
+ local_file_name = f"{ temp_id } .csv"
261
290
desired_s3_folder = f"{ self ._output_path } /{ temp_id } "
262
- self ._base .to_csv (local_filename , index = False , header = False )
291
+ self ._base .to_csv (local_file_name , index = False , header = False )
263
292
s3 .S3Uploader .upload (
264
- local_path = local_filename ,
293
+ local_path = local_file_name ,
265
294
desired_s3_uri = desired_s3_folder ,
266
295
sagemaker_session = self ._sagemaker_session ,
267
296
kms_key = self ._kms_key_id ,
268
297
)
298
+ os .remove (local_file_name )
269
299
temp_table_name = f"dataframe_{ temp_id } "
270
300
self ._create_temp_table (temp_table_name , desired_s3_folder )
271
301
base_features = list (self ._base .columns )
272
302
query_string = self ._construct_query_string (
273
303
temp_table_name ,
304
+ "AwsDataCatalog" ,
274
305
"sagemaker_featurestore" ,
275
306
base_features ,
276
307
)
277
308
query_result = self ._run_query (query_string , "AwsDataCatalog" , "sagemaker_featurestore" )
278
- # TODO: cleanup local file and temp table
279
- return query_result .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
309
+ # TODO: cleanup temp table, need more clarification, keep it for now
310
+ return query_result .get ("QueryExecution" , {} ).get ("ResultConfiguration" , {} ).get (
280
311
"OutputLocation" , None
281
- ), query_result .get ("QueryExecution" , None ).get ("Query" , None )
312
+ ), query_result .get ("QueryExecution" , {} ).get ("Query" , None )
282
313
if isinstance (self ._base , FeatureGroup ):
283
- # TODO: handle pagination and input feature validation
284
- base_feature_group = self ._base .describe ()
285
- data_catalog_config = base_feature_group .get ("OfflineStoreConfig" , None ).get (
286
- "DataCatalogConfig" , None
287
- )
288
- if not data_catalog_config :
289
- raise RuntimeError ("No metastore is configured with the base FeatureGroup." )
290
- disable_glue = base_feature_group .get ("DisableGlueTableCreation" , False )
291
- self ._record_identifier_feature_name = base_feature_group .get (
292
- "RecordIdentifierFeatureName" , None
314
+ base_feature_group = construct_feature_group_to_be_merged (
315
+ self ._base , self ._included_feature_names
293
316
)
294
- self ._event_time_identifier_feature_name = base_feature_group .get (
295
- "EventTimeFeatureName" , None
296
- )
297
- base_features = [
298
- feature .get ("FeatureName" , None )
299
- for feature in base_feature_group .get ("FeatureDefinitions" , None )
300
- ]
301
-
302
317
query_string = self ._construct_query_string (
303
- data_catalog_config .get ("TableName" , None ),
304
- data_catalog_config .get ("Database" , None ),
305
- base_features ,
318
+ base_feature_group .table_name ,
319
+ base_feature_group .catalog ,
320
+ base_feature_group .database ,
321
+ base_feature_group .features ,
306
322
)
307
323
query_result = self ._run_query (
308
324
query_string ,
309
- data_catalog_config . get ( "Catalog" , None ) if disable_glue else "AwsDataCatalog" ,
310
- data_catalog_config . get ( "Database" , None ) ,
325
+ base_feature_group . catalog ,
326
+ base_feature_group . database ,
311
327
)
312
- return query_result .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
328
+ return query_result .get ("QueryExecution" , {} ).get ("ResultConfiguration" , {} ).get (
313
329
"OutputLocation" , None
314
- ), query_result .get ("QueryExecution" , None ).get ("Query" , None )
330
+ ), query_result .get ("QueryExecution" , {} ).get ("Query" , None )
315
331
raise ValueError ("Base must be either a FeatureGroup or a DataFrame." )
316
332
317
333
def to_dataframe (self ) -> Tuple [str , pd .DataFrame ]:
@@ -328,8 +344,10 @@ def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
328
344
kms_key = self ._kms_key_id ,
329
345
sagemaker_session = self ._sagemaker_session ,
330
346
)
331
- # TODO: do we need to clean up local file?
332
- return query_string , pd .read_csv (csv_file .split ("/" )[- 1 ])
347
+ local_file_name = csv_file .split ("/" )[- 1 ]
348
+ df = pd .read_csv (local_file_name )
349
+ os .remove (local_file_name )
350
+ return query_string , df
333
351
334
352
def _construct_where_query_string (self , suffix : str , event_time_identifier_feature_name : str ):
335
353
"""Internal method for constructing SQL WHERE query string by parameters.
@@ -404,7 +422,7 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
404
422
)
405
423
406
424
def _construct_query_string (
407
- self , base_table_name : str , database : str , base_features : list
425
+ self , base_table_name : str , catalog : str , database : str , base_features : list
408
426
) -> str :
409
427
"""Internal method for constructing SQL query string by parameters.
410
428
@@ -418,6 +436,7 @@ def _construct_query_string(
418
436
base = FeatureGroupToBeMerged (
419
437
base_features ,
420
438
self ._included_feature_names ,
439
+ catalog ,
421
440
database ,
422
441
base_table_name ,
423
442
self ._record_identifier_feature_name ,
@@ -499,19 +518,9 @@ def _construct_athena_table_column_string(self, column: str) -> str:
499
518
RuntimeError: The type of pandas.Dataframe column is not support yet.
500
519
"""
501
520
dataframe_type = self ._base [column ].dtypes
502
- if dataframe_type == "object" :
503
- column_type = "STRING"
504
- elif dataframe_type == "int64" :
505
- column_type = "INT"
506
- elif dataframe_type == "float64" :
507
- column_type = "DOUBLE"
508
- elif dataframe_type == "bool" :
509
- column_type = "BOOLEAN"
510
- elif dataframe_type == "datetime64" :
511
- column_type = "TIMESTAMP"
512
- else :
521
+ if dataframe_type not in self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .keys ():
513
522
raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
514
- return f"{ column } { column_type } "
523
+ return f"{ column } { self . _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP . get ( dataframe_type , None ) } "
515
524
516
525
def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
517
526
"""Internal method for execute Athena query, wait for query finish and get query result.
@@ -536,9 +545,7 @@ def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str
536
545
query_id = query .get ("QueryExecutionId" , None )
537
546
self ._sagemaker_session .wait_for_athena_query (query_execution_id = query_id )
538
547
query_result = self ._sagemaker_session .get_query_execution (query_execution_id = query_id )
539
- query_state = (
540
- query_result .get ("QueryExecution" , None ).get ("Status" , None ).get ("State" , None )
541
- )
548
+ query_state = query_result .get ("QueryExecution" , {}).get ("Status" , {}).get ("State" , None )
542
549
if query_state != "SUCCEEDED" :
543
550
raise RuntimeError (f"Failed to execute query { query_id } ." )
544
551
return query_result
0 commit comments