Skip to content

Commit 8016123

Browse files
WeichenXu123zhengruifeng
authored andcommitted
[SPARK-33592] Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
### What changes were proposed in this pull request? Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent. Two typical cases to manually test: ~~~python tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") lr = LogisticRegression() pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) paramGrid = ParamGridBuilder() \ .addGrid(hashingTF.numFeatures, [10, 100]) \ .addGrid(lr.maxIter, [100, 200]) \ .build() tvs = TrainValidationSplit(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=MulticlassClassificationEvaluator()) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ~~~python lr = LogisticRegression() ova = OneVsRest(classifier=lr) grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build() evaluator = MulticlassClassificationEvaluator() tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) tvs.save(tvsPath) loadedTvs = TrainValidationSplit.load(tvsPath) # check `loadedTvs.getEstimatorParamMaps()` restored correctly. ~~~ ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #30539 from WeichenXu123/fix_tuning_param_maps_io. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent aeb3649 commit 8016123

File tree

9 files changed

+268
-107
lines changed

9 files changed

+268
-107
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def __hash__(self):
564564
"pyspark.ml.tests.test_stat",
565565
"pyspark.ml.tests.test_training_summary",
566566
"pyspark.ml.tests.test_tuning",
567+
"pyspark.ml.tests.test_util",
567568
"pyspark.ml.tests.test_wrapper",
568569
],
569570
excluded_python_implementations=[

python/pyspark/ml/classification.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, HasTrainingSummary
3737
from pyspark.ml.wrapper import JavaParams, \
3838
JavaPredictor, JavaPredictionModel, JavaWrapper
39-
from pyspark.ml.common import inherit_doc, _java2py, _py2java
39+
from pyspark.ml.common import inherit_doc
4040
from pyspark.ml.linalg import Vectors
4141
from pyspark.sql import DataFrame
4242
from pyspark.sql.functions import udf, when
@@ -2991,50 +2991,6 @@ def _to_java(self):
29912991
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
29922992
return _java_obj
29932993

2994-
def _make_java_param_pair(self, param, value):
2995-
"""
2996-
Makes a Java param pair.
2997-
"""
2998-
sc = SparkContext._active_spark_context
2999-
param = self._resolveParam(param)
3000-
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
3001-
self.uid)
3002-
java_param = _java_obj.getParam(param.name)
3003-
if isinstance(value, JavaParams):
3004-
# used in the case of an estimator having another estimator as a parameter
3005-
# the reason why this is not in _py2java in common.py is that importing
3006-
# Estimator and Model in common.py results in a circular import with inherit_doc
3007-
java_value = value._to_java()
3008-
else:
3009-
java_value = _py2java(sc, value)
3010-
return java_param.w(java_value)
3011-
3012-
def _transfer_param_map_to_java(self, pyParamMap):
3013-
"""
3014-
Transforms a Python ParamMap into a Java ParamMap.
3015-
"""
3016-
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
3017-
for param in self.params:
3018-
if param in pyParamMap:
3019-
pair = self._make_java_param_pair(param, pyParamMap[param])
3020-
paramMap.put([pair])
3021-
return paramMap
3022-
3023-
def _transfer_param_map_from_java(self, javaParamMap):
3024-
"""
3025-
Transforms a Java ParamMap into a Python ParamMap.
3026-
"""
3027-
sc = SparkContext._active_spark_context
3028-
paramMap = dict()
3029-
for pair in javaParamMap.toList():
3030-
param = pair.param()
3031-
if self.hasParam(str(param.name())):
3032-
if param.name() == "classifier":
3033-
paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
3034-
else:
3035-
paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
3036-
return paramMap
3037-
30382994

30392995
class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
30402996
"""

python/pyspark/ml/param/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ def _resolveParam(self, param):
437437
else:
438438
raise ValueError("Cannot resolve %r as a param." % param)
439439

440+
def _testOwnParam(self, param_parent, param_name):
441+
"""
442+
Test the ownership. Return True or False
443+
"""
444+
return self.uid == param_parent and self.hasParam(param_name)
445+
440446
@staticmethod
441447
def _dummy():
442448
"""

python/pyspark/ml/pipeline.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from pyspark.ml.param import Param, Params
2222
from pyspark.ml.util import MLReadable, MLWritable, JavaMLWriter, JavaMLReader, \
2323
DefaultParamsReader, DefaultParamsWriter, MLWriter, MLReader, JavaMLWritable
24-
from pyspark.ml.wrapper import JavaParams, JavaWrapper
25-
from pyspark.ml.common import inherit_doc, _java2py, _py2java
24+
from pyspark.ml.wrapper import JavaParams
25+
from pyspark.ml.common import inherit_doc
2626

2727

2828
@inherit_doc
@@ -190,55 +190,6 @@ def _to_java(self):
190190

191191
return _java_obj
192192

193-
def _make_java_param_pair(self, param, value):
194-
"""
195-
Makes a Java param pair.
196-
"""
197-
sc = SparkContext._active_spark_context
198-
param = self._resolveParam(param)
199-
java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
200-
if isinstance(value, Params) and hasattr(value, "_to_java"):
201-
# Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which
202-
# implements `_to_java` method (such as OneVsRest, Pipeline object) to java object.
203-
# used in the case of an estimator having another estimator as a parameter
204-
# the reason why this is not in _py2java in common.py is that importing
205-
# Estimator and Model in common.py results in a circular import with inherit_doc
206-
java_value = value._to_java()
207-
else:
208-
java_value = _py2java(sc, value)
209-
return java_param.w(java_value)
210-
211-
def _transfer_param_map_to_java(self, pyParamMap):
212-
"""
213-
Transforms a Python ParamMap into a Java ParamMap.
214-
"""
215-
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
216-
for param in self.params:
217-
if param in pyParamMap:
218-
pair = self._make_java_param_pair(param, pyParamMap[param])
219-
paramMap.put([pair])
220-
return paramMap
221-
222-
def _transfer_param_map_from_java(self, javaParamMap):
223-
"""
224-
Transforms a Java ParamMap into a Python ParamMap.
225-
"""
226-
sc = SparkContext._active_spark_context
227-
paramMap = dict()
228-
for pair in javaParamMap.toList():
229-
param = pair.param()
230-
if self.hasParam(str(param.name())):
231-
java_obj = pair.value()
232-
if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
233-
# Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
234-
# and Estimator/Transformer class which implements `_from_java` static method
235-
# (such as OneVsRest, Pipeline class).
236-
py_obj = JavaParams._from_java(java_obj)
237-
else:
238-
py_obj = _java2py(sc, java_obj)
239-
paramMap[self.getParam(param.name())] = py_obj
240-
return paramMap
241-
242193

243194
@inherit_doc
244195
class PipelineWriter(MLWriter):

python/pyspark/ml/tests/test_tuning.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,21 @@ def test_addGrid(self):
7373
.build())
7474

7575

76-
class CrossValidatorTests(SparkSessionTestCase):
76+
class ValidatorTestUtilsMixin:
77+
def assert_param_maps_equal(self, paramMaps1, paramMaps2):
78+
self.assertEqual(len(paramMaps1), len(paramMaps2))
79+
for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2):
80+
self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys()))
81+
for param in paramMap1.keys():
82+
v1 = paramMap1[param]
83+
v2 = paramMap2[param]
84+
if isinstance(v1, Params):
85+
self.assertEqual(v1.uid, v2.uid)
86+
else:
87+
self.assertEqual(v1, v2)
88+
89+
90+
class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
7791

7892
def test_copy(self):
7993
dataset = self.spark.createDataFrame([
@@ -256,7 +270,7 @@ def test_save_load_simple_estimator(self):
256270
loadedCV = CrossValidator.load(cvPath)
257271
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
258272
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
259-
self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
273+
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
260274

261275
# test save/load of CrossValidatorModel
262276
cvModelPath = temp_path + "/cvModel"
@@ -351,6 +365,7 @@ def test_save_load_nested_estimator(self):
351365
cvPath = temp_path + "/cv"
352366
cv.save(cvPath)
353367
loadedCV = CrossValidator.load(cvPath)
368+
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid)
354369
self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
355370
self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
356371

@@ -367,6 +382,7 @@ def test_save_load_nested_estimator(self):
367382
cvModelPath = temp_path + "/cvModel"
368383
cvModel.save(cvModelPath)
369384
loadedModel = CrossValidatorModel.load(cvModelPath)
385+
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
370386
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
371387

372388
def test_save_load_pipeline_estimator(self):
@@ -401,6 +417,11 @@ def test_save_load_pipeline_estimator(self):
401417
estimatorParamMaps=paramGrid,
402418
evaluator=MulticlassClassificationEvaluator(),
403419
numFolds=2) # use 3+ folds in practice
420+
cvPath = temp_path + "/cv"
421+
crossval.save(cvPath)
422+
loadedCV = CrossValidator.load(cvPath)
423+
self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid)
424+
self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid)
404425

405426
# Run cross-validation, and choose the best set of parameters.
406427
cvModel = crossval.fit(training)
@@ -421,6 +442,11 @@ def test_save_load_pipeline_estimator(self):
421442
estimatorParamMaps=paramGrid,
422443
evaluator=MulticlassClassificationEvaluator(),
423444
numFolds=2) # use 3+ folds in practice
445+
cv2Path = temp_path + "/cv2"
446+
crossval2.save(cv2Path)
447+
loadedCV2 = CrossValidator.load(cv2Path)
448+
self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid)
449+
self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid)
424450

425451
# Run cross-validation, and choose the best set of parameters.
426452
cvModel2 = crossval2.fit(training)
@@ -511,7 +537,7 @@ def test_invalid_user_specified_folds(self):
511537
cv.fit(dataset_with_folds)
512538

513539

514-
class TrainValidationSplitTests(SparkSessionTestCase):
540+
class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
515541

516542
def test_fit_minimize_metric(self):
517543
dataset = self.spark.createDataFrame([
@@ -632,7 +658,8 @@ def test_save_load_simple_estimator(self):
632658
loadedTvs = TrainValidationSplit.load(tvsPath)
633659
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
634660
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
635-
self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
661+
self.assert_param_maps_equal(
662+
loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
636663

637664
tvsModelPath = temp_path + "/tvsModel"
638665
tvsModel.save(tvsModelPath)
@@ -713,6 +740,7 @@ def test_save_load_nested_estimator(self):
713740
tvsPath = temp_path + "/tvs"
714741
tvs.save(tvsPath)
715742
loadedTvs = TrainValidationSplit.load(tvsPath)
743+
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid)
716744
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
717745
self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
718746

@@ -728,6 +756,7 @@ def test_save_load_nested_estimator(self):
728756
tvsModelPath = temp_path + "/tvsModel"
729757
tvsModel.save(tvsModelPath)
730758
loadedModel = TrainValidationSplitModel.load(tvsModelPath)
759+
self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
731760
self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
732761

733762
def test_save_load_pipeline_estimator(self):
@@ -761,6 +790,11 @@ def test_save_load_pipeline_estimator(self):
761790
tvs = TrainValidationSplit(estimator=pipeline,
762791
estimatorParamMaps=paramGrid,
763792
evaluator=MulticlassClassificationEvaluator())
793+
tvsPath = temp_path + "/tvs"
794+
tvs.save(tvsPath)
795+
loadedTvs = TrainValidationSplit.load(tvsPath)
796+
self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid)
797+
self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
764798

765799
# Run train validation split, and choose the best set of parameters.
766800
tvsModel = tvs.fit(training)
@@ -780,6 +814,11 @@ def test_save_load_pipeline_estimator(self):
780814
tvs2 = TrainValidationSplit(estimator=nested_pipeline,
781815
estimatorParamMaps=paramGrid,
782816
evaluator=MulticlassClassificationEvaluator())
817+
tvs2Path = temp_path + "/tvs2"
818+
tvs2.save(tvs2Path)
819+
loadedTvs2 = TrainValidationSplit.load(tvs2Path)
820+
self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid)
821+
self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid)
783822

784823
# Run train validation split, and choose the best set of parameters.
785824
tvsModel2 = tvs2.fit(training)

python/pyspark/ml/tests/test_util.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.ml import Pipeline
21+
from pyspark.ml.classification import LogisticRegression, OneVsRest
22+
from pyspark.ml.feature import VectorAssembler
23+
from pyspark.ml.linalg import Vectors
24+
from pyspark.ml.util import MetaAlgorithmReadWrite
25+
from pyspark.testing.mlutils import SparkSessionTestCase
26+
27+
28+
class MetaAlgorithmReadWriteTests(SparkSessionTestCase):
29+
30+
def test_getAllNestedStages(self):
31+
def _check_uid_set_equal(stages, expected_stages):
32+
uids = set(map(lambda x: x.uid, stages))
33+
expected_uids = set(map(lambda x: x.uid, expected_stages))
34+
self.assertEqual(uids, expected_uids)
35+
36+
df1 = self.spark.createDataFrame([
37+
(Vectors.dense([1., 2.]), 1.0),
38+
(Vectors.dense([-1., -2.]), 0.0),
39+
], ['features', 'label'])
40+
df2 = self.spark.createDataFrame([
41+
(1., 2., 1.0),
42+
(1., 2., 0.0),
43+
], ['a', 'b', 'label'])
44+
vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features')
45+
lr = LogisticRegression()
46+
pipeline = Pipeline(stages=[vs, lr])
47+
pipelineModel = pipeline.fit(df2)
48+
ova = OneVsRest(classifier=lr)
49+
ovaModel = ova.fit(df1)
50+
51+
ova_pipeline = Pipeline(stages=[vs, ova])
52+
nested_pipeline = Pipeline(stages=[ova_pipeline])
53+
54+
_check_uid_set_equal(
55+
MetaAlgorithmReadWrite.getAllNestedStages(pipeline),
56+
[pipeline, vs, lr]
57+
)
58+
_check_uid_set_equal(
59+
MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
60+
[pipelineModel] + pipelineModel.stages
61+
)
62+
_check_uid_set_equal(
63+
MetaAlgorithmReadWrite.getAllNestedStages(ova),
64+
[ova, lr]
65+
)
66+
_check_uid_set_equal(
67+
MetaAlgorithmReadWrite.getAllNestedStages(ovaModel),
68+
[ovaModel, lr] + ovaModel.models
69+
)
70+
_check_uid_set_equal(
71+
MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
72+
[nested_pipeline, ova_pipeline, vs, ova, lr]
73+
)
74+
75+
76+
if __name__ == "__main__":
77+
from pyspark.ml.tests.test_util import * # noqa: F401
78+
79+
try:
80+
import xmlrunner # type: ignore[import]
81+
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
82+
except ImportError:
83+
testRunner = None
84+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)