@@ -750,6 +750,10 @@ def check_logcdf(
750
750
if not skip_paramdomain_inside_edge_test :
751
751
domains = paramdomains .copy ()
752
752
domains ["value" ] = domain
753
+
754
+ model , param_vars = build_model (pymc3_dist , domain , paramdomains )
755
+ pymc3_logcdf = model .fastfn (logpt (model ["value" ], cdf = True ))
756
+
753
757
if decimal is None :
754
758
decimal = select_by_precision (float64 = 6 , float32 = 3 )
755
759
@@ -758,17 +762,23 @@ def check_logcdf(
758
762
if skip_params_fn (params ):
759
763
continue
760
764
scipy_cdf = scipy_logcdf (** params )
765
+
766
+ scipy_eval = scipy_logcdf (** params )
761
767
value = params .pop ("value" )
762
- with Model () as m :
763
- dist = pymc3_dist ("y" , ** params )
768
+
769
+ # Update shared parameter variables in pymc3_logcdf function
770
+ for param_name , param_value in params .items ():
771
+ param_vars [param_name ].set_value (param_value )
772
+
773
+ pymc3_eval = pymc3_logcdf ({"value" : value })
774
+
764
775
params ["value" ] = value # for displaying in err_msg
765
- with aesara .config .change_flags (on_opt_error = "raise" , mode = Mode ("py" )):
766
- assert_almost_equal (
767
- logcdf (dist , value ).eval (),
768
- scipy_cdf ,
769
- decimal = decimal ,
770
- err_msg = str (params ),
771
- )
776
+ assert_almost_equal (
777
+ pymc3_eval ,
778
+ scipy_eval ,
779
+ decimal = decimal ,
780
+ err_msg = str (params ),
781
+ )
772
782
773
783
valid_value = domain .vals [0 ]
774
784
valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
0 commit comments