-
Notifications
You must be signed in to change notification settings - Fork 133
Default acc_dtype to floatX #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Some of the tests are failing:
Although |
Yes I think the function can't and shouldn't downcast the accumulation phase. If the input is float64 the same should be used for accumulation by default. The float32 dtype you're talking is the output type. Why did it choose float32 for the accumulator? |
I agree, but I think at most
I also agree, but the failing tests are for
I see. For
Presumably because we're using |
@ricardoV94 I found out why I'd like to add a regression test using the MWE you provided. If you could point me to a good place to include such a test, I would be grateful! |
There were more offending tests, so I changed the default in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good just some nitpick
For the regression test, probably in the same file where the rewrite is being tested |
Well, we've reached an impasse I think. For some tests, unless we upcast the accumulator to def test_mean_precision(self):
# Check that the default accumulator precision is sufficient
x = shared(np.asarray([1e8, 1, -1e8], dtype="float32"))
m = x.mean()
f = function([], m)
m_val = f()
assert np.allclose(m_val, 1.0 / 3) The only way this test will succeed is if the internal accumulator uses |
Closed in favor of #656 |
Description
This PR changes the default data type for the internal accumulator to
config.floatX
when the input data type isfloat
. The previous behavior always attempted to upcastacc_dtype
, which caused the graph to include higher precision floats even when the user requestedfloatX = 'float32'
.Related Issue
local_sum_make_vector
rewrite can introduce forbidden float64 operations at the graph level #653Checklist
Type of change