14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
#
17
+ import numpy as np
17
18
18
19
import pandas as pd
19
- from typing import Any , Union
20
+ from typing import Any , Union , List , Tuple
20
21
21
22
from pyspark .ml .param import Param , Params , TypeConverters
22
- from pyspark .ml .param .shared import HasLabelCol , HasPredictionCol
23
+ from pyspark .ml .param .shared import HasLabelCol , HasPredictionCol , HasProbabilityCol
23
24
from pyspark .ml .connect .base import Evaluator
24
25
from pyspark .ml .connect .io_utils import ParamsReadWrite
25
26
from pyspark .ml .connect .util import aggregate_dataframe
26
27
from pyspark .sql import DataFrame
27
28
28
- import torch
29
- import torcheval .metrics as torchmetrics
30
29
30
+ class _TorchMetricEvaluator (Evaluator ):
31
31
32
- class RegressionEvaluator (Evaluator , HasLabelCol , HasPredictionCol , ParamsReadWrite ):
32
+ metricName : Param [str ] = Param (
33
+ Params ._dummy (),
34
+ "metricName" ,
35
+ "metric name for the regression evaluator, valid values are 'mse' and 'r2'" ,
36
+ typeConverter = TypeConverters .toString ,
37
+ )
38
+
39
+ def _get_torch_metric (self ) -> Any :
40
+ raise NotImplementedError ()
41
+
42
+ def _get_input_cols (self ) -> List [str ]:
43
+ raise NotImplementedError ()
44
+
45
+ def _get_metric_update_inputs (self , dataset : "pd.DataFrame" ) -> Tuple [Any , Any ]:
46
+ raise NotImplementedError ()
47
+
48
+ def _evaluate (self , dataset : Union ["DataFrame" , "pd.DataFrame" ]) -> float :
49
+ torch_metric = self ._get_torch_metric ()
50
+
51
+ def local_agg_fn (pandas_df : "pd.DataFrame" ) -> "pd.DataFrame" :
52
+ torch_metric .update (* self ._get_metric_update_inputs (pandas_df ))
53
+ return torch_metric
54
+
55
+ def merge_agg_state (state1 : Any , state2 : Any ) -> Any :
56
+ state1 .merge_state ([state2 ])
57
+ return state1
58
+
59
+ def agg_state_to_result (state : Any ) -> Any :
60
+ return state .compute ().item ()
61
+
62
+ return aggregate_dataframe (
63
+ dataset ,
64
+ self ._get_input_cols (),
65
+ local_agg_fn ,
66
+ merge_agg_state ,
67
+ agg_state_to_result ,
68
+ )
69
+
70
+
71
+ class RegressionEvaluator (_TorchMetricEvaluator , HasLabelCol , HasPredictionCol , ParamsReadWrite ):
33
72
"""
34
73
Evaluator for Regression, which expects input columns prediction and label.
35
74
Supported metrics are 'mse' and 'r2'.
@@ -41,14 +80,9 @@ def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None:
41
80
super ().__init__ ()
42
81
self ._set (metricName = metricName , labelCol = labelCol , predictionCol = predictionCol )
43
82
44
- metricName : Param [str ] = Param (
45
- Params ._dummy (),
46
- "metricName" ,
47
- "metric name for the regression evaluator, valid values are 'mse' and 'r2'" ,
48
- typeConverter = TypeConverters .toString ,
49
- )
50
-
51
83
def _get_torch_metric (self ) -> Any :
84
+ import torcheval .metrics as torchmetrics
85
+
52
86
metric_name = self .getOrDefault (self .metricName )
53
87
54
88
if metric_name == "mse" :
@@ -58,32 +92,89 @@ def _get_torch_metric(self) -> Any:
58
92
59
93
raise ValueError (f"Unsupported regressor evaluator metric name: { metric_name } " )
60
94
61
- def _evaluate (self , dataset : Union ["DataFrame" , "pd.DataFrame" ]) -> float :
62
- prediction_col = self .getPredictionCol ()
63
- label_col = self .getLabelCol ()
95
+ def _get_input_cols (self ) -> List [str ]:
96
+ return [self .getPredictionCol (), self .getLabelCol ()]
64
97
65
- torch_metric = self ._get_torch_metric ()
98
+ def _get_metric_update_inputs (self , dataset : "pd.DataFrame" ) -> Tuple [Any , Any ]:
99
+ import torch
66
100
67
- def local_agg_fn (pandas_df : "pd.DataFrame" ) -> "pd.DataFrame" :
68
- with torch .inference_mode ():
69
- preds_tensor = torch .tensor (pandas_df [prediction_col ].values )
70
- labels_tensor = torch .tensor (pandas_df [label_col ].values )
71
- torch_metric .update (preds_tensor , labels_tensor )
72
- return torch_metric
101
+ preds_tensor = torch .tensor (dataset [self .getPredictionCol ()].values )
102
+ labels_tensor = torch .tensor (dataset [self .getLabelCol ()].values )
103
+ return preds_tensor , labels_tensor
73
104
74
- def merge_agg_state (state1 : Any , state2 : Any ) -> Any :
75
- with torch .inference_mode ():
76
- state1 .merge_state ([state2 ])
77
- return state1
78
105
79
- def agg_state_to_result (state : Any ) -> Any :
80
- with torch .inference_mode ():
81
- return state .compute ().item ()
106
+ class BinaryClassificationEvaluator (
107
+ _TorchMetricEvaluator , HasLabelCol , HasProbabilityCol , ParamsReadWrite
108
+ ):
109
+ """
110
+ Evaluator for binary classification, which expects input columns prediction and label.
111
+ Supported metrics are 'areaUnderROC' and 'areaUnderPR'.
82
112
83
- return aggregate_dataframe (
84
- dataset ,
85
- [prediction_col , label_col ],
86
- local_agg_fn ,
87
- merge_agg_state ,
88
- agg_state_to_result ,
113
+ .. versionadded:: 3.5.0
114
+ """
115
+
116
+ def __init__ (self , metricName : str , labelCol : str , probabilityCol : str ) -> None :
117
+ super ().__init__ ()
118
+ self ._set (metricName = metricName , labelCol = labelCol , probabilityCol = probabilityCol )
119
+
120
+ def _get_torch_metric (self ) -> Any :
121
+ import torcheval .metrics as torchmetrics
122
+
123
+ metric_name = self .getOrDefault (self .metricName )
124
+
125
+ if metric_name == "areaUnderROC" :
126
+ return torchmetrics .BinaryAUROC ()
127
+ if metric_name == "areaUnderPR" :
128
+ return torchmetrics .BinaryAUPRC ()
129
+
130
+ raise ValueError (f"Unsupported binary classification evaluator metric name: { metric_name } " )
131
+
132
+ def _get_input_cols (self ) -> List [str ]:
133
+ return [self .getProbabilityCol (), self .getLabelCol ()]
134
+
135
+ def _get_metric_update_inputs (self , dataset : "pd.DataFrame" ) -> Tuple [Any , Any ]:
136
+ import torch
137
+
138
+ values = np .stack (dataset [self .getProbabilityCol ()].values ) # type: ignore[call-overload]
139
+ preds_tensor = torch .tensor (values )
140
+ if preds_tensor .dim () == 2 :
141
+ preds_tensor = preds_tensor [:, 1 ]
142
+ labels_tensor = torch .tensor (dataset [self .getLabelCol ()].values )
143
+ return preds_tensor , labels_tensor
144
+
145
+
146
+ class MulticlassClassificationEvaluator (
147
+ _TorchMetricEvaluator , HasLabelCol , HasPredictionCol , ParamsReadWrite
148
+ ):
149
+ """
150
+ Evaluator for multiclass classification, which expects input columns prediction and label.
151
+ Supported metrics are 'accuracy'.
152
+
153
+ .. versionadded:: 3.5.0
154
+ """
155
+
156
+ def __init__ (self , metricName : str , labelCol : str , predictionCol : str ) -> None :
157
+ super ().__init__ ()
158
+ self ._set (metricName = metricName , labelCol = labelCol , predictionCol = predictionCol )
159
+
160
+ def _get_torch_metric (self ) -> Any :
161
+ import torcheval .metrics as torchmetrics
162
+
163
+ metric_name = self .getOrDefault (self .metricName )
164
+
165
+ if metric_name == "accuracy" :
166
+ return torchmetrics .MulticlassAccuracy ()
167
+
168
+ raise ValueError (
169
+ f"Unsupported multiclass classification evaluator metric name: { metric_name } "
89
170
)
171
+
172
+ def _get_input_cols (self ) -> List [str ]:
173
+ return [self .getPredictionCol (), self .getLabelCol ()]
174
+
175
+ def _get_metric_update_inputs (self , dataset : "pd.DataFrame" ) -> Tuple [Any , Any ]:
176
+ import torch
177
+
178
+ preds_tensor = torch .tensor (dataset [self .getPredictionCol ()].values )
179
+ labels_tensor = torch .tensor (dataset [self .getLabelCol ()].values )
180
+ return preds_tensor , labels_tensor
0 commit comments