Skip to content

Commit 699f9b0

Browse files
imingtsouEric Zou
authored andcommitted
feat: Add DatasetBuilder class (aws#667)
Co-authored-by: Eric Zou <[email protected]>
1 parent 8c5a8c9 commit 699f9b0

File tree

4 files changed

+342
-1
lines changed

4 files changed

+342
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Dataset Builder
14+
15+
A Dataset Builder is a builder class for generating a dataset by providing conditions.
16+
"""
17+
from __future__ import absolute_import
18+
19+
import datetime
20+
from typing import Sequence, Union
21+
22+
import attr
23+
import pandas as pd
24+
25+
from sagemaker.feature_store.feature_group import FeatureGroup
26+
27+
28+
@attr.s
29+
class DatasetBuilder:
30+
"""DatasetBuilder definition.
31+
32+
This class instantiates a DatasetBuilder object that comprises a base, a list of feature names,
33+
an output path and a KMS key ID.
34+
35+
Attributes:
36+
_base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
37+
pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset.
38+
_output_path (str): An S3 URI which stores the output .csv file.
39+
_record_identifier_feature_name (str): A string representing the record identifier feature
40+
if base is a DataFrame.
41+
_event_time_identifier_feature_name (str): A string representing the event time identifier
42+
feature if base is a DataFrame.
43+
_included_feature_names (List[str]): A list of features to be included in the output.
44+
_kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file.
45+
_point_in_time_accurate_join (bool): A boolean representing whether using point in time join
46+
or not.
47+
_include_duplicated_records (bool): A boolean representing whether including duplicated
48+
records or not.
49+
_include_deleted_records (bool): A boolean representing whether including deleted records or
50+
not.
51+
_number_of_recent_records (int): An int that how many records will be returned for each
52+
record identifier.
53+
_number_of_records (int): An int that how many records will be returned.
54+
_write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in
55+
dataset will be before it.
56+
_event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time
57+
in dataset will be after it.
58+
_event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in
59+
dataset will be before it.
60+
"""
61+
62+
_base: Union[FeatureGroup, pd.DataFrame] = attr.ib()
63+
_output_path: str = attr.ib()
64+
_record_identifier_feature_name: str = attr.ib(default=None)
65+
_event_time_identifier_feature_name: str = attr.ib(default=None)
66+
_included_feature_names: Sequence[str] = attr.ib(default=None)
67+
_kms_key_id: str = attr.ib(default=None)
68+
69+
_point_in_time_accurate_join: bool = attr.ib(init=False, default=False)
70+
_include_duplicated_records: bool = attr.ib(init=False, default=False)
71+
_include_deleted_records: bool = attr.ib(init=False, default=False)
72+
_number_of_recent_records: int = attr.ib(init=False, default=1)
73+
_number_of_records: int = attr.ib(init=False, default=None)
74+
_write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
75+
_event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None)
76+
_event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
77+
78+
def point_in_time_accurate_join(self):
79+
"""Set join type as point in time accurate join.
80+
81+
Returns:
82+
This DatasetBuilder object.
83+
"""
84+
self._point_in_time_accurate_join = True
85+
return self
86+
87+
def include_duplicated_records(self):
88+
"""Include duplicated records in dataset.
89+
90+
Returns:
91+
This DatasetBuilder object.
92+
"""
93+
self._include_duplicated_records = True
94+
return self
95+
96+
def include_deleted_records(self):
97+
"""Include deleted records in dataset.
98+
99+
Returns:
100+
This DatasetBuilder object.
101+
"""
102+
self._include_deleted_records = True
103+
return self
104+
105+
def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int):
106+
"""Set number_of_recent_records field with provided input.
107+
108+
Args:
109+
number_of_recent_records (int): An int that how many recent records will be returned for
110+
each record identifier.
111+
Returns:
112+
This DatasetBuilder object.
113+
"""
114+
self._number_of_recent_records = number_of_recent_records
115+
return self
116+
117+
def with_number_of_records_from_query_results(self, number_of_records: int):
118+
"""Set number_of_records field with provided input.
119+
120+
Args:
121+
number_of_records (int): An int that how many records will be returned.
122+
Returns:
123+
This DatasetBuilder object.
124+
"""
125+
self._number_of_records = number_of_records
126+
return self
127+
128+
def as_of(self, timestamp: datetime.datetime):
129+
"""Set write_time_ending_timestamp field with provided input.
130+
131+
Args:
132+
timestamp (datetime.datetime): A datetime that all records' write time in dataset will
133+
be before it.
134+
Returns:
135+
This DatasetBuilder object.
136+
"""
137+
self._write_time_ending_timestamp = timestamp
138+
return self
139+
140+
def with_event_time_range(
141+
self,
142+
starting_timestamp: datetime.datetime = None,
143+
ending_timestamp: datetime.datetime = None,
144+
):
145+
"""Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs.
146+
147+
Args:
148+
starting_timestamp (datetime.datetime): A datetime that all records' event time in
149+
dataset will be after it (default: None).
150+
ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset
151+
will be before it (default: None).
152+
Returns:
153+
This DatasetBuilder object.
154+
"""
155+
self._event_time_starting_timestamp = starting_timestamp
156+
self._event_time_ending_timestamp = ending_timestamp
157+
return self

src/sagemaker/feature_store/feature_store.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from __future__ import absolute_import
1919

2020
import datetime
21-
from typing import Dict, Any
21+
from typing import Any, Dict, Sequence, Union
2222

2323
import attr
24+
import pandas as pd
2425

2526
from sagemaker import Session
27+
from sagemaker.feature_store.dataset_builder import DatasetBuilder
28+
from sagemaker.feature_store.feature_group import FeatureGroup
2629

2730

2831
@attr.s
@@ -37,6 +40,50 @@ class FeatureStore:
3740

3841
sagemaker_session: Session = attr.ib(default=Session)
3942

43+
@staticmethod
44+
def create_dataset(
45+
base: Union[FeatureGroup, pd.DataFrame],
46+
output_path: str,
47+
record_identifier_feature_name: str = None,
48+
event_time_identifier_feature_name: str = None,
49+
included_feature_names: Sequence[str] = None,
50+
kms_key_id: str = None,
51+
) -> DatasetBuilder:
52+
"""Create a Dataset Builder for generating a Dataset.
53+
54+
Args:
55+
base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
56+
pandas.DataFrame and will be used to merge other FeatureGroups and generate a
57+
Dataset.
58+
output_path (str): An S3 URI which stores the output .csv file.
59+
record_identifier_feature_name (str): A string representing the record identifier
60+
feature if base is a DataFrame (default: None).
61+
event_time_identifier_feature_name (str): A string representing the event time
62+
identifier feature if base is a DataFrame (default: None).
63+
included_feature_names (List[str]): A list of features to be included in the output
64+
(default: None).
65+
kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file
66+
(default: None).
67+
68+
Raises:
69+
ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event
70+
time identifier feature name is provided.
71+
"""
72+
if isinstance(base, pd.DataFrame):
73+
if record_identifier_feature_name is None or event_time_identifier_feature_name is None:
74+
raise ValueError(
75+
"You must provide a record identifier feature name and an event time "
76+
+ "identifier feature name if specify DataFrame as base."
77+
)
78+
return DatasetBuilder(
79+
base,
80+
output_path,
81+
record_identifier_feature_name,
82+
event_time_identifier_feature_name,
83+
included_feature_names,
84+
kms_key_id,
85+
)
86+
4087
def list_feature_groups(
4188
self,
4289
name_contains: str = None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import datetime
16+
17+
import pytest
18+
from mock import Mock
19+
20+
from sagemaker.feature_store.dataset_builder import DatasetBuilder
21+
22+
23+
@pytest.fixture
24+
def feature_group_mock():
25+
return Mock()
26+
27+
28+
def test_point_in_time_accurate_join(feature_group_mock):
29+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
30+
dataset_builder.point_in_time_accurate_join()
31+
assert dataset_builder._point_in_time_accurate_join
32+
33+
34+
def test_include_duplicated_records(feature_group_mock):
35+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
36+
dataset_builder.include_duplicated_records()
37+
assert dataset_builder._include_duplicated_records
38+
39+
40+
def test_include_deleted_records(feature_group_mock):
41+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
42+
dataset_builder.include_deleted_records()
43+
assert dataset_builder._include_deleted_records
44+
45+
46+
def test_with_number_of_recent_records_by_record_identifier(feature_group_mock):
47+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
48+
dataset_builder.with_number_of_recent_records_by_record_identifier(5)
49+
assert dataset_builder._number_of_recent_records == 5
50+
51+
52+
def test_with_number_of_records_from_query_results(feature_group_mock):
53+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
54+
dataset_builder.with_number_of_records_from_query_results(100)
55+
assert dataset_builder._number_of_records == 100
56+
57+
58+
def test_as_of(feature_group_mock):
59+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
60+
time = datetime.datetime.now()
61+
dataset_builder.as_of(time)
62+
assert dataset_builder._write_time_ending_timestamp == time
63+
64+
65+
def test_with_event_time_range(feature_group_mock):
66+
dataset_builder = DatasetBuilder(base=feature_group_mock, output_path="file/to/path")
67+
start = datetime.datetime.now()
68+
end = start + datetime.timedelta(minutes=1)
69+
dataset_builder.with_event_time_range(start, end)
70+
assert dataset_builder._event_time_starting_timestamp == start
71+
assert dataset_builder._event_time_ending_timestamp == end

tests/unit/sagemaker/feature_store/test_feature_store.py

+66
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,83 @@
1414

1515
import datetime
1616

17+
import pandas as pd
1718
import pytest
1819
from mock import Mock
1920

2021
from sagemaker.feature_store.feature_store import FeatureStore
2122

23+
DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]})
24+
2225

2326
@pytest.fixture
2427
def sagemaker_session_mock():
2528
return Mock()
2629

2730

31+
@pytest.fixture
32+
def feature_group_mock():
33+
return Mock()
34+
35+
36+
def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock):
37+
feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
38+
dataset_builder = feature_store.create_dataset(
39+
base=feature_group_mock,
40+
output_path="file/to/path",
41+
)
42+
assert dataset_builder._base == feature_group_mock
43+
assert dataset_builder._output_path == "file/to/path"
44+
45+
46+
def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock):
47+
feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
48+
dataset_builder = feature_store.create_dataset(
49+
base=feature_group_mock,
50+
included_feature_names=["feature_1", "feature_2"],
51+
output_path="file/to/path",
52+
kms_key_id="kms-key-id",
53+
)
54+
assert dataset_builder._base == feature_group_mock
55+
assert dataset_builder._included_feature_names == ["feature_1", "feature_2"]
56+
assert dataset_builder._output_path == "file/to/path"
57+
assert dataset_builder._kms_key_id == "kms-key-id"
58+
59+
60+
def test_create_dataset_with_dataframe(sagemaker_session_mock):
61+
feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
62+
dataset_builder = feature_store.create_dataset(
63+
base=DATAFRAME,
64+
record_identifier_feature_name="feature_1",
65+
event_time_identifier_feature_name="feature_2",
66+
included_feature_names=["feature_1", "feature_2"],
67+
output_path="file/to/path",
68+
kms_key_id="kms-key-id",
69+
)
70+
assert dataset_builder._base.equals(DATAFRAME)
71+
assert dataset_builder._record_identifier_feature_name == "feature_1"
72+
assert dataset_builder._event_time_identifier_feature_name == "feature_2"
73+
assert dataset_builder._included_feature_names == ["feature_1", "feature_2"]
74+
assert dataset_builder._output_path == "file/to/path"
75+
assert dataset_builder._kms_key_id == "kms-key-id"
76+
77+
78+
def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock):
79+
feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
80+
with pytest.raises(ValueError) as error:
81+
feature_store.create_dataset(
82+
base=DATAFRAME,
83+
included_feature_names=["feature_1", "feature_2"],
84+
output_path="file/to/path",
85+
kms_key_id="kms-key-id",
86+
)
87+
assert (
88+
"You must provide a record identifier feature name and an event time identifier feature "
89+
+ "name if specify DataFrame as base."
90+
in str(error)
91+
)
92+
93+
2894
def test_list_feature_groups_with_no_filter(sagemaker_session_mock):
2995
feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
3096
feature_store.list_feature_groups()

0 commit comments

Comments
 (0)