Skip to content

Commit 28f2648

Browse files
Refactor nlinalg.norm to match np.linalg.norm
Expand TestNorm test coverage
1 parent 1a0d12d commit 28f2648

File tree

2 files changed

+257
-55
lines changed

2 files changed

+257
-55
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 194 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import warnings
22
from functools import partial
3+
from typing import Callable, Literal, Optional, Union
34

45
import numpy as np
6+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
57

68
from pytensor import scalar as ps
79
from pytensor.gradient import DisconnectedType
@@ -688,41 +690,204 @@ def matrix_power(M, n):
688690
return result
689691

690692

691-
def norm(x, ord):
692-
x = as_tensor_variable(x)
693+
def _multi_svd_norm(
694+
x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable
695+
):
696+
"""Compute a function of the singular values of the 2-D matrices in `x`.
697+
698+
This is a private utility function used by `pytensor.tensor.nlinalg.norm()`.
699+
700+
Copied from `np.linalg._multi_svd_norm`.
701+
702+
Parameters
703+
----------
704+
x : TensorVariable
705+
Input tensor.
706+
row_axis, col_axis : int
707+
The axes of `x` that hold the 2-D matrices.
708+
reduce_op : callable
709+
Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum`
710+
711+
Returns
712+
-------
713+
result : float or ndarray
714+
If `x` is 2-D, the return values is a float.
715+
Otherwise, it is an array with ``x.ndim - 2`` dimensions.
716+
The return values are either the minimum or maximum or sum of the
717+
singular values of the matrices, depending on whether `op`
718+
is `pt.amin` or `pt.amax` or `pt.sum`.
719+
720+
"""
721+
y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1))
722+
result = reduce_op(svd(y, compute_uv=False), axis=-1)
723+
return result
724+
725+
726+
VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]
727+
728+
729+
def norm(
730+
x: ptb.TensorVariable,
731+
ord: Optional[Union[float, VALID_ORD]] = None,
732+
axis: Optional[Union[int, tuple[int, ...]]] = None,
733+
keepdims: bool = False,
734+
):
735+
"""
736+
Matrix or vector norm.
737+
738+
Parameters
739+
----------
740+
x: TensorVariable
741+
Tensor to take norm of.
742+
743+
ord: float, str or int, optional
744+
Order of norm. If `ord` is a str, it must be one of the following:
745+
- 'fro' or 'f' : Frobenius norm
746+
- 'nuc' : nuclear norm
747+
- 'inf' : Infinity norm
748+
- '-inf' : Negative infinity norm
749+
If an integer, order can be one of -2, -1, 0, 1, or 2.
750+
Otherwise `ord` must be a float.
751+
752+
Default is the Frobenius (L2) norm.
753+
754+
axis: tuple of int, optional
755+
Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
756+
norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.
757+
758+
keepdims: bool
759+
If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.
760+
761+
Returns
762+
-------
763+
TensorVariable
764+
Norm of `x` along axes specified by `axis`.
765+
766+
Notes
767+
-----
768+
Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with
769+
shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch.
770+
771+
If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified.
772+
"""
773+
x = ptb.as_tensor_variable(x)
774+
693775
ndim = x.ndim
694-
if ndim == 0:
695-
raise ValueError("'axis' entry is out of bounds.")
696-
elif ndim == 1:
697-
if ord is None:
698-
return ptm.sum(x**2) ** 0.5
699-
elif ord == "inf":
700-
return ptm.max(abs(x))
701-
elif ord == "-inf":
702-
return ptm.min(abs(x))
776+
core_ndim = min(2, ndim)
777+
batch_ndim = ndim - core_ndim
778+
779+
if axis is None:
780+
# Handle some common cases first. These can be computed more quickly than the default SVD way, so we always
781+
# want to check for them.
782+
if (
783+
(ord is None)
784+
or (ord in ("f", "fro") and core_ndim == 2)
785+
or (ord == 2 and core_ndim == 1)
786+
):
787+
x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1))
788+
batch_T_dim_order = tuple(range(batch_ndim)) + tuple(
789+
range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1)
790+
)
791+
792+
if x.dtype.startswith("complex"):
793+
x_real = x.real # type: ignore
794+
x_imag = x.imag # type: ignore
795+
sqnorm = (
796+
ptb.transpose(x_real, batch_T_dim_order) @ x_real
797+
+ ptb.transpose(x_imag, batch_T_dim_order) @ x_imag
798+
)
799+
else:
800+
sqnorm = ptb.transpose(x, batch_T_dim_order) @ x
801+
ret = ptm.sqrt(sqnorm).squeeze()
802+
if keepdims:
803+
ret = ptb.shape_padright(ret, core_ndim)
804+
return ret
805+
806+
# No special computation to exploit -- set default axis before continuing
807+
axis = tuple(range(core_ndim))
808+
809+
elif not isinstance(axis, tuple):
810+
try:
811+
axis = int(axis)
812+
except Exception as e:
813+
raise TypeError(
814+
"'axis' must be None, an integer, or a tuple of integers"
815+
) from e
816+
817+
axis = (axis,)
818+
819+
if len(axis) == 1:
820+
# Vector norms
821+
if ord in [None, "fro", "f"] and (core_ndim == 2):
822+
# This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of
823+
# vectors. Other vector norms will work fine in this case.
824+
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims))
825+
elif (ord == "inf") or (ord == np.inf):
826+
ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims)
827+
elif (ord == "-inf") or (ord == -np.inf):
828+
ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims)
703829
elif ord == 0:
704-
return x[x.nonzero()].shape[0]
830+
ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims)
831+
elif ord == 1:
832+
ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims)
833+
elif isinstance(ord, str):
834+
raise ValueError(f"Invalid norm order '{ord}' for vectors")
705835
else:
706-
try:
707-
z = ptm.sum(abs(x**ord)) ** (1.0 / ord)
708-
except TypeError:
709-
raise ValueError("Invalid norm order for vectors.")
710-
return z
711-
elif ndim == 2:
712-
if ord is None or ord == "fro":
713-
return ptm.sum(abs(x**2)) ** (0.5)
714-
elif ord == "inf":
715-
return ptm.max(ptm.sum(abs(x), 1))
716-
elif ord == "-inf":
717-
return ptm.min(ptm.sum(abs(x), 1))
836+
ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims)
837+
ret **= ptm.reciprocal(ord)
838+
839+
return ret
840+
841+
elif len(axis) == 2:
842+
# Matrix norms
843+
row_axis, col_axis = (
844+
batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim)
845+
)
846+
axis = (row_axis, col_axis)
847+
848+
if ord in [None, "fro", "f"]:
849+
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis))
850+
851+
elif (ord == "inf") or (ord == np.inf):
852+
if row_axis > col_axis:
853+
row_axis -= 1
854+
ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
855+
856+
elif (ord == "-inf") or (ord == -np.inf):
857+
if row_axis > col_axis:
858+
row_axis -= 1
859+
ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
860+
718861
elif ord == 1:
719-
return ptm.max(ptm.sum(abs(x), 0))
862+
if col_axis > row_axis:
863+
col_axis -= 1
864+
ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
865+
720866
elif ord == -1:
721-
return ptm.min(ptm.sum(abs(x), 0))
867+
if col_axis > row_axis:
868+
col_axis -= 1
869+
ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
870+
871+
elif ord == 2:
872+
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max)
873+
874+
elif ord == -2:
875+
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min)
876+
877+
elif ord == "nuc":
878+
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum)
879+
722880
else:
723-
raise ValueError(0)
724-
elif ndim > 2:
725-
raise NotImplementedError("We don't support norm with ndim > 2")
881+
raise ValueError(f"Invalid norm order for matrices: {ord}")
882+
883+
if keepdims:
884+
ret = ptb.expand_dims(ret, axis)
885+
886+
return ret
887+
else:
888+
raise ValueError(
889+
f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}"
890+
)
726891

727892

728893
class TensorInv(Op):

tests/tensor/test_nlinalg.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import numpy.linalg
55
import pytest
6-
from numpy import inf
76
from numpy.testing import assert_array_almost_equal
87

98
import pytensor
@@ -463,44 +462,82 @@ def test_non_square_matrix(self):
463462
f(a)
464463

465464

466-
class TestNormTests:
465+
class TestNorm:
467466
def test_wrong_type_of_ord_for_vector(self):
468-
with pytest.raises(ValueError):
467+
with pytest.raises(ValueError, match="Invalid norm order 'fro' for vectors"):
469468
norm([2, 1], "fro")
470469

471470
def test_wrong_type_of_ord_for_matrix(self):
472-
with pytest.raises(ValueError):
473-
norm([[2, 1], [3, 4]], 0)
471+
ord = 0
472+
with pytest.raises(ValueError, match=f"Invalid norm order for matrices: {ord}"):
473+
norm([[2, 1], [3, 4]], ord)
474474

475475
def test_non_tensorial_input(self):
476-
with pytest.raises(ValueError):
477-
norm(3, None)
476+
with pytest.raises(
477+
ValueError,
478+
match="Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = 0",
479+
):
480+
norm(3, ord=2)
481+
482+
def test_invalid_axis_input(self):
483+
axis = scalar("i", dtype="int")
484+
with pytest.raises(
485+
TypeError, match="'axis' must be None, an integer, or a tuple of integers"
486+
):
487+
norm([[1, 2], [3, 4]], axis=axis)
478488

479-
def test_tensor_input(self):
480-
res = norm(np.random.random((3, 4, 5)), None)
481-
assert res.shape.eval() == (3,)
489+
@pytest.mark.parametrize(
490+
"ord",
491+
[None, np.inf, -np.inf, 1, -1, 2, -2],
492+
ids=["None", "inf", "-inf", "1", "-1", "2", "-2"],
493+
)
494+
@pytest.mark.parametrize("core_dims", [(4,), (4, 3)], ids=["vector", "matrix"])
495+
@pytest.mark.parametrize("batch_dims", [(), (2,)], ids=["no_batch", "batch"])
496+
@pytest.mark.parametrize("test_imag", [True, False], ids=["complex", "real"])
497+
@pytest.mark.parametrize(
498+
"keepdims", [True, False], ids=["keep_dims=True", "keep_dims=False"]
499+
)
500+
def test_numpy_compare(
501+
self,
502+
ord: float,
503+
core_dims: tuple[int, ...],
504+
batch_dims: tuple[int, ...],
505+
test_imag: bool,
506+
keepdims: bool,
507+
axis=None,
508+
):
509+
is_matrix = len(core_dims) == 2
510+
has_batch = len(batch_dims) > 0
511+
if ord in [np.inf, -np.inf] and not is_matrix:
512+
pytest.skip("Infinity norm not defined for vectors")
513+
if test_imag and is_matrix and ord == -2:
514+
pytest.skip("Complex matrices not supported")
515+
if has_batch and not is_matrix:
516+
# Handle batched vectors by row-normalizing a matrix
517+
axis = (-1,)
482518

483-
def test_numpy_compare(self):
484519
rng = np.random.default_rng(utt.fetch_seed())
485520

486-
M = matrix("A", dtype=config.floatX)
487-
V = vector("V", dtype=config.floatX)
521+
if test_imag:
522+
x_real, x_imag = rng.standard_normal((2, *batch_dims, *core_dims)).astype(
523+
config.floatX
524+
)
525+
dtype = "complex128" if config.floatX.endswith("64") else "complex64"
526+
X = (x_real + 1j * x_imag).astype(dtype)
527+
else:
528+
X = rng.standard_normal(batch_dims + core_dims).astype(config.floatX)
488529

489-
a = rng.random((4, 4)).astype(config.floatX)
490-
b = rng.random(4).astype(config.floatX)
530+
if batch_dims == ():
531+
np_norm = np.linalg.norm(X, ord=ord, axis=axis, keepdims=keepdims)
532+
else:
533+
np_norm = np.stack(
534+
[np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) for x in X]
535+
)
491536

492-
A = (
493-
[None, "fro", "inf", "-inf", 1, -1, None, "inf", "-inf", 0, 1, -1, 2, -2],
494-
[M, M, M, M, M, M, V, V, V, V, V, V, V, V],
495-
[a, a, a, a, a, a, b, b, b, b, b, b, b, b],
496-
[None, "fro", inf, -inf, 1, -1, None, inf, -inf, 0, 1, -1, 2, -2],
497-
)
537+
pt_norm = norm(X, ord=ord, axis=axis, keepdims=keepdims)
538+
f = function([], pt_norm, mode="FAST_COMPILE")
498539

499-
for i in range(0, 14):
500-
f = function([A[1][i]], norm(A[1][i], A[0][i]))
501-
t_n = f(A[2][i])
502-
n_n = np.linalg.norm(A[2][i], A[3][i])
503-
assert _allclose(n_n, t_n)
540+
utt.assert_allclose(np_norm, f())
504541

505542

506543
class TestTensorInv(utt.InferShapeTester):

0 commit comments

Comments
 (0)