Skip to content

Commit 05c8a80

Browse files
committed
Typing CoefficientsTransformer.
1 parent 40b034e commit 05c8a80

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed
Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
14
from sklearn.base import BaseEstimator, TransformerMixin
25
from sklearn.utils.validation import check_is_fitted
36

47
from ._fdatabasis import FDataBasis
58

69

7-
class CoefficientsTransformer(BaseEstimator, TransformerMixin):
8-
"""
10+
class CoefficientsTransformer(
11+
BaseEstimator, # type:ignore
12+
TransformerMixin, # type:ignore
13+
):
14+
r"""
915
Transformer returning the coefficients of FDataBasis objects as a matrix.
1016
1117
Attributes:
12-
shape_ (tuple): original shape of coefficients per sample.
18+
basis\_ (tuple): Basis used.
1319
1420
Examples:
1521
>>> from skfda.representation.basis import (FDataBasis, Monomial,
@@ -26,19 +32,24 @@ class CoefficientsTransformer(BaseEstimator, TransformerMixin):
2632
2733
"""
2834

29-
def fit(self, X: FDataBasis, y=None):
35+
def fit( # noqa: D102
36+
self,
37+
X: FDataBasis,
38+
y: None = None,
39+
) -> CoefficientsTransformer:
3040

31-
self.shape_ = X.coefficients.shape[1:]
41+
self.basis_ = X.basis
3242

3343
return self
3444

35-
def transform(self, X, y=None):
45+
def transform( # noqa: D102
46+
self,
47+
X: FDataBasis,
48+
y: None = None,
49+
) -> np.ndarray:
3650

3751
check_is_fitted(self)
3852

39-
assert X.coefficients.shape[1:] == self.shape_
40-
41-
coefficients = X.coefficients.copy()
42-
coefficients = coefficients.reshape((X.n_samples, -1))
53+
assert X.basis == self.basis_
4354

44-
return coefficients
55+
return X.coefficients.copy()

0 commit comments

Comments
 (0)