Skip to content

Reuse output of Join in C backend #1340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 1, 2025

When GC is disabled we can get a meaningful speedup by managing the copy to the output buffer ourselves instead of using PyArray_CONCATENATE which doesn't allow passing an output.

I didn't circumnvent PyArray_CONCATENATE altogether because numpy has clever logic to allocate the output array so that strides are as aligned with the inputs as possible. Copying this logic to the case where we can't reuse the buffer would add quite some complexity to our codebase.

Also simplified the implementation by removing the exotic view_flag, which closes #753

Benchmark with new test

Before
------------------------------------------------------------------------------------------------------ benchmark: 14 tests ------------------------------------------------------------------------------------------------------
Name (time in us)                                                 Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_join_performance[vector-axis=0-C-contiguous-gc=False]     1.7930 (1.0)       67.6160 (1.62)     1.9835 (1.0)      1.0833 (2.68)     1.9130 (1.0)      0.0500 (1.02)      105;331      504.1614 (1.0)       16058           1
test_join_performance[vector-axis=0-C-contiguous-gc=True]      1.8630 (1.04)      42.3300 (1.01)     1.9850 (1.00)     0.4037 (1.0)      1.9640 (1.03)     0.0490 (1.0)       144;280      503.7859 (1.00)      17698           1
test_join_performance[matrix-axis=0-C-contiguous-gc=False]     5.2790 (2.94)      93.8450 (2.24)     6.8665 (3.46)     2.1207 (5.25)     6.5620 (3.43)     0.6610 (13.49)   1119;2611      145.6345 (0.29)      31241           1
test_join_performance[matrix-axis=1-F-contiguous-gc=True]      5.2790 (2.94)     232.6360 (5.56)     6.4849 (3.27)     2.9288 (7.26)     6.1120 (3.19)     0.6110 (12.47)    524;1934      154.2043 (0.31)      32470           1
test_join_performance[matrix-axis=0-C-contiguous-gc=True]      5.5010 (3.07)      41.8190 (1.0)      6.8975 (3.48)     1.2191 (3.02)     6.7120 (3.51)     0.7820 (15.96)     423;332      144.9798 (0.29)       7710           1
test_join_performance[matrix-axis=1-F-contiguous-gc=False]     5.5600 (3.10)      96.0400 (2.30)     6.6411 (3.35)     2.0886 (5.17)     6.4020 (3.35)     0.4300 (8.78)    1061;2063      150.5766 (0.30)      40362           1
test_join_performance[matrix-axis=1-C-contiguous-gc=True]      6.2110 (3.46)     128.7820 (3.08)     6.9118 (3.48)     2.2850 (5.66)     6.7520 (3.53)     0.2600 (5.31)      485;773      144.6810 (0.29)      21447           1
test_join_performance[matrix-axis=1-C-contiguous-gc=False]     6.2420 (3.48)     164.4780 (3.93)     6.9927 (3.53)     2.6832 (6.65)     6.7220 (3.51)     0.2410 (4.92)     927;1771      143.0059 (0.28)      32791           1
test_join_performance[matrix-axis=0-F-contiguous-gc=True]      6.3820 (3.56)     174.2060 (4.17)     7.0810 (3.57)     2.3573 (5.84)     6.7830 (3.55)     0.4820 (9.84)     744;1426      141.2225 (0.28)      35815           1
test_join_performance[matrix-axis=0-F-contiguous-gc=False]     6.7330 (3.76)     178.5840 (4.27)     7.4504 (3.76)     3.6387 (9.01)     7.1530 (3.74)     0.2190 (4.47)      180;371      134.2214 (0.27)       9803           1
test_join_performance[matrix-axis=1-Mixed-gc=False]            8.4960 (4.74)     266.8400 (6.38)     9.3448 (4.71)     3.0742 (7.62)     8.8570 (4.63)     0.1700 (3.47)    1083;3208      107.0115 (0.21)      25238           1
test_join_performance[matrix-axis=0-Mixed-gc=False]            8.5860 (4.79)     146.2040 (3.50)     9.7387 (4.91)     3.7827 (9.37)     9.3180 (4.87)     0.3490 (7.12)     935;2404      102.6827 (0.20)      38524           1
test_join_performance[matrix-axis=0-Mixed-gc=True]             8.6360 (4.82)      66.5050 (1.59)     9.6553 (4.87)     1.5583 (3.86)     9.4170 (4.92)     0.3710 (7.57)      636;985      103.5696 (0.21)      16076           1
test_join_performance[matrix-axis=1-Mixed-gc=True]             8.6360 (4.82)     249.1970 (5.96)     9.1403 (4.61)     2.6017 (6.45)     8.9460 (4.68)     0.1100 (2.24)     695;1458      109.4057 (0.22)      28300           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
After
----------------------------------------------------------------------------------------------------- benchmark: 14 tests ------------------------------------------------------------------------------------------------------
Name (time in us)                                                 Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_join_performance[vector-axis=0-C-contiguous-gc=False]     1.3130 (1.0)       42.8210 (5.85)     1.4596 (1.0)      0.4355 (2.12)     1.4230 (1.0)      0.0510 (1.02)    1453;4965      685.1113 (1.0)      108614           1
test_join_performance[vector-axis=0-C-contiguous-gc=True]      1.8940 (1.44)       7.3140 (1.0)      2.0291 (1.39)     0.2053 (1.0)      2.0130 (1.41)     0.0500 (1.0)       188;362      492.8326 (0.72)      16798           1
test_join_performance[matrix-axis=0-C-contiguous-gc=False]     4.5780 (3.49)      41.6180 (5.69)     4.7945 (3.28)     0.6530 (3.18)     4.7190 (3.32)     0.0510 (1.02)     524;1560      208.5727 (0.30)      27020           1
test_join_performance[matrix-axis=1-F-contiguous-gc=False]     4.6690 (3.56)      95.9900 (13.12)    5.3016 (3.63)     1.5013 (7.31)     5.1190 (3.60)     0.2010 (4.02)    1257;3622      188.6222 (0.28)      44284           1
test_join_performance[matrix-axis=0-C-contiguous-gc=True]      5.0790 (3.87)     110.3970 (15.09)    5.6590 (3.88)     2.2420 (10.92)    5.2500 (3.69)     0.0810 (1.62)     422;3279      176.7105 (0.26)      18352           1
test_join_performance[matrix-axis=1-F-contiguous-gc=True]      5.1000 (3.88)      92.7940 (12.69)    5.6332 (3.86)     2.0571 (10.02)    5.4100 (3.80)     0.2610 (5.22)     548;2454      177.5185 (0.26)      35458           1
test_join_performance[matrix-axis=1-C-contiguous-gc=False]     5.6200 (4.28)     121.6780 (16.64)    6.0844 (4.17)     2.1302 (10.37)    5.9310 (4.17)     0.1200 (2.40)     379;1630      164.3542 (0.24)      35483           1
test_join_performance[matrix-axis=0-F-contiguous-gc=False]     5.9110 (4.50)      87.9250 (12.02)    6.7605 (4.63)     4.5311 (22.07)    6.2710 (4.41)     0.1402 (2.80)       32;160      147.9172 (0.22)       2533           1
test_join_performance[matrix-axis=1-C-contiguous-gc=True]      5.9610 (4.54)      80.6110 (11.02)    6.5380 (4.48)     2.6319 (12.82)    6.2920 (4.42)     0.1410 (2.82)      122;832      152.9522 (0.22)      10727           1
test_join_performance[matrix-axis=0-F-contiguous-gc=True]      6.2320 (4.75)      90.2890 (12.34)    6.6963 (4.59)     1.0750 (5.24)     6.5830 (4.63)     0.1400 (2.80)     457;1665      149.3355 (0.22)      22722           1
test_join_performance[matrix-axis=0-Mixed-gc=False]            7.4540 (5.68)     137.1470 (18.75)    8.1993 (5.62)     2.5490 (12.41)    7.8450 (5.51)     0.3300 (6.60)     548;1068      121.9620 (0.18)      21392           1
test_join_performance[matrix-axis=1-Mixed-gc=False]            7.7950 (5.94)      34.8160 (4.76)     8.1700 (5.60)     0.8154 (3.97)     8.0650 (5.67)     0.1000 (2.00)     721;1447      122.3986 (0.18)      25168           1
test_join_performance[matrix-axis=0-Mixed-gc=True]             7.8750 (6.00)     100.1180 (13.69)    8.4308 (5.78)     2.3699 (11.54)    8.1650 (5.74)     0.2290 (4.58)     497;1723      118.6128 (0.17)      25061           1
test_join_performance[matrix-axis=1-Mixed-gc=True]             8.5360 (6.50)     153.1270 (20.94)    9.0898 (6.23)     3.5232 (17.16)    8.7970 (6.18)     0.1090 (2.18)     334;1496      110.0135 (0.16)      26109           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

📚 Documentation preview 📚: https://pytensor--1340.org.readthedocs.build/en/1340/

@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch 2 times, most recently from d8efc2f to 9b81d41 Compare April 2, 2025 12:10
@ricardoV94 ricardoV94 marked this pull request as draft April 2, 2025 17:15
@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch 2 times, most recently from c6536ac to deb4a19 Compare April 7, 2025 12:09
@ricardoV94 ricardoV94 changed the title Faster implementation of Join in C backend Reuse output of Join in C backend Apr 7, 2025
@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch 2 times, most recently from 81e22db to 4d3f133 Compare April 7, 2025 13:32
@ricardoV94 ricardoV94 marked this pull request as ready for review April 7, 2025 16:48
@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch from 4d3f133 to 948977d Compare April 8, 2025 12:32
Copy link

codecov bot commented Apr 8, 2025

Codecov Report

Attention: Patch coverage is 97.29730% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.13%. Comparing base (ff09268) to head (8e4bb7e).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/basic.py 96.49% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (97.29%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1340   +/-   ##
=======================================
  Coverage   82.13%   82.13%           
=======================================
  Files         211      211           
  Lines       49773    49734   -39     
  Branches     8830     8816   -14     
=======================================
- Hits        40879    40850   -29     
+ Misses       6714     6709    -5     
+ Partials     2180     2175    -5     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/tensor_basic.py 100.00% <100.00%> (+1.78%) ⬆️
pytensor/link/numba/dispatch/tensor_basic.py 87.38% <100.00%> (-0.34%) ⬇️
pytensor/scan/checkpoints.py 75.51% <100.00%> (-0.97%) ⬇️
pytensor/tensor/rewriting/basic.py 95.50% <100.00%> (+0.92%) ⬆️
pytensor/tensor/basic.py 91.68% <96.49%> (+<0.01%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (3)

tests/tensor/test_basic.py:1766

  • The exception expectation was changed from IndexError to ValueError. Please verify that all related code and tests are updated accordingly.
with pytest.raises(ValueError):

pytensor/tensor/basic.py:2415

  • The error message for out‐of‐bounds axis errors has been updated to match numpy's standard. Confirm that this change is intentional and that no downstream dependencies rely on the old message.
numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2

pytensor/link/numba/dispatch/tensor_basic.py:122

  • Switching to axis.item() assumes the axis input is a 0-d array; please ensure that all cases (including negative axes and non-array types) are correctly handled by this update.
return np.concatenate(tensors, axis.item())

@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch from 948977d to 2e0c79d Compare April 8, 2025 14:54

Joined tensors must have the same rank
>>> pt.join(0, x, u)
Traceback (most recent call last):
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1].
TypeError: Only tensors with the same number of dimensions can be joined
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we're losing information in the traceback, is that 100% necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, reintroduced

raise TypeError(
"Join cannot handle arguments of dimension 0."
" Use `stack` to join scalar values."
" Use `stack` to join scalar values and/or increase rank of scalars."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "rank" the right word here? I only know it in the context of matrices as a synonym for column span.

Maybe "Use stack to join scalar values into higher dimensional objects" ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reworded


# Most times axis is constant, inline it
# This is safe to do because the hash of the c_code includes the constant signature
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "constant signature"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values of the constant are used to determine if a cached c_impl can be reused or not, otherwise it would be dangerous to inline the values like we're doing here for performance

[old_output] = node.outputs

if ret.dtype != old_output.dtype:
ret = ret.astype(old_output.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any savings to doing this check vs just always casting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be some extra work in cast:

def cast(x, dtype: str | np.dtype) -> TensorVariable:
    """Symbolically cast `x` to a Tensor of type `dtype`."""

    if isinstance(dtype, str) and dtype == "floatX":
        dtype = config.floatX

    dtype_name = np.dtype(dtype).name

    _x = as_tensor_variable(x)
    if _x.type.dtype == dtype_name:
        return _x
    if _x.type.dtype.startswith("complex") and not dtype_name.startswith("complex"):
        raise TypeError(
            "Casting from complex to real is ambiguous: consider real(), "
            "imag(), angle() or abs()"
        )
    return _cast_mapping[dtype_name](x)

Not super crazy, but I don't mind too much the one check to save those extra cycles

@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch from 2e0c79d to e828a98 Compare May 30, 2025 13:01
Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
@ricardoV94 ricardoV94 force-pushed the faster_join_impl_c_backend branch from e828a98 to 8e4bb7e Compare May 30, 2025 13:01
@ricardoV94 ricardoV94 merged commit f695840 into pymc-devs:main May 30, 2025
67 checks passed
@ricardoV94 ricardoV94 deleted the faster_join_impl_c_backend branch May 30, 2025 13:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get rid of join view flag
2 participants