1
+ from __future__ import annotations
2
+
3
+ from typing import Optional , Union , overload
4
+
5
+ import numpy as np
1
6
from sklearn .base import BaseEstimator , TransformerMixin
2
7
from sklearn .utils .validation import check_is_fitted
8
+ from typing_extensions import Literal
9
+
3
10
from ._functional_data import FData
11
+ from ._typing import ArrayLike , GridPointsLike
12
+ from .extrapolation import ExtrapolationLike
4
13
from .grid import FDataGrid
5
14
6
15
7
- class EvaluationTransformer (BaseEstimator , TransformerMixin ):
16
+ class EvaluationTransformer (
17
+ BaseEstimator , # type:ignore
18
+ TransformerMixin , # type:ignore
19
+ ):
8
20
r"""
9
21
Transformer returning the evaluations of FData objects as a matrix.
10
22
@@ -25,10 +37,9 @@ class EvaluationTransformer(BaseEstimator, TransformerMixin):
25
37
the parameter has no efect. Defaults to False.
26
38
27
39
Attributes:
28
- shape_ (tuple): original shape of coefficients per sample.
40
+ shape\_ (tuple): original shape of coefficients per sample.
29
41
30
42
Examples:
31
-
32
43
>>> from skfda.representation import (FDataGrid, FDataBasis,
33
44
... EvaluationTransformer)
34
45
>>> from skfda.representation.basis import Monomial
@@ -82,32 +93,68 @@ class EvaluationTransformer(BaseEstimator, TransformerMixin):
82
93
83
94
"""
84
95
85
- def __init__ (self , eval_points = None , * ,
86
- extrapolation = None , grid = False ):
96
+ @overload
97
+ def __init__ (
98
+ self ,
99
+ eval_points : ArrayLike ,
100
+ * ,
101
+ extrapolation : Optional [ExtrapolationLike ] = None ,
102
+ grid : Literal [False ] = False ,
103
+ ) -> None :
104
+ pass
105
+
106
+ @overload
107
+ def __init__ (
108
+ self ,
109
+ eval_points : GridPointsLike ,
110
+ * ,
111
+ extrapolation : Optional [ExtrapolationLike ] = None ,
112
+ grid : Literal [True ],
113
+ ) -> None :
114
+ pass
115
+
116
+ def __init__ (
117
+ self ,
118
+ eval_points : Union [ArrayLike , GridPointsLike , None ] = None ,
119
+ * ,
120
+ extrapolation : Optional [ExtrapolationLike ] = None ,
121
+ grid : bool = False ,
122
+ ):
87
123
self .eval_points = eval_points
88
124
self .extrapolation = extrapolation
89
125
self .grid = grid
90
126
91
- def fit (self , X : FData , y = None ):
127
+ def fit ( # noqa: D102
128
+ self ,
129
+ X : FData ,
130
+ y : None = None ,
131
+ ) -> EvaluationTransformer :
92
132
93
133
if self .eval_points is None and not isinstance (X , FDataGrid ):
94
- raise ValueError ("If no eval_points are passed, the functions "
95
- "should be FDataGrid objects." )
134
+ raise ValueError (
135
+ "If no eval_points are passed, the functions "
136
+ "should be FDataGrid objects." ,
137
+ )
96
138
97
139
self ._is_fitted = True
98
140
99
141
return self
100
142
101
- def transform (self , X , y = None ):
143
+ def transform ( # noqa: D102
144
+ self ,
145
+ X : FData ,
146
+ y : None = None ,
147
+ ) -> np .ndarray :
102
148
103
149
check_is_fitted (self , '_is_fitted' )
104
150
105
151
if self .eval_points is None :
106
152
evaluation = X .data_matrix .copy ()
107
153
else :
108
- evaluation = X (self .eval_points ,
109
- extrapolation = self .extrapolation , grid = self .grid )
110
-
111
- evaluation = evaluation .reshape ((X .n_samples , - 1 ))
154
+ evaluation = X ( # type: ignore
155
+ self .eval_points ,
156
+ extrapolation = self .extrapolation ,
157
+ grid = self .grid ,
158
+ )
112
159
113
- return evaluation
160
+ return evaluation . reshape (( X . n_samples , - 1 ))
0 commit comments