@@ -129,13 +129,17 @@ def as_dataframe(self) -> DataFrame:
129
129
Returns:
130
130
A pandas DataFrame contains the query result.
131
131
"""
132
- query_state = self .get_query_execution ().get ("QueryExecution" ).get ("Status" ).get ("State" )
132
+ query_state = (
133
+ self .get_query_execution ().get ("QueryExecution" ).get ("Status" ).get ("State" )
134
+ )
133
135
if query_state != "SUCCEEDED" :
134
136
if query_state in ("QUEUED" , "RUNNING" ):
135
137
raise RuntimeError (
136
138
f"Current query { self ._current_query_execution_id } is still being executed."
137
139
)
138
- raise RuntimeError (f"Failed to execute query { self ._current_query_execution_id } " )
140
+ raise RuntimeError (
141
+ f"Failed to execute query { self ._current_query_execution_id } "
142
+ )
139
143
140
144
output_filename = os .path .join (
141
145
tempfile .gettempdir (), f"{ self ._current_query_execution_id } .csv"
@@ -195,7 +199,10 @@ def _ingest_single_batch(
195
199
List of row indices that failed to be ingested.
196
200
"""
197
201
retry_config = client_config .retries
198
- if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config :
202
+ if (
203
+ "max_attempts" not in retry_config
204
+ and "total_max_attempts" not in retry_config
205
+ ):
199
206
client_config = copy .deepcopy (client_config )
200
207
client_config .retries = {"max_attempts" : 10 , "mode" : "standard" }
201
208
sagemaker_featurestore_runtime_client = boto3 .Session ().client (
@@ -207,7 +214,8 @@ def _ingest_single_batch(
207
214
for row in data_frame [start_index :end_index ].itertuples ():
208
215
record = [
209
216
FeatureValue (
210
- feature_name = data_frame .columns [index - 1 ], value_as_string = str (row [index ])
217
+ feature_name = data_frame .columns [index - 1 ],
218
+ value_as_string = str (row [index ]),
211
219
)
212
220
for index in range (1 , len (row ))
213
221
if pd .notna (row [index ])
@@ -252,7 +260,9 @@ def wait(self, timeout=None):
252
260
self ._processing_pool .clear ()
253
261
254
262
self ._failed_indices = [
255
- failed_index for failed_indices in results for failed_index in failed_indices
263
+ failed_index
264
+ for failed_indices in results
265
+ for failed_index in failed_indices
256
266
]
257
267
258
268
if len (self ._failed_indices ) > 0 :
@@ -469,7 +479,8 @@ def create(
469
479
record_identifier_name = record_identifier_name ,
470
480
event_time_feature_name = event_time_feature_name ,
471
481
feature_definitions = [
472
- feature_definition .to_dict () for feature_definition in self .feature_definitions
482
+ feature_definition .to_dict ()
483
+ for feature_definition in self .feature_definitions
473
484
],
474
485
role_arn = role_arn ,
475
486
description = description ,
@@ -478,12 +489,16 @@ def create(
478
489
479
490
# online store configuration
480
491
if enable_online_store :
481
- online_store_config = OnlineStoreConfig (enable_online_store = enable_online_store )
492
+ online_store_config = OnlineStoreConfig (
493
+ enable_online_store = enable_online_store
494
+ )
482
495
if online_store_kms_key_id is not None :
483
- online_store_config .online_store_security_config = OnlineStoreSecurityConfig (
484
- kms_key_id = online_store_kms_key_id
496
+ online_store_config .online_store_security_config = (
497
+ OnlineStoreSecurityConfig ( kms_key_id = online_store_kms_key_id )
485
498
)
486
- create_feature_store_args .update ({"online_store_config" : online_store_config .to_dict ()})
499
+ create_feature_store_args .update (
500
+ {"online_store_config" : online_store_config .to_dict ()}
501
+ )
487
502
488
503
# offline store configuration
489
504
if s3_uri :
@@ -627,19 +642,25 @@ def athena_query(self) -> AthenaQuery:
627
642
An instance of AthenaQuery initialized with data catalog configurations.
628
643
"""
629
644
response = self .describe ()
630
- data_catalog_config = response .get ("OfflineStoreConfig" ).get ("DataCatalogConfig" , None )
645
+ data_catalog_config = response .get ("OfflineStoreConfig" ).get (
646
+ "DataCatalogConfig" , None
647
+ )
631
648
disable_glue = data_catalog_config .get ("DisableGlueTableCreation" , False )
632
649
if data_catalog_config :
633
650
query = AthenaQuery (
634
- catalog = data_catalog_config ["Catalog" ] if disable_glue else "AwsDataCatalog" ,
651
+ catalog = data_catalog_config ["Catalog" ]
652
+ if disable_glue
653
+ else "AwsDataCatalog" ,
635
654
database = data_catalog_config ["Database" ],
636
655
table_name = data_catalog_config ["TableName" ],
637
656
sagemaker_session = self .sagemaker_session ,
638
657
)
639
658
return query
640
659
raise RuntimeError ("No metastore is configured with this feature group." )
641
660
642
- def as_hive_ddl (self , database : str = "sagemaker_featurestore" , table_name : str = None ) -> str :
661
+ def as_hive_ddl (
662
+ self , database : str = "sagemaker_featurestore" , table_name : str = None
663
+ ) -> str :
643
664
"""Generate Hive DDL commands to define or change structure of tables or databases in Hive.
644
665
645
666
Schema of the table is generated based on the feature definitions. Columns are named
0 commit comments