17
17
from __future__ import absolute_import
18
18
19
19
import datetime
20
- from typing import Sequence , Union
20
+ from typing import Any , Dict , Sequence , Union
21
21
22
22
import attr
23
23
import pandas as pd
24
24
25
- from sagemaker import Session
25
+ from sagemaker import Session , s3 , utils
26
26
from sagemaker .feature_store .feature_group import FeatureGroup
27
27
28
28
@@ -166,6 +166,30 @@ def to_csv(self):
166
166
The S3 path of the .csv file.
167
167
The query string executed.
168
168
"""
169
+ if isinstance (self ._base , pd .DataFrame ):
170
+ temp_id = utils .unique_name_from_base ("dataframe-base" )
171
+ local_filename = f"{ temp_id } .csv"
172
+ desired_s3_folder = f"{ self ._output_path } /{ temp_id } "
173
+ self ._base .to_csv (local_filename , index = False , header = False )
174
+ s3 .S3Uploader .upload (
175
+ local_path = local_filename ,
176
+ desired_s3_uri = desired_s3_folder ,
177
+ sagemaker_session = self ._sagemaker_session ,
178
+ kms_key = self ._kms_key_id ,
179
+ )
180
+ temp_table_name = f"dataframe_{ temp_id } "
181
+ self ._create_temp_table (temp_table_name , desired_s3_folder )
182
+ base_features = list (self ._base .columns )
183
+ query_string = self ._construct_query_string (
184
+ temp_table_name ,
185
+ "sagemaker_featurestore" ,
186
+ base_features ,
187
+ )
188
+ query_result = self ._run_query (query_string , "AwsDataCatalog" , "sagemaker_featurestore" )
189
+ # TODO: cleanup local file and temp table
190
+ return query_result .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
191
+ "OutputLocation" , None
192
+ ), query_result .get ("QueryExecution" , None ).get ("Query" , None )
169
193
if isinstance (self ._base , FeatureGroup ):
170
194
# TODO: handle pagination and input feature validation
171
195
base_feature_group = self ._base .describe ()
@@ -186,37 +210,19 @@ def to_csv(self):
186
210
for feature in base_feature_group .get ("FeatureDefinitions" , None )
187
211
]
188
212
189
- query = self ._sagemaker_session .start_query_execution (
190
- catalog = data_catalog_config .get ("Catalog" , None )
191
- if disable_glue
192
- else "AwsDataCatalog" ,
193
- database = data_catalog_config .get ("Database" , None ),
194
- query_string = self ._construct_query_string (
195
- data_catalog_config .get ("TableName" , None ),
196
- data_catalog_config .get ("Database" , None ),
197
- base_features ,
198
- ),
199
- output_location = self ._output_path ,
200
- kms_key = self ._kms_key_id ,
201
- )
202
- query_id = query .get ("QueryExecutionId" , None )
203
- self ._sagemaker_session .wait_for_athena_query (
204
- query_execution_id = query_id ,
213
+ query_string = self ._construct_query_string (
214
+ data_catalog_config .get ("TableName" , None ),
215
+ data_catalog_config .get ("Database" , None ),
216
+ base_features ,
205
217
)
206
- query_state = (
207
- self ._sagemaker_session .get_query_execution (
208
- query_execution_id = query_id ,
209
- )
210
- .get ("QueryExecution" , None )
211
- .get ("Status" , None )
212
- .get ("State" , None )
218
+ query_result = self ._run_query (
219
+ query_string ,
220
+ data_catalog_config .get ("Catalog" , None ) if disable_glue else "AwsDataCatalog" ,
221
+ data_catalog_config .get ("Database" , None ),
213
222
)
214
- if query_state != "SUCCEEDED" :
215
- raise RuntimeError (f"Failed to execute query { query_id } ." )
216
-
217
- return query_state .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
223
+ return query_result .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
218
224
"OutputLocation" , None
219
- ), query_state .get ("QueryExecution" , None ).get ("Query" , None )
225
+ ), query_result .get ("QueryExecution" , None ).get ("Query" , None )
220
226
raise ValueError ("Base must be either a FeatureGroup or a DataFrame." )
221
227
222
228
def _construct_query_string (
@@ -259,3 +265,78 @@ def _construct_query_string(
259
265
if not self ._include_deleted_records :
260
266
query_string += "AND NOT is_deleted\n "
261
267
return query_string
268
+
269
+ def _create_temp_table (self , temp_table_name : str , desired_s3_folder : str ):
270
+ """Internal method for creating a temp Athena table for the base pandas.Dataframe.
271
+
272
+ Args:
273
+ temp_table_name (str): The Athena table name of base pandas.DataFrame.
274
+ desired_s3_folder (str): The S3 URI of the folder of the data.
275
+ """
276
+ columns_string = ", " .join (
277
+ [self ._construct_athena_table_column_string (column ) for column in self ._base .columns ]
278
+ )
279
+ serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\ \\ "'
280
+ query_string = (
281
+ f"CREATE EXTERNAL TABLE { temp_table_name } ({ columns_string } ) "
282
+ + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' "
283
+ + f"WITH SERDEPROPERTIES ({ serde_properties } ) "
284
+ + f"LOCATION '{ desired_s3_folder } ';"
285
+ )
286
+ self ._run_query (query_string , "AwsDataCatalog" , "sagemaker_featurestore" )
287
+
288
+ def _construct_athena_table_column_string (self , column : str ) -> str :
289
+ """Internal method for constructing string of Athena column.
290
+
291
+ Args:
292
+ column (str): The column name from pandas.Dataframe.
293
+ Returns:
294
+ The Athena column string.
295
+
296
+ Raises:
297
+ RuntimeError: The type of pandas.Dataframe column is not support yet.
298
+ """
299
+ dataframe_type = self ._base [column ].dtypes
300
+ if dataframe_type == "object" :
301
+ column_type = "STRING"
302
+ elif dataframe_type == "int64" :
303
+ column_type = "INT"
304
+ elif dataframe_type == "float64" :
305
+ column_type = "DOUBLE"
306
+ elif dataframe_type == "bool" :
307
+ column_type = "BOOLEAN"
308
+ elif dataframe_type == "datetime64" :
309
+ column_type = "TIMESTAMP"
310
+ else :
311
+ raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
312
+ return f"{ column } { column_type } "
313
+
314
+ def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
315
+ """Internal method for execute Athena query, wait for query finish and get query result.
316
+
317
+ Args:
318
+ query_string (str): The SQL query statements to be executed.
319
+ catalog (str): The name of the data catalog used in the query execution.
320
+ database (str): The name of the database used in the query execution.
321
+ Returns:
322
+ The query result.
323
+
324
+ Raises:
325
+ RuntimeError: Athena query failed.
326
+ """
327
+ query = self ._sagemaker_session .start_query_execution (
328
+ catalog = catalog ,
329
+ database = database ,
330
+ query_string = query_string ,
331
+ output_location = self ._output_path ,
332
+ kms_key = self ._kms_key_id ,
333
+ )
334
+ query_id = query .get ("QueryExecutionId" , None )
335
+ self ._sagemaker_session .wait_for_athena_query (query_execution_id = query_id )
336
+ query_result = self ._sagemaker_session .get_query_execution (query_execution_id = query_id )
337
+ query_state = (
338
+ query_result .get ("QueryExecution" , None ).get ("Status" , None ).get ("State" , None )
339
+ )
340
+ if query_state != "SUCCEEDED" :
341
+ raise RuntimeError (f"Failed to execute query { query_id } ." )
342
+ return query_result
0 commit comments