Skip to content

Commit dac0647

Browse files
committed
black autoreformat take 2
1 parent 2422a36 commit dac0647

File tree

1 file changed

+12
-32
lines changed

1 file changed

+12
-32
lines changed

src/sagemaker/feature_store/feature_group.py

+12-32
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,13 @@ def as_dataframe(self) -> DataFrame:
129129
Returns:
130130
A pandas DataFrame contains the query result.
131131
"""
132-
query_state = (
133-
self.get_query_execution().get("QueryExecution").get("Status").get("State")
134-
)
132+
query_state = self.get_query_execution().get("QueryExecution").get("Status").get("State")
135133
if query_state != "SUCCEEDED":
136134
if query_state in ("QUEUED", "RUNNING"):
137135
raise RuntimeError(
138136
f"Current query {self._current_query_execution_id} is still being executed."
139137
)
140-
raise RuntimeError(
141-
f"Failed to execute query {self._current_query_execution_id}"
142-
)
138+
raise RuntimeError(f"Failed to execute query {self._current_query_execution_id}")
143139

144140
output_filename = os.path.join(
145141
tempfile.gettempdir(), f"{self._current_query_execution_id}.csv"
@@ -199,10 +195,7 @@ def _ingest_single_batch(
199195
List of row indices that failed to be ingested.
200196
"""
201197
retry_config = client_config.retries
202-
if (
203-
"max_attempts" not in retry_config
204-
and "total_max_attempts" not in retry_config
205-
):
198+
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
206199
client_config = copy.deepcopy(client_config)
207200
client_config.retries = {"max_attempts": 10, "mode": "standard"}
208201
sagemaker_featurestore_runtime_client = boto3.Session().client(
@@ -260,9 +253,7 @@ def wait(self, timeout=None):
260253
self._processing_pool.clear()
261254

262255
self._failed_indices = [
263-
failed_index
264-
for failed_indices in results
265-
for failed_index in failed_indices
256+
failed_index for failed_indices in results for failed_index in failed_indices
266257
]
267258

268259
if len(self._failed_indices) > 0:
@@ -479,8 +470,7 @@ def create(
479470
record_identifier_name=record_identifier_name,
480471
event_time_feature_name=event_time_feature_name,
481472
feature_definitions=[
482-
feature_definition.to_dict()
483-
for feature_definition in self.feature_definitions
473+
feature_definition.to_dict() for feature_definition in self.feature_definitions
484474
],
485475
role_arn=role_arn,
486476
description=description,
@@ -489,16 +479,12 @@ def create(
489479

490480
# online store configuration
491481
if enable_online_store:
492-
online_store_config = OnlineStoreConfig(
493-
enable_online_store=enable_online_store
494-
)
482+
online_store_config = OnlineStoreConfig(enable_online_store=enable_online_store)
495483
if online_store_kms_key_id is not None:
496-
online_store_config.online_store_security_config = (
497-
OnlineStoreSecurityConfig(kms_key_id=online_store_kms_key_id)
484+
online_store_config.online_store_security_config = OnlineStoreSecurityConfig(
485+
kms_key_id=online_store_kms_key_id
498486
)
499-
create_feature_store_args.update(
500-
{"online_store_config": online_store_config.to_dict()}
501-
)
487+
create_feature_store_args.update({"online_store_config": online_store_config.to_dict()})
502488

503489
# offline store configuration
504490
if s3_uri:
@@ -642,25 +628,19 @@ def athena_query(self) -> AthenaQuery:
642628
An instance of AthenaQuery initialized with data catalog configurations.
643629
"""
644630
response = self.describe()
645-
data_catalog_config = response.get("OfflineStoreConfig").get(
646-
"DataCatalogConfig", None
647-
)
631+
data_catalog_config = response.get("OfflineStoreConfig").get("DataCatalogConfig", None)
648632
disable_glue = data_catalog_config.get("DisableGlueTableCreation", False)
649633
if data_catalog_config:
650634
query = AthenaQuery(
651-
catalog=data_catalog_config["Catalog"]
652-
if disable_glue
653-
else "AwsDataCatalog",
635+
catalog=data_catalog_config["Catalog"] if disable_glue else "AwsDataCatalog",
654636
database=data_catalog_config["Database"],
655637
table_name=data_catalog_config["TableName"],
656638
sagemaker_session=self.sagemaker_session,
657639
)
658640
return query
659641
raise RuntimeError("No metastore is configured with this feature group.")
660642

661-
def as_hive_ddl(
662-
self, database: str = "sagemaker_featurestore", table_name: str = None
663-
) -> str:
643+
def as_hive_ddl(self, database: str = "sagemaker_featurestore", table_name: str = None) -> str:
664644
"""Generate Hive DDL commands to define or change structure of tables or databases in Hive.
665645
666646
Schema of the table is generated based on the feature definitions. Columns are named

0 commit comments

Comments
 (0)