|
22 | 22 | import pytest
|
23 | 23 | from pandas import DataFrame
|
24 | 24 |
|
| 25 | +from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition |
25 | 26 | from sagemaker.feature_store.feature_group import FeatureGroup
|
26 |
| -from sagemaker.feature_store.inputs import FeatureValue |
| 27 | +from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter |
27 | 28 | from sagemaker.session import get_execution_role, Session
|
28 | 29 | from tests.integ.timeout import timeout
|
29 | 30 |
|
@@ -237,6 +238,83 @@ def test_create_feature_store(
|
237 | 238 | assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
|
238 | 239 |
|
239 | 240 |
|
| 241 | +def test_update_feature_group( |
| 242 | + feature_store_session, |
| 243 | + role, |
| 244 | + feature_group_name, |
| 245 | + offline_store_s3_uri, |
| 246 | + pandas_data_frame, |
| 247 | +): |
| 248 | + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) |
| 249 | + feature_group.load_feature_definitions(data_frame=pandas_data_frame) |
| 250 | + |
| 251 | + with cleanup_feature_group(feature_group): |
| 252 | + feature_group.create( |
| 253 | + s3_uri=offline_store_s3_uri, |
| 254 | + record_identifier_name="feature1", |
| 255 | + event_time_feature_name="feature3", |
| 256 | + role_arn=role, |
| 257 | + enable_online_store=True, |
| 258 | + ) |
| 259 | + _wait_for_feature_group_create(feature_group) |
| 260 | + |
| 261 | + new_feature_name = "new_feature" |
| 262 | + new_features = [FractionalFeatureDefinition(feature_name=new_feature_name)] |
| 263 | + feature_group.update(new_features) |
| 264 | + _wait_for_feature_group_update(feature_group) |
| 265 | + feature_definitions = feature_group.describe().get("FeatureDefinitions") |
| 266 | + assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) |
| 267 | + |
| 268 | + |
| 269 | +def test_feature_metadata( |
| 270 | + feature_store_session, |
| 271 | + role, |
| 272 | + feature_group_name, |
| 273 | + offline_store_s3_uri, |
| 274 | + pandas_data_frame, |
| 275 | +): |
| 276 | + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) |
| 277 | + feature_group.load_feature_definitions(data_frame=pandas_data_frame) |
| 278 | + |
| 279 | + with cleanup_feature_group(feature_group): |
| 280 | + feature_group.create( |
| 281 | + s3_uri=offline_store_s3_uri, |
| 282 | + record_identifier_name="feature1", |
| 283 | + event_time_feature_name="feature3", |
| 284 | + role_arn=role, |
| 285 | + enable_online_store=True, |
| 286 | + ) |
| 287 | + _wait_for_feature_group_create(feature_group) |
| 288 | + |
| 289 | + parameter_additions = [ |
| 290 | + FeatureParameter(key="key1", value="value1"), |
| 291 | + FeatureParameter(key="key2", value="value2"), |
| 292 | + ] |
| 293 | + description = "test description" |
| 294 | + feature_name = "feature1" |
| 295 | + feature_group.update_feature_metadata( |
| 296 | + feature_name=feature_name, |
| 297 | + description=description, |
| 298 | + parameter_additions=parameter_additions, |
| 299 | + ) |
| 300 | + describe_feature_metadata = feature_group.describe_feature_metadata( |
| 301 | + feature_name=feature_name |
| 302 | + ) |
| 303 | + print(describe_feature_metadata) |
| 304 | + assert description == describe_feature_metadata.get("Description") |
| 305 | + assert 2 == len(describe_feature_metadata.get("Parameters")) |
| 306 | + |
| 307 | + parameter_removals = ["key1"] |
| 308 | + feature_group.update_feature_metadata( |
| 309 | + feature_name=feature_name, parameter_removals=parameter_removals |
| 310 | + ) |
| 311 | + describe_feature_metadata = feature_group.describe_feature_metadata( |
| 312 | + feature_name=feature_name |
| 313 | + ) |
| 314 | + assert description == describe_feature_metadata.get("Description") |
| 315 | + assert 1 == len(describe_feature_metadata.get("Parameters")) |
| 316 | + |
| 317 | + |
240 | 318 | def test_ingest_without_string_feature(
|
241 | 319 | feature_store_session,
|
242 | 320 | role,
|
@@ -304,6 +382,18 @@ def _wait_for_feature_group_create(feature_group: FeatureGroup):
|
304 | 382 | print(f"FeatureGroup {feature_group.name} successfully created.")
|
305 | 383 |
|
306 | 384 |
|
| 385 | +def _wait_for_feature_group_update(feature_group: FeatureGroup): |
| 386 | + status = feature_group.describe().get("LastUpdateStatus").get("Status") |
| 387 | + while status == "InProgress": |
| 388 | + print("Waiting for Feature Group Update") |
| 389 | + time.sleep(5) |
| 390 | + status = feature_group.describe().get("LastUpdateStatus").get("Status") |
| 391 | + if status != "Successful": |
| 392 | + print(feature_group.describe()) |
| 393 | + raise RuntimeError(f"Failed to update feature group {feature_group.name}") |
| 394 | + print(f"FeatureGroup {feature_group.name} successfully updated.") |
| 395 | + |
| 396 | + |
307 | 397 | @contextmanager
|
308 | 398 | def cleanup_feature_group(feature_group: FeatureGroup):
|
309 | 399 | try:
|
|
0 commit comments