Skip to content

Commit 2a4d081

Browse files
Refactor linspace, logspace, and geomspace to match numpy implementation
1 parent b8e26cd commit 2a4d081

File tree

2 files changed

+319
-40
lines changed

2 files changed

+319
-40
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 297 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Collection, Iterable
23

34
import numpy as np
@@ -20,14 +21,25 @@
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import int32 as int_t
2223
from pytensor.scalar import upcast
23-
from pytensor.tensor import as_tensor_variable
24+
from pytensor.tensor import TensorLike, as_tensor_variable
2425
from pytensor.tensor import basic as ptb
2526
from pytensor.tensor.basic import alloc, second
2627
from pytensor.tensor.exceptions import NotScalarConstantError
2728
from pytensor.tensor.math import abs as pt_abs
2829
from pytensor.tensor.math import all as pt_all
30+
from pytensor.tensor.math import (
31+
bitwise_and,
32+
ge,
33+
gt,
34+
log,
35+
lt,
36+
maximum,
37+
minimum,
38+
prod,
39+
sign,
40+
switch,
41+
)
2942
from pytensor.tensor.math import eq as pt_eq
30-
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
3143
from pytensor.tensor.math import max as pt_max
3244
from pytensor.tensor.math import sum as pt_sum
3345
from pytensor.tensor.shape import specify_broadcastable
@@ -1583,27 +1595,294 @@ def broadcast_shape_iter(
15831595
return tuple(result_dims)
15841596

15851597

1586-
def geomspace(start, end, steps, base=10.0):
1587-
from pytensor.tensor.math import log
1598+
def _check_deprecated_inputs(stop, end, num, steps):
1599+
if end is not None:
1600+
warnings.warn(
1601+
"The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.",
1602+
DeprecationWarning,
1603+
)
1604+
stop = end
1605+
if steps is not None:
1606+
warnings.warn(
1607+
"The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.",
1608+
DeprecationWarning,
1609+
)
1610+
num = steps
1611+
1612+
return stop, num
1613+
1614+
1615+
def _linspace_core(
1616+
start: TensorVariable,
1617+
stop: TensorVariable,
1618+
num: int,
1619+
dtype: str,
1620+
endpoint=True,
1621+
retstep=False,
1622+
axis=0,
1623+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1624+
div = (num - 1) if endpoint else num
1625+
delta = (stop - start).astype(dtype)
1626+
samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
1627+
1628+
step = switch(gt(div, 0), delta / div, np.nan)
1629+
samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start)
1630+
samples = switch(
1631+
bitwise_and(gt(num, 1), np.asarray(endpoint)),
1632+
set_subtensor(samples[-1, ...], stop),
1633+
samples,
1634+
)
1635+
1636+
if axis != 0:
1637+
samples = ptb.moveaxis(samples, 0, axis)
1638+
1639+
if retstep:
1640+
return samples, step
1641+
1642+
return samples
1643+
1644+
1645+
def _broadcast_inputs_and_dtypes(*args, dtype=None):
1646+
args = map(ptb.as_tensor_variable, args)
1647+
args = broadcast_arrays(*args)
1648+
1649+
if dtype is None:
1650+
dtype = pytensor.config.floatX
1651+
1652+
return args, dtype
1653+
1654+
1655+
def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
1656+
"""
1657+
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
1658+
may change how the axis argument is interpreted in the final output.
1659+
1660+
Parameters
1661+
----------
1662+
start
1663+
stop
1664+
base
1665+
dtype
1666+
axis
1667+
1668+
Returns
1669+
-------
1670+
1671+
"""
1672+
base = ptb.as_tensor_variable(base, dtype=dtype)
1673+
if base.ndim > 0:
1674+
ndmax = len(broadcast_shape(start, stop, base))
1675+
start, stop, base = (
1676+
ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base)
1677+
)
1678+
base = ptb.expand_dims(base, axis=(axis,))
1679+
1680+
return start, stop, base
1681+
1682+
1683+
def linspace(
1684+
start: TensorLike,
1685+
stop: TensorLike,
1686+
num: TensorLike = 50,
1687+
endpoint: bool = True,
1688+
retstep: bool = False,
1689+
dtype: str | None = None,
1690+
axis: int = 0,
1691+
end: TensorLike | None = None,
1692+
steps: TensorLike | None = None,
1693+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1694+
"""
1695+
Return evenly spaced numbers over a specified interval.
1696+
1697+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1698+
1699+
The endpoint of the interval can optionally be excluded.
15881700
1589-
start = ptb.as_tensor_variable(start)
1590-
end = ptb.as_tensor_variable(end)
1591-
return base ** linspace(log(start) / log(base), log(end) / log(base), steps)
1701+
Parameters
1702+
----------
1703+
start: int, float, or TensorVariable
1704+
The starting value of the sequence.
15921705
1706+
stop: int, float or TensorVariable
1707+
The end value of the sequence, unless `endpoint` is set to False.
1708+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
15931709
1594-
def logspace(start, end, steps, base=10.0):
1595-
start = ptb.as_tensor_variable(start)
1596-
end = ptb.as_tensor_variable(end)
1597-
return base ** linspace(start, end, steps)
1710+
num: int
1711+
Number of samples to generate. Must be non-negative.
15981712
1713+
endpoint: bool
1714+
Whether to include the endpoint in the range.
1715+
1716+
retstep: bool
1717+
If true, returns both the samples and an array of steps between samples.
1718+
1719+
dtype: str, optional
1720+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1721+
and `end` arguments.
1722+
1723+
axis: int
1724+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1725+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
1726+
1727+
end: int, float or TensorVariable
1728+
.. warning::
1729+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1730+
The end value of the sequence, unless `endpoint` is set to False.
1731+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1732+
excluded.
1733+
1734+
steps: float, int, or TensorVariable
1735+
.. warning::
1736+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
1737+
1738+
Number of samples to generate. Must be non-negative
1739+
1740+
Returns
1741+
-------
1742+
samples: TensorVariable
1743+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1744+
1745+
step: TensorVariable
1746+
Tensor containing the spacing between samples. Only returned if `retstep` is True.
1747+
"""
1748+
end, num = _check_deprecated_inputs(stop, end, num, steps)
1749+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1750+
1751+
return _linspace_core(
1752+
start=start,
1753+
stop=stop,
1754+
num=num,
1755+
dtype=dtype,
1756+
endpoint=endpoint,
1757+
retstep=retstep,
1758+
axis=axis,
1759+
)
1760+
1761+
1762+
def geomspace(
1763+
start: TensorLike,
1764+
stop: TensorLike,
1765+
num: int = 50,
1766+
base: float = 10.0,
1767+
endpoint: bool = True,
1768+
dtype: str | None = None,
1769+
axis: int = 0,
1770+
end: TensorLike | None = None,
1771+
steps: TensorLike | None = None,
1772+
) -> TensorVariable:
1773+
"""
1774+
Return numbers spaced evenly on a log scale (a geometric progression).
1775+
1776+
Parameters
1777+
----------
1778+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1779+
1780+
The endpoint of the interval can optionally be excluded.
1781+
1782+
Parameters
1783+
----------
1784+
start: int, float, or TensorVariable
1785+
The starting value of the sequence.
1786+
1787+
stop: int, float or TensorVariable
1788+
The end value of the sequence, unless `endpoint` is set to False.
1789+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
1790+
1791+
num: int
1792+
Number of samples to generate. Must be non-negative.
1793+
1794+
base: float
1795+
The base of the log space. The step size between the elements in ln(samples) / ln(base)
1796+
(or log_base(samples)) is uniform.
1797+
1798+
endpoint: bool
1799+
Whether to include the endpoint in the range.
1800+
1801+
dtype: str, optional
1802+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1803+
and `end` arguments.
1804+
1805+
axis: int
1806+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1807+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
1808+
1809+
end: int, float or TensorVariable
1810+
.. warning::
1811+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1812+
The end value of the sequence, unless `endpoint` is set to False.
1813+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1814+
excluded.
1815+
1816+
steps: float, int, or TensorVariable
1817+
.. warning::
1818+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
1819+
1820+
Number of samples to generate. Must be non-negative
1821+
1822+
Returns
1823+
-------
1824+
samples: TensorVariable
1825+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1826+
"""
1827+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1828+
(start, stop), dtype = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1829+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1830+
1831+
out_sign = sign(start)
1832+
log_start, log_stop = (
1833+
log(start * out_sign) / log(base),
1834+
log(stop * out_sign) / log(base),
1835+
)
1836+
result = _linspace_core(
1837+
start=log_start,
1838+
stop=log_stop,
1839+
num=num,
1840+
endpoint=endpoint,
1841+
dtype=dtype,
1842+
axis=0,
1843+
retstep=False,
1844+
)
1845+
result = base**result
1846+
1847+
if num > 0:
1848+
set_subtensor(result[0, ...], start, inplace=True)
1849+
if num > 1 and endpoint:
1850+
set_subtensor(result[-1, ...], stop, inplace=True)
1851+
1852+
result = result * out_sign
1853+
1854+
if axis != 0:
1855+
result = ptb.moveaxis(result, 0, axis)
1856+
1857+
return result
1858+
1859+
1860+
def logspace(
1861+
start: TensorLike,
1862+
stop: TensorLike,
1863+
num: int = 50,
1864+
base: float = 10.0,
1865+
endpoint: bool = True,
1866+
dtype: str | None = None,
1867+
axis: int = 0,
1868+
end: TensorLike | None = None,
1869+
steps: TensorLike | None = None,
1870+
) -> TensorVariable:
1871+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1872+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1873+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1874+
1875+
ls = _linspace_core(
1876+
start=start,
1877+
stop=stop,
1878+
num=num,
1879+
endpoint=endpoint,
1880+
dtype=dtype,
1881+
axis=axis,
1882+
retstep=False,
1883+
)
15991884

1600-
def linspace(start, end, steps):
1601-
start = ptb.as_tensor_variable(start)
1602-
end = ptb.as_tensor_variable(end)
1603-
arr = ptb.arange(steps)
1604-
arr = ptb.shape_padright(arr, max(start.ndim, end.ndim))
1605-
multiplier = (end - start) / (steps - 1)
1606-
return start + arr * multiplier
1885+
return base**ls
16071886

16081887

16091888
def broadcast_to(

0 commit comments

Comments
 (0)