Skip to content

Commit 40b034e

Browse files
committed
Typing evaluation transformer.
1 parent a5339f6 commit 40b034e

File tree

1 file changed

+61
-14
lines changed

1 file changed

+61
-14
lines changed

skfda/representation/_evaluation_trasformer.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional, Union, overload
4+
5+
import numpy as np
16
from sklearn.base import BaseEstimator, TransformerMixin
27
from sklearn.utils.validation import check_is_fitted
8+
from typing_extensions import Literal
9+
310
from ._functional_data import FData
11+
from ._typing import ArrayLike, GridPointsLike
12+
from .extrapolation import ExtrapolationLike
413
from .grid import FDataGrid
514

615

7-
class EvaluationTransformer(BaseEstimator, TransformerMixin):
16+
class EvaluationTransformer(
17+
BaseEstimator, # type:ignore
18+
TransformerMixin, # type:ignore
19+
):
820
r"""
921
Transformer returning the evaluations of FData objects as a matrix.
1022
@@ -25,10 +37,9 @@ class EvaluationTransformer(BaseEstimator, TransformerMixin):
2537
the parameter has no efect. Defaults to False.
2638
2739
Attributes:
28-
shape_ (tuple): original shape of coefficients per sample.
40+
shape\_ (tuple): original shape of coefficients per sample.
2941
3042
Examples:
31-
3243
>>> from skfda.representation import (FDataGrid, FDataBasis,
3344
... EvaluationTransformer)
3445
>>> from skfda.representation.basis import Monomial
@@ -82,32 +93,68 @@ class EvaluationTransformer(BaseEstimator, TransformerMixin):
8293
8394
"""
8495

85-
def __init__(self, eval_points=None, *,
86-
extrapolation=None, grid=False):
96+
@overload
97+
def __init__(
98+
self,
99+
eval_points: ArrayLike,
100+
*,
101+
extrapolation: Optional[ExtrapolationLike] = None,
102+
grid: Literal[False] = False,
103+
) -> None:
104+
pass
105+
106+
@overload
107+
def __init__(
108+
self,
109+
eval_points: GridPointsLike,
110+
*,
111+
extrapolation: Optional[ExtrapolationLike] = None,
112+
grid: Literal[True],
113+
) -> None:
114+
pass
115+
116+
def __init__(
117+
self,
118+
eval_points: Union[ArrayLike, GridPointsLike, None] = None,
119+
*,
120+
extrapolation: Optional[ExtrapolationLike] = None,
121+
grid: bool = False,
122+
):
87123
self.eval_points = eval_points
88124
self.extrapolation = extrapolation
89125
self.grid = grid
90126

91-
def fit(self, X: FData, y=None):
127+
def fit( # noqa: D102
128+
self,
129+
X: FData,
130+
y: None = None,
131+
) -> EvaluationTransformer:
92132

93133
if self.eval_points is None and not isinstance(X, FDataGrid):
94-
raise ValueError("If no eval_points are passed, the functions "
95-
"should be FDataGrid objects.")
134+
raise ValueError(
135+
"If no eval_points are passed, the functions "
136+
"should be FDataGrid objects.",
137+
)
96138

97139
self._is_fitted = True
98140

99141
return self
100142

101-
def transform(self, X, y=None):
143+
def transform( # noqa: D102
144+
self,
145+
X: FData,
146+
y: None = None,
147+
) -> np.ndarray:
102148

103149
check_is_fitted(self, '_is_fitted')
104150

105151
if self.eval_points is None:
106152
evaluation = X.data_matrix.copy()
107153
else:
108-
evaluation = X(self.eval_points,
109-
extrapolation=self.extrapolation, grid=self.grid)
110-
111-
evaluation = evaluation.reshape((X.n_samples, -1))
154+
evaluation = X( # type: ignore
155+
self.eval_points,
156+
extrapolation=self.extrapolation,
157+
grid=self.grid,
158+
)
112159

113-
return evaluation
160+
return evaluation.reshape((X.n_samples, -1))

0 commit comments

Comments
 (0)