Skip to content

Commit 72ed1b4

Browse files
committed
TST/BUG: add test for functional times vector.
This commit adds a test for multiplication of a functional and a vector, both left and right multiplication. In doing so, two bugs were found and corrected.
1 parent 19d5a02 commit 72ed1b4

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

odl/solvers/functional/functional.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __mul__(self, other):
273273
else:
274274
return FunctionalRightScalarMult(self, other)
275275
else:
276-
super().__mul__(self, other)
276+
return super().__mul__(other)
277277

278278
def __rmul__(self, other):
279279
"""Return ``other * self``.
@@ -318,18 +318,18 @@ def __rmul__(self, other):
318318
rmul : {`Functional`, `Operator`}
319319
Multiplication result
320320
321-
If ``other`` is an `Operator`, ``mul`` is a `OperatorComp`.
321+
If ``other`` is an `Operator`, ``rmul`` is a `OperatorComp`.
322322
323-
If ``other`` is a scalar, ``mul`` is a
323+
If ``other`` is a scalar, ``rmul`` is a
324324
`FunctionalLeftScalarMult`.
325325
326-
If ``other`` is a vector, ``mul`` is a
326+
If ``other`` is a vector, ``rmul`` is a
327327
`OperatorLeftVectorMult`.
328328
"""
329329
if other in self.domain.field:
330330
return FunctionalLeftScalarMult(self, other)
331331
else:
332-
super().__rmul__(self, other)
332+
return super().__rmul__(other)
333333

334334
def __add__(self, other):
335335
"""Return ``self + other``.

test/solvers/functional/functional_test.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
# Places for the accepted error when comparing results
3434
PLACES = 8
3535

36-
# TODO: make some tests that check that prox work.
3736

3837
# TODO: Test that prox and conjugate functionals are not returned for negative
3938
# left scaling.
@@ -42,8 +41,6 @@
4241

4342
# TODO: Test flags for translations etc.
4443

45-
# TODO: Add test for composition, both from letf and right, with a vector
46-
4744

4845
def test_derivative():
4946
"""Test for the derivative of a functional.
@@ -188,7 +185,6 @@ def test_functional_composition():
188185

189186
def test_functional_sum():
190187
"""Test for the sum of two functionals."""
191-
192188
space = odl.uniform_discr(0, 1, 10)
193189

194190
func1 = odl.solvers.L2NormSquare(space)
@@ -271,7 +267,6 @@ def test_functional_plus_scalar():
271267

272268
def test_translation_of_functional():
273269
"""Test for the translation of a functional: (f(. - y))^*"""
274-
275270
space = odl.uniform_discr(0, 1, 10)
276271

277272
# The translation; an element in the domain
@@ -328,10 +323,49 @@ def test_translation_of_functional():
328323
places=PLACES)
329324

330325

326+
def test_multiplication_with_vector():
327+
"""Test for multiplying a functional with a vector, both left and right."""
328+
329+
space = odl.uniform_discr(0, 1, 10)
330+
331+
x = example_element(space)
332+
y = example_element(space)
333+
func = odl.solvers.L1Norm(space)
334+
335+
wrong_space = odl.uniform_discr(1, 2, 10)
336+
y_other_space = example_element(wrong_space)
337+
338+
# Multiplication from the right. Make sure it is a OperatorRightVectorMult
339+
func_times_y = func * y
340+
assert isinstance(func_times_y, odl.OperatorRightVectorMult)
341+
342+
expected_result = func(y*x)
343+
assert almost_equal((func*y)(x), expected_result, places=PLACES)
344+
345+
# Make sure that right muliplication is not allowed with vector from
346+
# another space
347+
with pytest.raises(TypeError):
348+
func * y_other_space
349+
350+
# Multiplication from the left. Make sure it is a FunctionalLeftVectorMult
351+
y_times_func = y * func
352+
assert isinstance(y_times_func, odl.FunctionalLeftVectorMult)
353+
354+
expected_result = y * func(x)
355+
assert all_almost_equal(y_times_func(x), expected_result, places=PLACES)
356+
357+
# Now, multiplication with vector from another space is ok (since it is the
358+
# same as scaling that vector with the scalar returned by the functional).
359+
y_other_times_func = y_other_space * func
360+
assert isinstance(y_other_times_func, odl.FunctionalLeftVectorMult)
361+
362+
expected_result = y_other_space * func(x)
363+
assert all_almost_equal(y_other_times_func(x), expected_result,
364+
places=PLACES)
365+
366+
331367
def test_convex_conjugate_translation():
332368
"""Test for the convex conjugate of a translation: (f(. - y))^*"""
333-
334-
# Image space
335369
space = odl.uniform_discr(0, 1, 10)
336370

337371
# The translation; an element in the domain
@@ -394,8 +428,6 @@ def test_convex_conjugate_translation():
394428

395429
def test_convex_conjugate_arg_scaling():
396430
"""Test for the convex conjugate of a scaling: (f(. scaling))^*"""
397-
398-
# Image space
399431
space = odl.uniform_discr(0, 1, 10)
400432

401433
# The scaling parameter
@@ -446,8 +478,6 @@ def test_convex_conjugate_arg_scaling():
446478

447479
def test_convex_conjugate_functional_scaling():
448480
"""Test for the convex conjugate of a scaling: (scaling * f(.))^*"""
449-
450-
# Image space
451481
space = odl.uniform_discr(0, 1, 10)
452482

453483
# The scaling parameter
@@ -500,8 +530,6 @@ def test_convex_conjugate_functional_scaling():
500530

501531
def test_convex_conjugate_linear_perturbation():
502532
"""Test for the convex conjugate of a scaling: (f(.) + <y,.>)^*"""
503-
504-
# Image space
505533
space = odl.uniform_discr(0, 1, 10)
506534

507535
# The perturbation; an element in the domain (which is the same as the dual

0 commit comments

Comments
 (0)