Skip to content

Commit bd96ec5

Browse files
mizanfiuEric Zouimingtsoubdchathamjiapinw
authored
feature: Feature Store dataset builder, delete_record, get_record, list_feature_group (aws#3534)
Co-authored-by: Eric Zou <[email protected]> Co-authored-by: Yiming Zou <[email protected]> Co-authored-by: Brandon Chatham <[email protected]> Co-authored-by: jiapinw <[email protected]>
1 parent a3efddf commit bd96ec5

File tree

9 files changed

+2979
-588
lines changed

9 files changed

+2979
-588
lines changed

src/sagemaker/feature_store/dataset_builder.py

+990
Large diffs are not rendered by default.

src/sagemaker/feature_store/feature_group.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -435,13 +435,14 @@ class FeatureGroup:
435435
"uint64",
436436
]
437437
_FLOAT_TYPES = ["float_", "float16", "float32", "float64"]
438-
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
438+
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
439439
type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES
440440
}
441-
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
441+
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
442442
{type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES}
443443
)
444-
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
444+
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
445+
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING
445446

446447
_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = {
447448
FeatureTypeEnum.INTEGRAL.value: "INT",
@@ -629,7 +630,7 @@ def load_feature_definitions(
629630
"""
630631
feature_definitions = []
631632
for column in data_frame:
632-
feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
633+
feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
633634
str(data_frame[column].dtype), None
634635
)
635636
if feature_type:
@@ -644,6 +645,23 @@ def load_feature_definitions(
644645
self.feature_definitions = feature_definitions
645646
return self.feature_definitions
646647

648+
def get_record(
649+
self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None
650+
) -> Sequence[Dict[str, str]]:
651+
"""Get a single record in a FeatureGroup
652+
653+
Args:
654+
record_identifier_value_as_string (String):
655+
a String representing the value of the record identifier.
656+
feature_names (Sequence[String]):
657+
a list of Strings representing feature names.
658+
"""
659+
return self.sagemaker_session.get_record(
660+
record_identifier_value_as_string=record_identifier_value_as_string,
661+
feature_group_name=self.name,
662+
feature_names=feature_names,
663+
).get("Record")
664+
647665
def put_record(self, record: Sequence[FeatureValue]):
648666
"""Put a single record in the FeatureGroup.
649667
@@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]):
654672
feature_group_name=self.name, record=[value.to_dict() for value in record]
655673
)
656674

675+
def delete_record(
676+
self,
677+
record_identifier_value_as_string: str,
678+
event_time: str,
679+
):
680+
"""Delete a single record from a FeatureGroup.
681+
682+
Args:
683+
record_identifier_value_as_string (String):
684+
a String representing the value of the record identifier.
685+
event_time (String):
686+
a timestamp format String indicating when the deletion event occurred.
687+
"""
688+
return self.sagemaker_session.delete_record(
689+
feature_group_name=self.name,
690+
record_identifier_value_as_string=record_identifier_value_as_string,
691+
event_time=event_time,
692+
)
693+
657694
def ingest(
658695
self,
659696
data_frame: DataFrame,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Feature Store.
14+
15+
Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and
16+
manage features for machine learning (ML) models.
17+
"""
18+
from __future__ import absolute_import
19+
20+
import datetime
21+
from typing import Any, Dict, Sequence, Union
22+
23+
import attr
24+
import pandas as pd
25+
26+
from sagemaker import Session
27+
from sagemaker.feature_store.dataset_builder import DatasetBuilder
28+
from sagemaker.feature_store.feature_group import FeatureGroup
29+
30+
31+
@attr.s
32+
class FeatureStore:
33+
"""FeatureStore definition.
34+
35+
This class instantiates a FeatureStore object that comprises a SageMaker session instance.
36+
37+
Attributes:
38+
sagemaker_session (Session): session instance to perform boto calls.
39+
"""
40+
41+
sagemaker_session: Session = attr.ib(default=Session)
42+
43+
def create_dataset(
44+
self,
45+
base: Union[FeatureGroup, pd.DataFrame],
46+
output_path: str,
47+
record_identifier_feature_name: str = None,
48+
event_time_identifier_feature_name: str = None,
49+
included_feature_names: Sequence[str] = None,
50+
kms_key_id: str = None,
51+
) -> DatasetBuilder:
52+
"""Create a Dataset Builder for generating a Dataset.
53+
54+
Args:
55+
base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
56+
pandas.DataFrame and will be used to merge other FeatureGroups and generate a
57+
Dataset.
58+
output_path (str): An S3 URI which stores the output .csv file.
59+
record_identifier_feature_name (str): A string representing the record identifier
60+
feature if base is a DataFrame (default: None).
61+
event_time_identifier_feature_name (str): A string representing the event time
62+
identifier feature if base is a DataFrame (default: None).
63+
included_feature_names (List[str]): A list of features to be included in the output
64+
(default: None).
65+
kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file
66+
(default: None).
67+
68+
Raises:
69+
ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event
70+
time identifier feature name is provided.
71+
"""
72+
if isinstance(base, pd.DataFrame):
73+
if record_identifier_feature_name is None or event_time_identifier_feature_name is None:
74+
raise ValueError(
75+
"You must provide a record identifier feature name and an event time "
76+
+ "identifier feature name if specify DataFrame as base."
77+
)
78+
return DatasetBuilder(
79+
self.sagemaker_session,
80+
base,
81+
output_path,
82+
record_identifier_feature_name,
83+
event_time_identifier_feature_name,
84+
included_feature_names,
85+
kms_key_id,
86+
)
87+
88+
def list_feature_groups(
89+
self,
90+
name_contains: str = None,
91+
feature_group_status_equals: str = None,
92+
offline_store_status_equals: str = None,
93+
creation_time_after: datetime.datetime = None,
94+
creation_time_before: datetime.datetime = None,
95+
sort_order: str = None,
96+
sort_by: str = None,
97+
max_results: int = None,
98+
next_token: str = None,
99+
) -> Dict[str, Any]:
100+
"""List all FeatureGroups satisfying given filters.
101+
102+
Args:
103+
name_contains (str): A string that partially matches one or more FeatureGroups' names.
104+
Filters FeatureGroups by name.
105+
feature_group_status_equals (str): A FeatureGroup status.
106+
Filters FeatureGroups by FeatureGroup status.
107+
offline_store_status_equals (str): An OfflineStore status.
108+
Filters FeatureGroups by OfflineStore status.
109+
creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
110+
created after a specific date and time.
111+
creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
112+
created before a specific date and time.
113+
sort_order (str): The order in which FeatureGroups are listed.
114+
sort_by (str): The value on which the FeatureGroup list is sorted.
115+
max_results (int): The maximum number of results returned by ListFeatureGroups.
116+
next_token (str): A token to resume pagination of ListFeatureGroups results.
117+
Returns:
118+
Response dict from service.
119+
"""
120+
return self.sagemaker_session.list_feature_groups(
121+
name_contains=name_contains,
122+
feature_group_status_equals=feature_group_status_equals,
123+
offline_store_status_equals=offline_store_status_equals,
124+
creation_time_after=creation_time_after,
125+
creation_time_before=creation_time_before,
126+
sort_order=sort_order,
127+
sort_by=sort_by,
128+
max_results=max_results,
129+
next_token=next_token,
130+
)

src/sagemaker/session.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
312312
# For each object key, create the directory on the local machine if needed, and then
313313
# download the file.
314314
for key in keys:
315-
tail_s3_uri_path = os.path.basename(key_prefix)
315+
tail_s3_uri_path = os.path.basename(key)
316316
if not os.path.splitext(key_prefix)[1]:
317317
tail_s3_uri_path = os.path.relpath(key, key_prefix)
318318
destination_path = os.path.join(path, tail_s3_uri_path)
@@ -4341,6 +4341,56 @@ def update_feature_group(
43414341
FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions
43424342
)
43434343

4344+
def list_feature_groups(
4345+
self,
4346+
name_contains,
4347+
feature_group_status_equals,
4348+
offline_store_status_equals,
4349+
creation_time_after,
4350+
creation_time_before,
4351+
sort_order,
4352+
sort_by,
4353+
max_results,
4354+
next_token,
4355+
) -> Dict[str, Any]:
4356+
"""List all FeatureGroups satisfying given filters.
4357+
4358+
Args:
4359+
name_contains (str): A string that partially matches one or more FeatureGroups' names.
4360+
Filters FeatureGroups by name.
4361+
feature_group_status_equals (str): A FeatureGroup status.
4362+
Filters FeatureGroups by FeatureGroup status.
4363+
offline_store_status_equals (str): An OfflineStore status.
4364+
Filters FeatureGroups by OfflineStore status.
4365+
creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
4366+
created after a specific date and time.
4367+
creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
4368+
created before a specific date and time.
4369+
sort_order (str): The order in which FeatureGroups are listed.
4370+
sort_by (str): The value on which the FeatureGroup list is sorted.
4371+
max_results (int): The maximum number of results returned by ListFeatureGroups.
4372+
next_token (str): A token to resume pagination of ListFeatureGroups results.
4373+
Returns:
4374+
Response dict from service.
4375+
"""
4376+
list_feature_groups_args = {}
4377+
4378+
def check_object(key, value):
4379+
if value is not None:
4380+
list_feature_groups_args[key] = value
4381+
4382+
check_object("NameContains", name_contains)
4383+
check_object("FeatureGroupStatusEquals", feature_group_status_equals)
4384+
check_object("OfflineStoreStatusEquals", offline_store_status_equals)
4385+
check_object("CreationTimeAfter", creation_time_after)
4386+
check_object("CreationTimeBefore", creation_time_before)
4387+
check_object("SortOrder", sort_order)
4388+
check_object("SortBy", sort_by)
4389+
check_object("MaxResults", max_results)
4390+
check_object("NextToken", next_token)
4391+
4392+
return self.sagemaker_client.list_feature_groups(**list_feature_groups_args)
4393+
43444394
def update_feature_metadata(
43454395
self,
43464396
feature_group_name: str,
@@ -4408,6 +4458,48 @@ def put_record(
44084458
Record=record,
44094459
)
44104460

4461+
def delete_record(
4462+
self,
4463+
feature_group_name: str,
4464+
record_identifier_value_as_string: str,
4465+
event_time: str,
4466+
):
4467+
"""Deletes a single record from the FeatureGroup.
4468+
4469+
Args:
4470+
feature_group_name (str): name of the FeatureGroup.
4471+
record_identifier_value_as_string (str): name of the record identifier.
4472+
event_time (str): a timestamp indicating when the deletion event occurred.
4473+
"""
4474+
return self.sagemaker_featurestore_runtime_client.delete_record(
4475+
FeatureGroupName=feature_group_name,
4476+
RecordIdentifierValueAsString=record_identifier_value_as_string,
4477+
EventTime=event_time,
4478+
)
4479+
4480+
def get_record(
4481+
self,
4482+
record_identifier_value_as_string: str,
4483+
feature_group_name: str,
4484+
feature_names: Sequence[str],
4485+
) -> Dict[str, Sequence[Dict[str, str]]]:
4486+
"""Gets a single record in the FeatureGroup.
4487+
4488+
Args:
4489+
record_identifier_value_as_string (str): name of the record identifier.
4490+
feature_group_name (str): name of the FeatureGroup.
4491+
feature_names (Sequence[str]): list of feature names.
4492+
"""
4493+
get_record_args = {
4494+
"FeatureGroupName": feature_group_name,
4495+
"RecordIdentifierValueAsString": record_identifier_value_as_string,
4496+
}
4497+
4498+
if feature_names:
4499+
get_record_args["FeatureNames"] = feature_names
4500+
4501+
return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args)
4502+
44114503
def start_query_execution(
44124504
self,
44134505
catalog: str,

0 commit comments

Comments
 (0)