@@ -129,17 +129,13 @@ def as_dataframe(self) -> DataFrame:
129
129
Returns:
130
130
A pandas DataFrame contains the query result.
131
131
"""
132
- query_state = (
133
- self .get_query_execution ().get ("QueryExecution" ).get ("Status" ).get ("State" )
134
- )
132
+ query_state = self .get_query_execution ().get ("QueryExecution" ).get ("Status" ).get ("State" )
135
133
if query_state != "SUCCEEDED" :
136
134
if query_state in ("QUEUED" , "RUNNING" ):
137
135
raise RuntimeError (
138
136
f"Current query { self ._current_query_execution_id } is still being executed."
139
137
)
140
- raise RuntimeError (
141
- f"Failed to execute query { self ._current_query_execution_id } "
142
- )
138
+ raise RuntimeError (f"Failed to execute query { self ._current_query_execution_id } " )
143
139
144
140
output_filename = os .path .join (
145
141
tempfile .gettempdir (), f"{ self ._current_query_execution_id } .csv"
@@ -199,10 +195,7 @@ def _ingest_single_batch(
199
195
List of row indices that failed to be ingested.
200
196
"""
201
197
retry_config = client_config .retries
202
- if (
203
- "max_attempts" not in retry_config
204
- and "total_max_attempts" not in retry_config
205
- ):
198
+ if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config :
206
199
client_config = copy .deepcopy (client_config )
207
200
client_config .retries = {"max_attempts" : 10 , "mode" : "standard" }
208
201
sagemaker_featurestore_runtime_client = boto3 .Session ().client (
@@ -260,9 +253,7 @@ def wait(self, timeout=None):
260
253
self ._processing_pool .clear ()
261
254
262
255
self ._failed_indices = [
263
- failed_index
264
- for failed_indices in results
265
- for failed_index in failed_indices
256
+ failed_index for failed_indices in results for failed_index in failed_indices
266
257
]
267
258
268
259
if len (self ._failed_indices ) > 0 :
@@ -479,8 +470,7 @@ def create(
479
470
record_identifier_name = record_identifier_name ,
480
471
event_time_feature_name = event_time_feature_name ,
481
472
feature_definitions = [
482
- feature_definition .to_dict ()
483
- for feature_definition in self .feature_definitions
473
+ feature_definition .to_dict () for feature_definition in self .feature_definitions
484
474
],
485
475
role_arn = role_arn ,
486
476
description = description ,
@@ -489,16 +479,12 @@ def create(
489
479
490
480
# online store configuration
491
481
if enable_online_store :
492
- online_store_config = OnlineStoreConfig (
493
- enable_online_store = enable_online_store
494
- )
482
+ online_store_config = OnlineStoreConfig (enable_online_store = enable_online_store )
495
483
if online_store_kms_key_id is not None :
496
- online_store_config .online_store_security_config = (
497
- OnlineStoreSecurityConfig ( kms_key_id = online_store_kms_key_id )
484
+ online_store_config .online_store_security_config = OnlineStoreSecurityConfig (
485
+ kms_key_id = online_store_kms_key_id
498
486
)
499
- create_feature_store_args .update (
500
- {"online_store_config" : online_store_config .to_dict ()}
501
- )
487
+ create_feature_store_args .update ({"online_store_config" : online_store_config .to_dict ()})
502
488
503
489
# offline store configuration
504
490
if s3_uri :
@@ -642,25 +628,19 @@ def athena_query(self) -> AthenaQuery:
642
628
An instance of AthenaQuery initialized with data catalog configurations.
643
629
"""
644
630
response = self .describe ()
645
- data_catalog_config = response .get ("OfflineStoreConfig" ).get (
646
- "DataCatalogConfig" , None
647
- )
631
+ data_catalog_config = response .get ("OfflineStoreConfig" ).get ("DataCatalogConfig" , None )
648
632
disable_glue = data_catalog_config .get ("DisableGlueTableCreation" , False )
649
633
if data_catalog_config :
650
634
query = AthenaQuery (
651
- catalog = data_catalog_config ["Catalog" ]
652
- if disable_glue
653
- else "AwsDataCatalog" ,
635
+ catalog = data_catalog_config ["Catalog" ] if disable_glue else "AwsDataCatalog" ,
654
636
database = data_catalog_config ["Database" ],
655
637
table_name = data_catalog_config ["TableName" ],
656
638
sagemaker_session = self .sagemaker_session ,
657
639
)
658
640
return query
659
641
raise RuntimeError ("No metastore is configured with this feature group." )
660
642
661
- def as_hive_ddl (
662
- self , database : str = "sagemaker_featurestore" , table_name : str = None
663
- ) -> str :
643
+ def as_hive_ddl (self , database : str = "sagemaker_featurestore" , table_name : str = None ) -> str :
664
644
"""Generate Hive DDL commands to define or change structure of tables or databases in Hive.
665
645
666
646
Schema of the table is generated based on the feature definitions. Columns are named
0 commit comments