Skip to content

Commit 98b60c0

Browse files
imingtsoumizanfiu
authored andcommitted
Address TODOs (aws#731)
1 parent 93a867c commit 98b60c0

File tree

2 files changed

+86
-78
lines changed

2 files changed

+86
-78
lines changed

src/sagemaker/feature_store/dataset_builder.py

+85-78
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import absolute_import
1818

1919
import datetime
20+
import os
2021
from typing import Any, Dict, List, Sequence, Tuple, Union
2122

2223
import attr
@@ -39,6 +40,7 @@ class FeatureGroupToBeMerged:
3940
features (List[str]): A list of strings representing feature names of this FeatureGroup.
4041
included_feature_names (Sequence[str]): A list of strings representing features to be
4142
included in the output.
43+
catalog (str): A string representing the catalog.
4244
database (str): A string representing the database.
4345
table_name (str): A string representing the Athena table name of this FeatureGroup.
4446
record_dentifier_feature_name (str): A string representing the record identifier feature.
@@ -50,13 +52,59 @@ class FeatureGroupToBeMerged:
5052

5153
features: List[str] = attr.ib()
5254
included_feature_names: Sequence[str] = attr.ib()
55+
catalog: str = attr.ib()
5356
database: str = attr.ib()
5457
table_name: str = attr.ib()
5558
record_identifier_feature_name: str = attr.ib()
5659
event_time_identifier_feature_name: str = attr.ib()
5760
target_feature_name_in_base: str = attr.ib(default=None)
5861

5962

63+
def construct_feature_group_to_be_merged(
64+
feature_group: FeatureGroup,
65+
included_feature_names: Sequence[str],
66+
target_feature_name_in_base: str = None,
67+
) -> FeatureGroupToBeMerged:
68+
"""Construct a FeatureGroupToBeMerged object by provided parameters.
69+
70+
Args:
71+
feature_group (FeatureGroup): A FeatureGroup object.
72+
included_feature_names (Sequence[str]): A list of strings representing features to be
73+
included in the output.
74+
target_feature_name_in_base (str): A string representing the feature name in base which
75+
will be used as target join key (default: None).
76+
Returns:
77+
A FeatureGroupToBeMerged object.
78+
"""
79+
feature_group_metadata = feature_group.describe()
80+
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
81+
"DataCatalogConfig", None
82+
)
83+
if not data_catalog_config:
84+
raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
85+
86+
record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
87+
event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None)
88+
table_name = data_catalog_config.get("TableName", None)
89+
database = data_catalog_config.get("Database", None)
90+
disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False)
91+
catalog = data_catalog_config.get("Catalog", None) if disable_glue else "AwsDataCatalog"
92+
features = [
93+
feature.get("FeatureName", None)
94+
for feature in feature_group_metadata.get("FeatureDefinitions", None)
95+
]
96+
return FeatureGroupToBeMerged(
97+
features,
98+
included_feature_names,
99+
catalog,
100+
database,
101+
table_name,
102+
record_identifier_feature_name,
103+
event_time_identifier_feature_name,
104+
target_feature_name_in_base,
105+
)
106+
107+
60108
@attr.s
61109
class DatasetBuilder:
62110
"""DatasetBuilder definition.
@@ -114,6 +162,14 @@ class DatasetBuilder:
114162
_event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
115163
_feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, default=[])
116164

165+
_DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = {
166+
"object": "STRING",
167+
"int64": "INT",
168+
"float64": "DOUBLE",
169+
"bool": "BOOLEAN",
170+
"datetime64": "TIMESTAMP",
171+
}
172+
117173
def with_feature_group(
118174
self,
119175
feature_group: FeatureGroup,
@@ -131,38 +187,11 @@ def with_feature_group(
131187
Returns:
132188
This DatasetBuilder object.
133189
"""
134-
# TODO: handle pagination and input feature validation
135-
# TODO: potential refactor with FeatureGroup base
136-
feature_group_metadata = feature_group.describe()
137-
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", None).get(
138-
"DataCatalogConfig", None
139-
)
140-
if not data_catalog_config:
141-
raise RuntimeError(
142-
f"No metastore is configured with FeatureGroup {feature_group.name}."
143-
)
144-
145-
record_identifier_feature_name = feature_group_metadata.get(
146-
"RecordIdentifierFeatureName", None
147-
)
148-
event_time_identifier_feature_name = feature_group_metadata.get(
149-
"EventTimeFeatureName", None
150-
)
151-
# TODO: back fill feature definitions due to UpdateFG
152-
table_name = data_catalog_config.get("TableName", None)
153-
database = data_catalog_config.get("Database", None)
154-
features = [feature.feature_name for feature in feature_group.feature_definitions]
155190
if not target_feature_name_in_base:
156191
target_feature_name_in_base = self._record_identifier_feature_name
157192
self._feature_groups_to_be_merged.append(
158-
FeatureGroupToBeMerged(
159-
features,
160-
included_feature_names,
161-
database,
162-
table_name,
163-
record_identifier_feature_name,
164-
event_time_identifier_feature_name,
165-
target_feature_name_in_base,
193+
construct_feature_group_to_be_merged(
194+
feature_group, included_feature_names, target_feature_name_in_base
166195
)
167196
)
168197
return self
@@ -257,61 +286,48 @@ def to_csv(self) -> Tuple[str, str]:
257286
"""
258287
if isinstance(self._base, pd.DataFrame):
259288
temp_id = utils.unique_name_from_base("dataframe-base")
260-
local_filename = f"{temp_id}.csv"
289+
local_file_name = f"{temp_id}.csv"
261290
desired_s3_folder = f"{self._output_path}/{temp_id}"
262-
self._base.to_csv(local_filename, index=False, header=False)
291+
self._base.to_csv(local_file_name, index=False, header=False)
263292
s3.S3Uploader.upload(
264-
local_path=local_filename,
293+
local_path=local_file_name,
265294
desired_s3_uri=desired_s3_folder,
266295
sagemaker_session=self._sagemaker_session,
267296
kms_key=self._kms_key_id,
268297
)
298+
os.remove(local_file_name)
269299
temp_table_name = f"dataframe_{temp_id}"
270300
self._create_temp_table(temp_table_name, desired_s3_folder)
271301
base_features = list(self._base.columns)
272302
query_string = self._construct_query_string(
273303
temp_table_name,
304+
"AwsDataCatalog",
274305
"sagemaker_featurestore",
275306
base_features,
276307
)
277308
query_result = self._run_query(query_string, "AwsDataCatalog", "sagemaker_featurestore")
278-
# TODO: cleanup local file and temp table
279-
return query_result.get("QueryExecution", None).get("ResultConfiguration", None).get(
309+
# TODO: cleanup temp table, need more clarification, keep it for now
310+
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
280311
"OutputLocation", None
281-
), query_result.get("QueryExecution", None).get("Query", None)
312+
), query_result.get("QueryExecution", {}).get("Query", None)
282313
if isinstance(self._base, FeatureGroup):
283-
# TODO: handle pagination and input feature validation
284-
base_feature_group = self._base.describe()
285-
data_catalog_config = base_feature_group.get("OfflineStoreConfig", None).get(
286-
"DataCatalogConfig", None
287-
)
288-
if not data_catalog_config:
289-
raise RuntimeError("No metastore is configured with the base FeatureGroup.")
290-
disable_glue = base_feature_group.get("DisableGlueTableCreation", False)
291-
self._record_identifier_feature_name = base_feature_group.get(
292-
"RecordIdentifierFeatureName", None
314+
base_feature_group = construct_feature_group_to_be_merged(
315+
self._base, self._included_feature_names
293316
)
294-
self._event_time_identifier_feature_name = base_feature_group.get(
295-
"EventTimeFeatureName", None
296-
)
297-
base_features = [
298-
feature.get("FeatureName", None)
299-
for feature in base_feature_group.get("FeatureDefinitions", None)
300-
]
301-
302317
query_string = self._construct_query_string(
303-
data_catalog_config.get("TableName", None),
304-
data_catalog_config.get("Database", None),
305-
base_features,
318+
base_feature_group.table_name,
319+
base_feature_group.catalog,
320+
base_feature_group.database,
321+
base_feature_group.features,
306322
)
307323
query_result = self._run_query(
308324
query_string,
309-
data_catalog_config.get("Catalog", None) if disable_glue else "AwsDataCatalog",
310-
data_catalog_config.get("Database", None),
325+
base_feature_group.catalog,
326+
base_feature_group.database,
311327
)
312-
return query_result.get("QueryExecution", None).get("ResultConfiguration", None).get(
328+
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
313329
"OutputLocation", None
314-
), query_result.get("QueryExecution", None).get("Query", None)
330+
), query_result.get("QueryExecution", {}).get("Query", None)
315331
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
316332

317333
def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
@@ -328,8 +344,10 @@ def to_dataframe(self) -> Tuple[str, pd.DataFrame]:
328344
kms_key=self._kms_key_id,
329345
sagemaker_session=self._sagemaker_session,
330346
)
331-
# TODO: do we need to clean up local file?
332-
return query_string, pd.read_csv(csv_file.split("/")[-1])
347+
local_file_name = csv_file.split("/")[-1]
348+
df = pd.read_csv(local_file_name)
349+
os.remove(local_file_name)
350+
return query_string, df
333351

334352
def _construct_where_query_string(self, suffix: str, event_time_identifier_feature_name: str):
335353
"""Internal method for constructing SQL WHERE query string by parameters.
@@ -404,7 +422,7 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
404422
)
405423

406424
def _construct_query_string(
407-
self, base_table_name: str, database: str, base_features: list
425+
self, base_table_name: str, catalog: str, database: str, base_features: list
408426
) -> str:
409427
"""Internal method for constructing SQL query string by parameters.
410428
@@ -418,6 +436,7 @@ def _construct_query_string(
418436
base = FeatureGroupToBeMerged(
419437
base_features,
420438
self._included_feature_names,
439+
catalog,
421440
database,
422441
base_table_name,
423442
self._record_identifier_feature_name,
@@ -499,19 +518,9 @@ def _construct_athena_table_column_string(self, column: str) -> str:
499518
RuntimeError: The type of pandas.Dataframe column is not support yet.
500519
"""
501520
dataframe_type = self._base[column].dtypes
502-
if dataframe_type == "object":
503-
column_type = "STRING"
504-
elif dataframe_type == "int64":
505-
column_type = "INT"
506-
elif dataframe_type == "float64":
507-
column_type = "DOUBLE"
508-
elif dataframe_type == "bool":
509-
column_type = "BOOLEAN"
510-
elif dataframe_type == "datetime64":
511-
column_type = "TIMESTAMP"
512-
else:
521+
if dataframe_type not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys():
513522
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
514-
return f"{column} {column_type}"
523+
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(dataframe_type, None)}"
515524

516525
def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
517526
"""Internal method for execute Athena query, wait for query finish and get query result.
@@ -536,9 +545,7 @@ def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str
536545
query_id = query.get("QueryExecutionId", None)
537546
self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id)
538547
query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id)
539-
query_state = (
540-
query_result.get("QueryExecution", None).get("Status", None).get("State", None)
541-
)
548+
query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None)
542549
if query_state != "SUCCEEDED":
543550
raise RuntimeError(f"Failed to execute query {query_id}.")
544551
return query_result

tests/unit/sagemaker/feature_store/test_dataset_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_with_feature_group(sagemaker_session_mock):
6060
"OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}},
6161
"RecordIdentifierFeatureName": "feature-1",
6262
"EventTimeFeatureName": "feature-2",
63+
"FeatureDefinitions": [{"FeatureName": "feature-1"}, {"FeatureName": "feature-2"}],
6364
}
6465
dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"])
6566
assert len(dataset_builder._feature_groups_to_be_merged) == 1

0 commit comments

Comments
 (0)