Skip to content

Moved logsumexp #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed

Moved logsumexp #351

wants to merge 7 commits into from

Conversation

KladeRe
Copy link

@KladeRe KladeRe commented Jun 20, 2023

Motivation for these changes

Closes #350

Implementation details

Moved the function logsumexp and its tests from math.py to special.py and all its dependencies.

Checklist

Major / Breaking Changes

  • ...

New features

  • logsumexp is now fully in special.py

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 20, 2023

Any tests and rewrites (like local_log_sum_exp) pertaining to logsumexp should also be moved from .../math.py to .../special.py

@KladeRe
Copy link
Author

KladeRe commented Jun 20, 2023

I now moved local_logsumexp to special.py and test_logsumexp to tests/tensor/test_special.py. Are there any other tests or functions which need to be moved?

@ricardoV94
Copy link
Member

Thanks @KladeRe, I think the only other thing are the tests here:

def test_local_log_sum_exp_maximum():
"""Test that the rewrite is applied by checking the presence of the maximum."""
x = tensor3("x")
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
x = TensorType(dtype="floatX", shape=(None, 1, None))("x")
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)
def test_local_log_sum_exp_near_one():
"""Test that the rewritten result is correct around 1.0."""
x = tensor3("x")
x_val = 1.0 + np.random.random((4, 3, 2)).astype(config.floatX) / 10.0
f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = np.log(np.sum(np.exp(x_val), axis=1))
rewritten_ret = f(x_val)
assert np.allclose(naive_ret, rewritten_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T)
rewritten_ret = f(x_val)
assert np.allclose(naive_ret, rewritten_ret)
def test_local_log_sum_exp_large():
"""Test that the rewrite result is correct for extreme value 100."""
x = vector("x")
f = compile_graph_log_sum_exp(x, axis=0)
x_val = np.array([-100.0, 100.0]).astype(config.floatX)
rewritten_ret = f(x_val)
assert np.allclose(rewritten_ret, 100.0)
def test_local_log_sum_exp_inf():
"""Test that when max = +-inf, the rewritten output still works correctly."""
x = vector("x")
f = compile_graph_log_sum_exp(x, axis=0)
assert f([-np.inf, -np.inf]) == -np.inf
assert f([np.inf, np.inf]) == np.inf
assert f([-np.inf, np.inf]) == np.inf

@KladeRe
Copy link
Author

KladeRe commented Jun 20, 2023

Alright, everything related to logsumexp is now either in special.py or its tests.

@ricardoV94 ricardoV94 marked this pull request as draft September 5, 2023 14:48
@KladeRe KladeRe closed this by deleting the head repository Jan 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move logsumexp to tensor.special
2 participants