@@ -499,6 +499,10 @@ def load(self, path):
499
499
bestModelPath = os .path .join (path , 'bestModel' )
500
500
bestModel = DefaultParamsReader .loadParamsInstance (bestModelPath , self .sc )
501
501
avgMetrics = metadata ['avgMetrics' ]
502
+ if 'stdMetrics' in metadata :
503
+ stdMetrics = metadata ['stdMetrics' ]
504
+ else :
505
+ stdMetrics = None
502
506
persistSubModels = ('persistSubModels' in metadata ) and metadata ['persistSubModels' ]
503
507
504
508
if persistSubModels :
@@ -512,7 +516,9 @@ def load(self, path):
512
516
else :
513
517
subModels = None
514
518
515
- cvModel = CrossValidatorModel (bestModel , avgMetrics = avgMetrics , subModels = subModels )
519
+ cvModel = CrossValidatorModel (
520
+ bestModel , avgMetrics = avgMetrics , subModels = subModels , stdMetrics = stdMetrics
521
+ )
516
522
cvModel = cvModel ._resetUid (metadata ['uid' ])
517
523
cvModel .set (cvModel .estimator , estimator )
518
524
cvModel .set (cvModel .estimatorParamMaps , estimatorParamMaps )
@@ -536,6 +542,9 @@ def saveImpl(self, path):
536
542
.getValidatorModelWriterPersistSubModelsParam (self )
537
543
extraMetadata = {'avgMetrics' : instance .avgMetrics ,
538
544
'persistSubModels' : persistSubModels }
545
+ if instance .stdMetrics :
546
+ extraMetadata ['stdMetrics' ] = instance .stdMetrics
547
+
539
548
_ValidatorSharedReadWrite .saveImpl (path , instance , self .sc , extraMetadata = extraMetadata )
540
549
bestModelPath = os .path .join (path , 'bestModel' )
541
550
instance .bestModel .save (bestModelPath )
@@ -710,13 +719,19 @@ def setCollectSubModels(self, value):
710
719
"""
711
720
return self ._set (collectSubModels = value )
712
721
722
+ @staticmethod
723
+ def _gen_avg_and_std_metrics (metrics_all ):
724
+ avg_metrics = np .mean (metrics_all , axis = 0 )
725
+ std_metrics = np .std (metrics_all , axis = 0 )
726
+ return list (avg_metrics ), list (std_metrics )
727
+
713
728
def _fit (self , dataset ):
714
729
est = self .getOrDefault (self .estimator )
715
730
epm = self .getOrDefault (self .estimatorParamMaps )
716
731
numModels = len (epm )
717
732
eva = self .getOrDefault (self .evaluator )
718
733
nFolds = self .getOrDefault (self .numFolds )
719
- metrics = [0.0 ] * numModels
734
+ metrics_all = [[ 0.0 ] * numModels for i in range ( nFolds )]
720
735
721
736
pool = ThreadPool (processes = min (self .getParallelism (), numModels ))
722
737
subModels = None
@@ -733,19 +748,21 @@ def _fit(self, dataset):
733
748
inheritable_thread_target ,
734
749
_parallelFitTasks (est , train , eva , validation , epm , collectSubModelsParam ))
735
750
for j , metric , subModel in pool .imap_unordered (lambda f : f (), tasks ):
736
- metrics [ j ] += ( metric / nFolds )
751
+ metrics_all [ i ][ j ] = metric
737
752
if collectSubModelsParam :
738
753
subModels [i ][j ] = subModel
739
754
740
755
validation .unpersist ()
741
756
train .unpersist ()
742
757
758
+ metrics , std_metrics = CrossValidator ._gen_avg_and_std_metrics (metrics_all )
759
+
743
760
if eva .isLargerBetter ():
744
761
bestIndex = np .argmax (metrics )
745
762
else :
746
763
bestIndex = np .argmin (metrics )
747
764
bestModel = est .fit (dataset , epm [bestIndex ])
748
- return self ._copyValues (CrossValidatorModel (bestModel , metrics , subModels ))
765
+ return self ._copyValues (CrossValidatorModel (bestModel , metrics , subModels , std_metrics ))
749
766
750
767
def _kFold (self , dataset ):
751
768
nFolds = self .getOrDefault (self .numFolds )
@@ -875,15 +892,20 @@ def _to_java(self):
875
892
876
893
class CrossValidatorModel (Model , _CrossValidatorParams , MLReadable , MLWritable ):
877
894
"""
878
-
879
895
CrossValidatorModel contains the model with the highest average cross-validation
880
896
metric across folds and uses this model to transform input data. CrossValidatorModel
881
897
also tracks the metrics for each param map evaluated.
882
898
883
899
.. versionadded:: 1.4.0
900
+
901
+ Notes
902
+ -----
903
+ Since version 3.3.0, CrossValidatorModel contains a new attribute "stdMetrics",
904
+ which represent standard deviation of metrics for each paramMap in
905
+ CrossValidator.estimatorParamMaps.
884
906
"""
885
907
886
- def __init__ (self , bestModel , avgMetrics = None , subModels = None ):
908
+ def __init__ (self , bestModel , avgMetrics = None , subModels = None , stdMetrics = None ):
887
909
super (CrossValidatorModel , self ).__init__ ()
888
910
#: best model from cross validation
889
911
self .bestModel = bestModel
@@ -892,6 +914,9 @@ def __init__(self, bestModel, avgMetrics=None, subModels=None):
892
914
self .avgMetrics = avgMetrics or []
893
915
#: sub model list from cross validation
894
916
self .subModels = subModels
917
+ #: standard deviation of metrics for each paramMap in
918
+ #: CrossValidator.estimatorParamMaps, in the corresponding order.
919
+ self .stdMetrics = stdMetrics or []
895
920
896
921
def _transform (self , dataset ):
897
922
return self .bestModel .transform (dataset )
@@ -924,7 +949,9 @@ def copy(self, extra=None):
924
949
[sub_model .copy () for sub_model in fold_sub_models ]
925
950
for fold_sub_models in self .subModels
926
951
]
927
- return self ._copyValues (CrossValidatorModel (bestModel , avgMetrics , subModels ), extra = extra )
952
+ stdMetrics = list (self .stdMetrics )
953
+ return self ._copyValues (CrossValidatorModel (bestModel , avgMetrics , subModels , stdMetrics ),
954
+ extra = extra )
928
955
929
956
@since ("2.3.0" )
930
957
def write (self ):
0 commit comments