Skip to content

Commit b0e537b

Browse files
imingtsoumizanfiu
authored andcommitted
feat: Add pandas.Dataframe as base case (aws#708)
1 parent 9fb4ae2 commit b0e537b

File tree

1 file changed

+111
-30
lines changed

1 file changed

+111
-30
lines changed

src/sagemaker/feature_store/dataset_builder.py

+111-30
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from __future__ import absolute_import
1818

1919
import datetime
20-
from typing import Sequence, Union
20+
from typing import Any, Dict, Sequence, Union
2121

2222
import attr
2323
import pandas as pd
2424

25-
from sagemaker import Session
25+
from sagemaker import Session, s3, utils
2626
from sagemaker.feature_store.feature_group import FeatureGroup
2727

2828

@@ -166,6 +166,30 @@ def to_csv(self):
166166
The S3 path of the .csv file.
167167
The query string executed.
168168
"""
169+
if isinstance(self._base, pd.DataFrame):
170+
temp_id = utils.unique_name_from_base("dataframe-base")
171+
local_filename = f"{temp_id}.csv"
172+
desired_s3_folder = f"{self._output_path}/{temp_id}"
173+
self._base.to_csv(local_filename, index=False, header=False)
174+
s3.S3Uploader.upload(
175+
local_path=local_filename,
176+
desired_s3_uri=desired_s3_folder,
177+
sagemaker_session=self._sagemaker_session,
178+
kms_key=self._kms_key_id,
179+
)
180+
temp_table_name = f"dataframe_{temp_id}"
181+
self._create_temp_table(temp_table_name, desired_s3_folder)
182+
base_features = list(self._base.columns)
183+
query_string = self._construct_query_string(
184+
temp_table_name,
185+
"sagemaker_featurestore",
186+
base_features,
187+
)
188+
query_result = self._run_query(query_string, "AwsDataCatalog", "sagemaker_featurestore")
189+
# TODO: cleanup local file and temp table
190+
return query_result.get("QueryExecution", None).get("ResultConfiguration", None).get(
191+
"OutputLocation", None
192+
), query_result.get("QueryExecution", None).get("Query", None)
169193
if isinstance(self._base, FeatureGroup):
170194
# TODO: handle pagination and input feature validation
171195
base_feature_group = self._base.describe()
@@ -186,37 +210,19 @@ def to_csv(self):
186210
for feature in base_feature_group.get("FeatureDefinitions", None)
187211
]
188212

189-
query = self._sagemaker_session.start_query_execution(
190-
catalog=data_catalog_config.get("Catalog", None)
191-
if disable_glue
192-
else "AwsDataCatalog",
193-
database=data_catalog_config.get("Database", None),
194-
query_string=self._construct_query_string(
195-
data_catalog_config.get("TableName", None),
196-
data_catalog_config.get("Database", None),
197-
base_features,
198-
),
199-
output_location=self._output_path,
200-
kms_key=self._kms_key_id,
201-
)
202-
query_id = query.get("QueryExecutionId", None)
203-
self._sagemaker_session.wait_for_athena_query(
204-
query_execution_id=query_id,
213+
query_string = self._construct_query_string(
214+
data_catalog_config.get("TableName", None),
215+
data_catalog_config.get("Database", None),
216+
base_features,
205217
)
206-
query_state = (
207-
self._sagemaker_session.get_query_execution(
208-
query_execution_id=query_id,
209-
)
210-
.get("QueryExecution", None)
211-
.get("Status", None)
212-
.get("State", None)
218+
query_result = self._run_query(
219+
query_string,
220+
data_catalog_config.get("Catalog", None) if disable_glue else "AwsDataCatalog",
221+
data_catalog_config.get("Database", None),
213222
)
214-
if query_state != "SUCCEEDED":
215-
raise RuntimeError(f"Failed to execute query {query_id}.")
216-
217-
return query_state.get("QueryExecution", None).get("ResultConfiguration", None).get(
223+
return query_result.get("QueryExecution", None).get("ResultConfiguration", None).get(
218224
"OutputLocation", None
219-
), query_state.get("QueryExecution", None).get("Query", None)
225+
), query_result.get("QueryExecution", None).get("Query", None)
220226
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
221227

222228
def _construct_query_string(
@@ -259,3 +265,78 @@ def _construct_query_string(
259265
if not self._include_deleted_records:
260266
query_string += "AND NOT is_deleted\n"
261267
return query_string
268+
269+
def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
270+
"""Internal method for creating a temp Athena table for the base pandas.Dataframe.
271+
272+
Args:
273+
temp_table_name (str): The Athena table name of base pandas.DataFrame.
274+
desired_s3_folder (str): The S3 URI of the folder of the data.
275+
"""
276+
columns_string = ", ".join(
277+
[self._construct_athena_table_column_string(column) for column in self._base.columns]
278+
)
279+
serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"'
280+
query_string = (
281+
f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) "
282+
+ "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' "
283+
+ f"WITH SERDEPROPERTIES ({serde_properties}) "
284+
+ f"LOCATION '{desired_s3_folder}';"
285+
)
286+
self._run_query(query_string, "AwsDataCatalog", "sagemaker_featurestore")
287+
288+
def _construct_athena_table_column_string(self, column: str) -> str:
289+
"""Internal method for constructing string of Athena column.
290+
291+
Args:
292+
column (str): The column name from pandas.Dataframe.
293+
Returns:
294+
The Athena column string.
295+
296+
Raises:
297+
RuntimeError: The type of pandas.Dataframe column is not support yet.
298+
"""
299+
dataframe_type = self._base[column].dtypes
300+
if dataframe_type == "object":
301+
column_type = "STRING"
302+
elif dataframe_type == "int64":
303+
column_type = "INT"
304+
elif dataframe_type == "float64":
305+
column_type = "DOUBLE"
306+
elif dataframe_type == "bool":
307+
column_type = "BOOLEAN"
308+
elif dataframe_type == "datetime64":
309+
column_type = "TIMESTAMP"
310+
else:
311+
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
312+
return f"{column} {column_type}"
313+
314+
def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
315+
"""Internal method for execute Athena query, wait for query finish and get query result.
316+
317+
Args:
318+
query_string (str): The SQL query statements to be executed.
319+
catalog (str): The name of the data catalog used in the query execution.
320+
database (str): The name of the database used in the query execution.
321+
Returns:
322+
The query result.
323+
324+
Raises:
325+
RuntimeError: Athena query failed.
326+
"""
327+
query = self._sagemaker_session.start_query_execution(
328+
catalog=catalog,
329+
database=database,
330+
query_string=query_string,
331+
output_location=self._output_path,
332+
kms_key=self._kms_key_id,
333+
)
334+
query_id = query.get("QueryExecutionId", None)
335+
self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id)
336+
query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id)
337+
query_state = (
338+
query_result.get("QueryExecution", None).get("Status", None).get("State", None)
339+
)
340+
if query_state != "SUCCEEDED":
341+
raise RuntimeError(f"Failed to execute query {query_id}.")
342+
return query_result

0 commit comments

Comments
 (0)