-
Notifications
You must be signed in to change notification settings - Fork 132
Reuse output of Join in C backend #1340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reuse output of Join in C backend #1340
Conversation
d8efc2f
to
9b81d41
Compare
c6536ac
to
deb4a19
Compare
81e22db
to
4d3f133
Compare
4d3f133
to
948977d
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (97.29%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1340 +/- ##
=======================================
Coverage 82.13% 82.13%
=======================================
Files 211 211
Lines 49773 49734 -39
Branches 8830 8816 -14
=======================================
- Hits 40879 40850 -29
+ Misses 6714 6709 -5
+ Partials 2180 2175 -5
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (3)
tests/tensor/test_basic.py:1766
- The exception expectation was changed from IndexError to ValueError. Please verify that all related code and tests are updated accordingly.
with pytest.raises(ValueError):
pytensor/tensor/basic.py:2415
- The error message for out‐of‐bounds axis errors has been updated to match numpy's standard. Confirm that this change is intentional and that no downstream dependencies rely on the old message.
numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2
pytensor/link/numba/dispatch/tensor_basic.py:122
- Switching to axis.item() assumes the axis input is a 0-d array; please ensure that all cases (including negative axes and non-array types) are correctly handled by this update.
return np.concatenate(tensors, axis.item())
948977d
to
2e0c79d
Compare
pytensor/tensor/basic.py
Outdated
|
||
Joined tensors must have the same rank | ||
>>> pt.join(0, x, u) | ||
Traceback (most recent call last): | ||
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1]. | ||
TypeError: Only tensors with the same number of dimensions can be joined |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love that we're losing information in the traceback, is that 100% necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, reintroduced
pytensor/tensor/basic.py
Outdated
raise TypeError( | ||
"Join cannot handle arguments of dimension 0." | ||
" Use `stack` to join scalar values." | ||
" Use `stack` to join scalar values and/or increase rank of scalars." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "rank" the right word here? I only know it in the context of matrices as a synonym for column span.
Maybe "Use stack
to join scalar values into higher dimensional objects" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reworded
|
||
# Most times axis is constant, inline it | ||
# This is safe to do because the hash of the c_code includes the constant signature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "constant signature"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values of the constant are used to determine if a cached c_impl can be reused or not, otherwise it would be dangerous to inline the values like we're doing here for performance
[old_output] = node.outputs | ||
|
||
if ret.dtype != old_output.dtype: | ||
ret = ret.astype(old_output.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any savings to doing this check vs just always casting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be some extra work in cast:
def cast(x, dtype: str | np.dtype) -> TensorVariable:
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX
dtype_name = np.dtype(dtype).name
_x = as_tensor_variable(x)
if _x.type.dtype == dtype_name:
return _x
if _x.type.dtype.startswith("complex") and not dtype_name.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider real(), "
"imag(), angle() or abs()"
)
return _cast_mapping[dtype_name](x)
Not super crazy, but I don't mind too much the one check to save those extra cycles
2e0c79d
to
e828a98
Compare
Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
e828a98
to
8e4bb7e
Compare
When GC is disabled we can get a meaningful speedup by managing the copy to the output buffer ourselves instead of using
PyArray_CONCATENATE
which doesn't allow passing an output.I didn't circumnvent
PyArray_CONCATENATE
altogether because numpy has clever logic to allocate the output array so that strides are as aligned with the inputs as possible. Copying this logic to the case where we can't reuse the buffer would add quite some complexity to our codebase.Also simplified the implementation by removing the exotic
view_flag
, which closes #753Benchmark with new test
Before
After
📚 Documentation preview 📚: https://pytensor--1340.org.readthedocs.build/en/1340/