Skip to content

Commit 2422a36

Browse files
committed
black autoreformat
1 parent c5f2a3e commit 2422a36

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

src/sagemaker/feature_store/feature_group.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,17 @@ def as_dataframe(self) -> DataFrame:
129129
Returns:
130130
A pandas DataFrame contains the query result.
131131
"""
132-
query_state = self.get_query_execution().get("QueryExecution").get("Status").get("State")
132+
query_state = (
133+
self.get_query_execution().get("QueryExecution").get("Status").get("State")
134+
)
133135
if query_state != "SUCCEEDED":
134136
if query_state in ("QUEUED", "RUNNING"):
135137
raise RuntimeError(
136138
f"Current query {self._current_query_execution_id} is still being executed."
137139
)
138-
raise RuntimeError(f"Failed to execute query {self._current_query_execution_id}")
140+
raise RuntimeError(
141+
f"Failed to execute query {self._current_query_execution_id}"
142+
)
139143

140144
output_filename = os.path.join(
141145
tempfile.gettempdir(), f"{self._current_query_execution_id}.csv"
@@ -195,7 +199,10 @@ def _ingest_single_batch(
195199
List of row indices that failed to be ingested.
196200
"""
197201
retry_config = client_config.retries
198-
if "max_attempts" not in retry_config and "total_max_attempts" not in retry_config:
202+
if (
203+
"max_attempts" not in retry_config
204+
and "total_max_attempts" not in retry_config
205+
):
199206
client_config = copy.deepcopy(client_config)
200207
client_config.retries = {"max_attempts": 10, "mode": "standard"}
201208
sagemaker_featurestore_runtime_client = boto3.Session().client(
@@ -207,7 +214,8 @@ def _ingest_single_batch(
207214
for row in data_frame[start_index:end_index].itertuples():
208215
record = [
209216
FeatureValue(
210-
feature_name=data_frame.columns[index - 1], value_as_string=str(row[index])
217+
feature_name=data_frame.columns[index - 1],
218+
value_as_string=str(row[index]),
211219
)
212220
for index in range(1, len(row))
213221
if pd.notna(row[index])
@@ -252,7 +260,9 @@ def wait(self, timeout=None):
252260
self._processing_pool.clear()
253261

254262
self._failed_indices = [
255-
failed_index for failed_indices in results for failed_index in failed_indices
263+
failed_index
264+
for failed_indices in results
265+
for failed_index in failed_indices
256266
]
257267

258268
if len(self._failed_indices) > 0:
@@ -469,7 +479,8 @@ def create(
469479
record_identifier_name=record_identifier_name,
470480
event_time_feature_name=event_time_feature_name,
471481
feature_definitions=[
472-
feature_definition.to_dict() for feature_definition in self.feature_definitions
482+
feature_definition.to_dict()
483+
for feature_definition in self.feature_definitions
473484
],
474485
role_arn=role_arn,
475486
description=description,
@@ -478,12 +489,16 @@ def create(
478489

479490
# online store configuration
480491
if enable_online_store:
481-
online_store_config = OnlineStoreConfig(enable_online_store=enable_online_store)
492+
online_store_config = OnlineStoreConfig(
493+
enable_online_store=enable_online_store
494+
)
482495
if online_store_kms_key_id is not None:
483-
online_store_config.online_store_security_config = OnlineStoreSecurityConfig(
484-
kms_key_id=online_store_kms_key_id
496+
online_store_config.online_store_security_config = (
497+
OnlineStoreSecurityConfig(kms_key_id=online_store_kms_key_id)
485498
)
486-
create_feature_store_args.update({"online_store_config": online_store_config.to_dict()})
499+
create_feature_store_args.update(
500+
{"online_store_config": online_store_config.to_dict()}
501+
)
487502

488503
# offline store configuration
489504
if s3_uri:
@@ -627,19 +642,25 @@ def athena_query(self) -> AthenaQuery:
627642
An instance of AthenaQuery initialized with data catalog configurations.
628643
"""
629644
response = self.describe()
630-
data_catalog_config = response.get("OfflineStoreConfig").get("DataCatalogConfig", None)
645+
data_catalog_config = response.get("OfflineStoreConfig").get(
646+
"DataCatalogConfig", None
647+
)
631648
disable_glue = data_catalog_config.get("DisableGlueTableCreation", False)
632649
if data_catalog_config:
633650
query = AthenaQuery(
634-
catalog=data_catalog_config["Catalog"] if disable_glue else "AwsDataCatalog",
651+
catalog=data_catalog_config["Catalog"]
652+
if disable_glue
653+
else "AwsDataCatalog",
635654
database=data_catalog_config["Database"],
636655
table_name=data_catalog_config["TableName"],
637656
sagemaker_session=self.sagemaker_session,
638657
)
639658
return query
640659
raise RuntimeError("No metastore is configured with this feature group.")
641660

642-
def as_hive_ddl(self, database: str = "sagemaker_featurestore", table_name: str = None) -> str:
661+
def as_hive_ddl(
662+
self, database: str = "sagemaker_featurestore", table_name: str = None
663+
) -> str:
643664
"""Generate Hive DDL commands to define or change structure of tables or databases in Hive.
644665
645666
Schema of the table is generated based on the feature definitions. Columns are named

0 commit comments

Comments
 (0)