Skip to content

Commit 5e3da8d

Browse files
authored
feat: discretize table (#327)
Closes #143. ### Summary of Changes * Added a class `Discretizer` in `safeds.data.tabular.transformation` that wraps the [`KBinsDiscretizer` of `scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html) * Made the class a subclass of `TableTransformer` * The `__init__` for now only has a parameter `number_of_bins` to control how many bins are created * If `number_of_bins` is less than 2, it raises a `ValueError`
1 parent 388ab2d commit 5e3da8d

File tree

3 files changed

+508
-0
lines changed

3 files changed

+508
-0
lines changed

src/safeds/data/tabular/transformation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Classes for transforming tabular data."""
22

3+
from ._discretizer import Discretizer
34
from ._imputer import Imputer
45
from ._label_encoder import LabelEncoder
56
from ._one_hot_encoder import OneHotEncoder
@@ -14,5 +15,6 @@
1415
"InvertibleTableTransformer",
1516
"TableTransformer",
1617
"RangeScaler",
18+
"Discretizer",
1719
"StandardScaler",
1820
]
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from __future__ import annotations
2+
3+
from sklearn.preprocessing import KBinsDiscretizer as sk_KBinsDiscretizer
4+
5+
from safeds.data.tabular.containers import Table
6+
from safeds.data.tabular.transformation._table_transformer import TableTransformer
7+
from safeds.exceptions import NonNumericColumnError, TransformerNotFittedError, UnknownColumnNameError
8+
9+
10+
class Discretizer(TableTransformer):
11+
"""
12+
The Discretizer bins continuous data into intervals.
13+
14+
Parameters
15+
----------
16+
number_of_bins: float
17+
The number of bins to be created.
18+
19+
Raises
20+
------
21+
ValueError
22+
If the given number_of_bins is less than 2.
23+
"""
24+
25+
def __init__(self, number_of_bins: float = 5):
26+
self._column_names: list[str] | None = None
27+
self._wrapped_transformer: sk_KBinsDiscretizer | None = None
28+
29+
if number_of_bins < 2:
30+
raise ValueError("Parameter 'number_of_bins' must be >= 2.")
31+
self._number_of_bins = number_of_bins
32+
33+
def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
34+
"""
35+
Learn a transformation for a set of columns in a table.
36+
37+
This transformer is not modified.
38+
39+
Parameters
40+
----------
41+
table : Table
42+
The table used to fit the transformer.
43+
column_names : list[str] | None
44+
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
45+
46+
Returns
47+
-------
48+
fitted_transformer : TableTransformer
49+
The fitted transformer.
50+
51+
Raises
52+
------
53+
ValueError
54+
If the table is empty.
55+
NonNumericColumnError
56+
If one of the columns, that should be fitted is non-numeric.
57+
UnknownColumnNameError
58+
If one of the columns, that should be fitted is not in the table.
59+
"""
60+
if table.number_of_rows == 0:
61+
raise ValueError("The Discretizer cannot be fitted because the table contains 0 rows")
62+
63+
if column_names is None:
64+
column_names = table.column_names
65+
else:
66+
missing_columns = set(column_names) - set(table.column_names)
67+
if len(missing_columns) > 0:
68+
raise UnknownColumnNameError(
69+
sorted(
70+
missing_columns,
71+
key={val: ix for ix, val in enumerate(column_names)}.__getitem__,
72+
),
73+
)
74+
75+
for column in column_names:
76+
if not table.get_column(column).type.is_numeric():
77+
raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.")
78+
79+
wrapped_transformer = sk_KBinsDiscretizer(n_bins=self._number_of_bins, encode="ordinal")
80+
wrapped_transformer.fit(table._data[column_names])
81+
82+
result = Discretizer(self._number_of_bins)
83+
result._wrapped_transformer = wrapped_transformer
84+
result._column_names = column_names
85+
86+
return result
87+
88+
def transform(self, table: Table) -> Table:
89+
"""
90+
Apply the learned transformation to a table.
91+
92+
The table is not modified.
93+
94+
Parameters
95+
----------
96+
table : Table
97+
The table to which the learned transformation is applied.
98+
99+
Returns
100+
-------
101+
transformed_table : Table
102+
The transformed table.
103+
104+
Raises
105+
------
106+
TransformerNotFittedError
107+
If the transformer has not been fitted yet.
108+
ValueError
109+
If the table is empty.
110+
UnknownColumnNameError
111+
If one of the columns, that should be transformed is not in the table.
112+
NonNumericColumnError
113+
If one of the columns, that should be fitted is non-numeric.
114+
"""
115+
# Transformer has not been fitted yet
116+
if self._wrapped_transformer is None or self._column_names is None:
117+
raise TransformerNotFittedError
118+
119+
if table.number_of_rows == 0:
120+
raise ValueError("The table cannot be transformed because it contains 0 rows")
121+
122+
# Input table does not contain all columns used to fit the transformer
123+
missing_columns = set(self._column_names) - set(table.column_names)
124+
if len(missing_columns) > 0:
125+
raise UnknownColumnNameError(
126+
sorted(
127+
missing_columns,
128+
key={val: ix for ix, val in enumerate(self._column_names)}.__getitem__,
129+
),
130+
)
131+
132+
for column in self._column_names:
133+
if not table.get_column(column).type.is_numeric():
134+
raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.")
135+
136+
data = table._data.copy()
137+
data.columns = table.column_names
138+
data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names])
139+
return Table._from_pandas_dataframe(data)
140+
141+
def is_fitted(self) -> bool:
142+
"""
143+
Check if the transformer is fitted.
144+
145+
Returns
146+
-------
147+
is_fitted : bool
148+
Whether the transformer is fitted.
149+
"""
150+
return self._wrapped_transformer is not None
151+
152+
def get_names_of_added_columns(self) -> list[str]:
153+
"""
154+
Get the names of all new columns that have been added by the Discretizer.
155+
156+
Returns
157+
-------
158+
added_columns : list[str]
159+
A list of names of the added columns, ordered as they will appear in the table.
160+
161+
Raises
162+
------
163+
TransformerNotFittedError
164+
If the transformer has not been fitted yet.
165+
"""
166+
if not self.is_fitted():
167+
raise TransformerNotFittedError
168+
return []
169+
170+
# (Must implement abstract method, cannot instantiate class otherwise.)
171+
def get_names_of_changed_columns(self) -> list[str]:
172+
"""
173+
Get the names of all columns that may have been changed by the Discretizer.
174+
175+
Returns
176+
-------
177+
changed_columns : list[str]
178+
The list of (potentially) changed column names, as passed to fit.
179+
180+
Raises
181+
------
182+
TransformerNotFittedError
183+
If the transformer has not been fitted yet.
184+
"""
185+
if self._column_names is None:
186+
raise TransformerNotFittedError
187+
return self._column_names
188+
189+
def get_names_of_removed_columns(self) -> list[str]:
190+
"""
191+
Get the names of all columns that have been removed by the Discretizer.
192+
193+
Returns
194+
-------
195+
removed_columns : list[str]
196+
A list of names of the removed columns, ordered as they appear in the table the Discretizer was fitted on.
197+
198+
Raises
199+
------
200+
TransformerNotFittedError
201+
If the transformer has not been fitted yet.
202+
"""
203+
if not self.is_fitted():
204+
raise TransformerNotFittedError
205+
return []

0 commit comments

Comments
 (0)