Skip to content

Commit 7fcabef

Browse files
committed
[SPARK-44250][ML][PYTHON][CONNECT] Implement classification evaluator
### What changes were proposed in this pull request? Implement classification evaluator ### Why are the changes needed? Distributed ML <> spark connect project. ### Does this PR introduce _any_ user-facing change? Yes. `BinaryClassificationEvaluator` and `MulticlassClassificationEvaluator` are added. ### How was this patch tested? Closes #41793 from WeichenXu123/classification-evaluator. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 7bc28d5 commit 7fcabef

File tree

2 files changed

+202
-36
lines changed

2 files changed

+202
-36
lines changed

python/pyspark/ml/connect/evaluation.py

Lines changed: 126 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,61 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import numpy as np
1718

1819
import pandas as pd
19-
from typing import Any, Union
20+
from typing import Any, Union, List, Tuple
2021

2122
from pyspark.ml.param import Param, Params, TypeConverters
22-
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol
23+
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol
2324
from pyspark.ml.connect.base import Evaluator
2425
from pyspark.ml.connect.io_utils import ParamsReadWrite
2526
from pyspark.ml.connect.util import aggregate_dataframe
2627
from pyspark.sql import DataFrame
2728

28-
import torch
29-
import torcheval.metrics as torchmetrics
3029

30+
class _TorchMetricEvaluator(Evaluator):
3131

32-
class RegressionEvaluator(Evaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite):
32+
metricName: Param[str] = Param(
33+
Params._dummy(),
34+
"metricName",
35+
"metric name for the regression evaluator, valid values are 'mse' and 'r2'",
36+
typeConverter=TypeConverters.toString,
37+
)
38+
39+
def _get_torch_metric(self) -> Any:
40+
raise NotImplementedError()
41+
42+
def _get_input_cols(self) -> List[str]:
43+
raise NotImplementedError()
44+
45+
def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
46+
raise NotImplementedError()
47+
48+
def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
49+
torch_metric = self._get_torch_metric()
50+
51+
def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame":
52+
torch_metric.update(*self._get_metric_update_inputs(pandas_df))
53+
return torch_metric
54+
55+
def merge_agg_state(state1: Any, state2: Any) -> Any:
56+
state1.merge_state([state2])
57+
return state1
58+
59+
def agg_state_to_result(state: Any) -> Any:
60+
return state.compute().item()
61+
62+
return aggregate_dataframe(
63+
dataset,
64+
self._get_input_cols(),
65+
local_agg_fn,
66+
merge_agg_state,
67+
agg_state_to_result,
68+
)
69+
70+
71+
class RegressionEvaluator(_TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite):
3372
"""
3473
Evaluator for Regression, which expects input columns prediction and label.
3574
Supported metrics are 'mse' and 'r2'.
@@ -41,14 +80,9 @@ def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None:
4180
super().__init__()
4281
self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol)
4382

44-
metricName: Param[str] = Param(
45-
Params._dummy(),
46-
"metricName",
47-
"metric name for the regression evaluator, valid values are 'mse' and 'r2'",
48-
typeConverter=TypeConverters.toString,
49-
)
50-
5183
def _get_torch_metric(self) -> Any:
84+
import torcheval.metrics as torchmetrics
85+
5286
metric_name = self.getOrDefault(self.metricName)
5387

5488
if metric_name == "mse":
@@ -58,32 +92,89 @@ def _get_torch_metric(self) -> Any:
5892

5993
raise ValueError(f"Unsupported regressor evaluator metric name: {metric_name}")
6094

61-
def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
62-
prediction_col = self.getPredictionCol()
63-
label_col = self.getLabelCol()
95+
def _get_input_cols(self) -> List[str]:
96+
return [self.getPredictionCol(), self.getLabelCol()]
6497

65-
torch_metric = self._get_torch_metric()
98+
def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
99+
import torch
66100

67-
def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame":
68-
with torch.inference_mode():
69-
preds_tensor = torch.tensor(pandas_df[prediction_col].values)
70-
labels_tensor = torch.tensor(pandas_df[label_col].values)
71-
torch_metric.update(preds_tensor, labels_tensor)
72-
return torch_metric
101+
preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values)
102+
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
103+
return preds_tensor, labels_tensor
73104

74-
def merge_agg_state(state1: Any, state2: Any) -> Any:
75-
with torch.inference_mode():
76-
state1.merge_state([state2])
77-
return state1
78105

79-
def agg_state_to_result(state: Any) -> Any:
80-
with torch.inference_mode():
81-
return state.compute().item()
106+
class BinaryClassificationEvaluator(
107+
_TorchMetricEvaluator, HasLabelCol, HasProbabilityCol, ParamsReadWrite
108+
):
109+
"""
110+
Evaluator for binary classification, which expects input columns prediction and label.
111+
Supported metrics are 'areaUnderROC' and 'areaUnderPR'.
82112
83-
return aggregate_dataframe(
84-
dataset,
85-
[prediction_col, label_col],
86-
local_agg_fn,
87-
merge_agg_state,
88-
agg_state_to_result,
113+
.. versionadded:: 3.5.0
114+
"""
115+
116+
def __init__(self, metricName: str, labelCol: str, probabilityCol: str) -> None:
117+
super().__init__()
118+
self._set(metricName=metricName, labelCol=labelCol, probabilityCol=probabilityCol)
119+
120+
def _get_torch_metric(self) -> Any:
121+
import torcheval.metrics as torchmetrics
122+
123+
metric_name = self.getOrDefault(self.metricName)
124+
125+
if metric_name == "areaUnderROC":
126+
return torchmetrics.BinaryAUROC()
127+
if metric_name == "areaUnderPR":
128+
return torchmetrics.BinaryAUPRC()
129+
130+
raise ValueError(f"Unsupported binary classification evaluator metric name: {metric_name}")
131+
132+
def _get_input_cols(self) -> List[str]:
133+
return [self.getProbabilityCol(), self.getLabelCol()]
134+
135+
def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
136+
import torch
137+
138+
values = np.stack(dataset[self.getProbabilityCol()].values) # type: ignore[call-overload]
139+
preds_tensor = torch.tensor(values)
140+
if preds_tensor.dim() == 2:
141+
preds_tensor = preds_tensor[:, 1]
142+
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
143+
return preds_tensor, labels_tensor
144+
145+
146+
class MulticlassClassificationEvaluator(
147+
_TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite
148+
):
149+
"""
150+
Evaluator for multiclass classification, which expects input columns prediction and label.
151+
Supported metrics are 'accuracy'.
152+
153+
.. versionadded:: 3.5.0
154+
"""
155+
156+
def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None:
157+
super().__init__()
158+
self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol)
159+
160+
def _get_torch_metric(self) -> Any:
161+
import torcheval.metrics as torchmetrics
162+
163+
metric_name = self.getOrDefault(self.metricName)
164+
165+
if metric_name == "accuracy":
166+
return torchmetrics.MulticlassAccuracy()
167+
168+
raise ValueError(
169+
f"Unsupported multiclass classification evaluator metric name: {metric_name}"
89170
)
171+
172+
def _get_input_cols(self) -> List[str]:
173+
return [self.getPredictionCol(), self.getLabelCol()]
174+
175+
def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
176+
import torch
177+
178+
preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values)
179+
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
180+
return preds_tensor, labels_tensor

python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
import unittest
1919
import numpy as np
2020

21-
from pyspark.ml.connect.evaluation import RegressionEvaluator
21+
from pyspark.ml.connect.evaluation import (
22+
RegressionEvaluator,
23+
BinaryClassificationEvaluator,
24+
MulticlassClassificationEvaluator,
25+
)
2226
from pyspark.sql import SparkSession
2327

2428

@@ -66,6 +70,77 @@ def test_regressor_evaluator(self):
6670
np.testing.assert_almost_equal(r2, expected_r2)
6771
np.testing.assert_almost_equal(r2_local, expected_r2)
6872

73+
def test_binary_classifier_evaluator(self):
74+
df1 = self.spark.createDataFrame(
75+
[
76+
(1, 0.2, [0.8, 0.2]),
77+
(0, 0.6, [0.4, 0.6]),
78+
(1, 0.8, [0.2, 0.8]),
79+
(1, 0.7, [0.3, 0.7]),
80+
(0, 0.4, [0.6, 0.4]),
81+
(0, 0.3, [0.7, 0.3]),
82+
],
83+
schema=["label", "prob", "prob2"],
84+
)
85+
86+
local_df1 = df1.toPandas()
87+
88+
for prob_col in ["prob", "prob2"]:
89+
auroc_evaluator = BinaryClassificationEvaluator(
90+
metricName="areaUnderROC",
91+
labelCol="label",
92+
probabilityCol=prob_col,
93+
)
94+
95+
expected_auroc = 0.6667
96+
auroc = auroc_evaluator.evaluate(df1)
97+
auroc_local = auroc_evaluator.evaluate(local_df1)
98+
np.testing.assert_almost_equal(auroc, expected_auroc, decimal=2)
99+
np.testing.assert_almost_equal(auroc_local, expected_auroc, decimal=2)
100+
101+
auprc_evaluator = BinaryClassificationEvaluator(
102+
metricName="areaUnderPR",
103+
labelCol="label",
104+
probabilityCol=prob_col,
105+
)
106+
107+
expected_auprc = 0.8333
108+
auprc = auprc_evaluator.evaluate(df1)
109+
auprc_local = auprc_evaluator.evaluate(local_df1)
110+
np.testing.assert_almost_equal(auprc, expected_auprc, decimal=2)
111+
np.testing.assert_almost_equal(auprc_local, expected_auprc, decimal=2)
112+
113+
def test_multiclass_classifier_evaluator(self):
114+
df1 = self.spark.createDataFrame(
115+
[
116+
(1, 1),
117+
(1, 1),
118+
(2, 3),
119+
(0, 0),
120+
(0, 1),
121+
(3, 1),
122+
(3, 3),
123+
(2, 2),
124+
(1, 0),
125+
(2, 2),
126+
],
127+
schema=["label", "prediction"],
128+
)
129+
130+
local_df1 = df1.toPandas()
131+
132+
accuracy_evaluator = MulticlassClassificationEvaluator(
133+
metricName="accuracy",
134+
labelCol="label",
135+
predictionCol="prediction",
136+
)
137+
138+
expected_accuracy = 0.600
139+
accuracy = accuracy_evaluator.evaluate(df1)
140+
accuracy_local = accuracy_evaluator.evaluate(local_df1)
141+
np.testing.assert_almost_equal(accuracy, expected_accuracy, decimal=2)
142+
np.testing.assert_almost_equal(accuracy_local, expected_accuracy, decimal=2)
143+
69144

70145
@unittest.skipIf(not have_torcheval, "torcheval is required")
71146
class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):

0 commit comments

Comments
 (0)