22
22
import attr
23
23
import pandas as pd
24
24
25
+ from sagemaker import Session
25
26
from sagemaker .feature_store .feature_group import FeatureGroup
26
27
27
28
@@ -33,6 +34,7 @@ class DatasetBuilder:
33
34
an output path and a KMS key ID.
34
35
35
36
Attributes:
37
+ _sagemaker_session (Session): Session instance to perform boto calls.
36
38
_base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
37
39
pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset.
38
40
_output_path (str): An S3 URI which stores the output .csv file.
@@ -59,6 +61,7 @@ class DatasetBuilder:
59
61
dataset will be before it.
60
62
"""
61
63
64
+ _sagemaker_session : Session = attr .ib ()
62
65
_base : Union [FeatureGroup , pd .DataFrame ] = attr .ib ()
63
66
_output_path : str = attr .ib ()
64
67
_record_identifier_feature_name : str = attr .ib (default = None )
@@ -155,3 +158,104 @@ def with_event_time_range(
155
158
self ._event_time_starting_timestamp = starting_timestamp
156
159
self ._event_time_ending_timestamp = ending_timestamp
157
160
return self
161
+
162
+ def to_csv (self ):
163
+ """Get query string and result in .csv format
164
+
165
+ Returns:
166
+ The S3 path of the .csv file.
167
+ The query string executed.
168
+ """
169
+ if isinstance (self ._base , FeatureGroup ):
170
+ # TODO: handle pagination and input feature validation
171
+ base_feature_group = self ._base .describe ()
172
+ data_catalog_config = base_feature_group .get ("OfflineStoreConfig" , None ).get (
173
+ "DataCatalogConfig" , None
174
+ )
175
+ if not data_catalog_config :
176
+ raise RuntimeError ("No metastore is configured with the base FeatureGroup." )
177
+ disable_glue = base_feature_group .get ("DisableGlueTableCreation" , False )
178
+ self ._record_identifier_feature_name = base_feature_group .get (
179
+ "RecordIdentifierFeatureName" , None
180
+ )
181
+ self ._event_time_identifier_feature_name = base_feature_group .get (
182
+ "EventTimeFeatureName" , None
183
+ )
184
+ base_features = [
185
+ feature .get ("FeatureName" , None )
186
+ for feature in base_feature_group .get ("FeatureDefinitions" , None )
187
+ ]
188
+
189
+ query = self ._sagemaker_session .start_query_execution (
190
+ catalog = data_catalog_config .get ("Catalog" , None )
191
+ if disable_glue
192
+ else "AwsDataCatalog" ,
193
+ database = data_catalog_config .get ("Database" , None ),
194
+ query_string = self ._construct_query_string (
195
+ data_catalog_config .get ("TableName" , None ),
196
+ data_catalog_config .get ("Database" , None ),
197
+ base_features ,
198
+ ),
199
+ output_location = self ._output_path ,
200
+ kms_key = self ._kms_key_id ,
201
+ )
202
+ query_id = query .get ("QueryExecutionId" , None )
203
+ self ._sagemaker_session .wait_for_athena_query (
204
+ query_execution_id = query_id ,
205
+ )
206
+ query_state = (
207
+ self ._sagemaker_session .get_query_execution (
208
+ query_execution_id = query_id ,
209
+ )
210
+ .get ("QueryExecution" , None )
211
+ .get ("Status" , None )
212
+ .get ("State" , None )
213
+ )
214
+ if query_state != "SUCCEEDED" :
215
+ raise RuntimeError (f"Failed to execute query { query_id } ." )
216
+
217
+ return query_state .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
218
+ "OutputLocation" , None
219
+ ), query_state .get ("QueryExecution" , None ).get ("Query" , None )
220
+ raise ValueError ("Base must be either a FeatureGroup or a DataFrame." )
221
+
222
+ def _construct_query_string (
223
+ self , base_table_name : str , database : str , base_features : list
224
+ ) -> str :
225
+ """Internal method for constructing SQL query string by parameters.
226
+
227
+ Args:
228
+ base_table_name (str): The Athena table name of base FeatureGroup or pandas.DataFrame.
229
+ database (str): The Athena database of the base table.
230
+ base_features (list): The list of features of the base table.
231
+ Returns:
232
+ The query string.
233
+ """
234
+ included_features = ", " .join (
235
+ [
236
+ f'base."{ include_feature_name } "'
237
+ for include_feature_name in self ._included_feature_names
238
+ ]
239
+ )
240
+ query_string = f"SELECT { included_features } \n "
241
+ if self ._include_duplicated_records :
242
+ query_string += f'FROM "{ database } "."{ base_table_name } " base\n '
243
+ if not self ._include_deleted_records :
244
+ query_string += "WHERE NOT is_deleted\n "
245
+ else :
246
+ base_features .remove (self ._event_time_identifier_feature_name )
247
+ dedup_features = ", " .join ([f'dedup_base."{ feature } "' for feature in base_features ])
248
+ query_string += (
249
+ "FROM (\n "
250
+ + "SELECT *, row_number() OVER (\n "
251
+ + f"PARTITION BY { dedup_features } \n "
252
+ + f'ORDER BY dedup_base."{ self ._event_time_identifier_feature_name } " '
253
+ + 'DESC, dedup_base."api_invocation_time" DESC, dedup_base."write_time" DESC\n '
254
+ + ") AS row_base\n "
255
+ + f'FROM "{ database } "."{ base_table_name } " dedup_base\n '
256
+ + ") AS base\n "
257
+ + "WHERE row_base = 1\n "
258
+ )
259
+ if not self ._include_deleted_records :
260
+ query_string += "AND NOT is_deleted\n "
261
+ return query_string
0 commit comments