Skip to content

Commit 85cdd65

Browse files
committed
ENH/DOC: updates in functionals and default_ops, see full commit message.
functional_basic_example: - update comments functional_basic_example_solver: - update comments - remove example when solving with Chambolle-Pock default_ops: - update doc in ConstantOperator - update input checks ConstantOperator__init__ default_functionals: - update doc in sevaral methods in several functionals functional: - update doc in some methods of Functional - new imlpementation of derivative in Functional - update doc in several other places
1 parent 43d1083 commit 85cdd65

File tree

5 files changed

+342
-223
lines changed

5 files changed

+342
-223
lines changed

examples/solvers/functional_basic_example.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
import numpy as np
4646
import odl
4747

48+
4849
# Here we define the functional
4950
class MyFunctional(odl.solvers.Functional):
5051
"""This is my functional: ||x||_2^2 + <x, y>."""
5152

52-
# Defining the __init__ function
5353
def __init__(self, domain, y):
5454
# This comand calls the init of Functional and sets a number of
5555
# parameters associated with a functional. All but domain have default
@@ -63,8 +63,7 @@ def __init__(self, domain, y):
6363
raise TypeError('y is not in the domain!')
6464
self._y = y
6565

66-
# Now we define a propert which returns y, so that the user can see which
67-
# value is used in a particular instance of the class.
66+
# Property that returns the linear term.
6867
@property
6968
def y(self):
7069
return self._y
@@ -76,35 +75,36 @@ def _call(self, x):
7675
# Next we define the gradient. Note that this is a property.
7776
@property
7877
def gradient(self):
79-
# Inside this property, we define the gradient operator. This can be
80-
# defined anywhere and just returned here, but in this example we will
81-
# also define it here.
8278

79+
# The class corresponding to the gradient operator.
8380
class MyGradientOperator(odl.Operator):
81+
"""Class that implements the gradient operator of the functional
82+
``||x||_2^2 + <x,y>``.
83+
"""
8484

85-
# Define an __init__ method for this operator
8685
def __init__(self, functional):
8786
super().__init__(domain=functional.domain,
8887
range=functional.domain)
8988

9089
self._functional = functional
9190

92-
# Define a _call method for this operator
9391
def _call(self, x):
9492
return 2.0 * x + self._functional.y
9593

9694
return MyGradientOperator(functional=self)
9795

98-
# Next we define the convex conjugate functional
96+
# Next we define the convex conjugate functional.
9997
@property
10098
def conjugate_functional(self):
101-
# This functional is implemented below
99+
# This functional is implemented below.
102100
return MyFunctionalConjugate(domain=self.domain, y=self.y)
103101

104102

105-
# Here is the conjugate functional
103+
# Here is the conjugate functional.
106104
class MyFunctionalConjugate(odl.solvers.Functional):
107-
"""Calculations give that this funtional has the analytic expression
105+
"""Conjugate functional to ``||x||_2^2 + <x,y>``.
106+
107+
Calculations give that this funtional has the analytic expression
108108
f^*(x) = ||x||^2/2 - ||x-y||^2/4 + ||y||^2/2 - <x,y>.
109109
"""
110110
def __init__(self, domain, y):

examples/solvers/functional_basic_example_solver.py

+8-48
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
using the default functionals. The problem we will solve is to minimize
2222
1/2 * ||x - g||_2^2 + lam*||x||_1, for some vector g and some constant lam,
2323
subject to that all components in x are greater than or equal to 0. The
24-
theoretical optimal solution to this problem is x = (g - lam)_+, where ( )+
25-
denotes the positive part of the element, i.e., (z_i)_+ = z_i if z_i >= 0, and
26-
0 otherwise.
27-
"""
24+
theoretical optimal solution to this problem is x = (g - lam)_+, where ( )_+
25+
denotes the positive part of the element, i.e., (z_i)_+ = max(z_i, 0)."""
2826

2927
import numpy as np
3028
import odl
@@ -33,12 +31,12 @@
3331
n = 10
3432
space = odl.rn(n)
3533

36-
# Create parameters. First half of g are ones, second half are minus ones.
34+
# Create parameters.
3735
g = space.element(np.hstack((np.ones(n/2), -np.ones(n/2))))
3836
lam = 0.5
3937

4038
# Note that with the values above, the optimal solution is given by a vector
41-
# with fires half of the elements equal to 0.5, the second half equal to 0.
39+
# with first half of the elements equal to 0.5, the second half equal to 0.
4240

4341
# Create the L1-norm functional and multiplyit with the constant lam.
4442
lam_l1_func = lam * odl.solvers.L1Norm(space)
@@ -62,53 +60,15 @@
6260
sigma = 0.5
6361

6462
# Starting point, and also updated inplace in the solver
65-
x_fbpd = space.zero()
63+
x = space.element(np.random.randn(n))
64+
print('Initial guess: x = {}'.format(x.asarray()))
6665

6766
# Optional: pass callback objects to solver
6867
callback = (odl.solvers.CallbackPrintIteration())
6968

7069
# Run the algorithm
71-
odl.solvers.forward_backward_pd(x=x_fbpd, prox_f=prox_f, prox_cc_g=[prox_cc_g],
70+
odl.solvers.forward_backward_pd(x=x, prox_f=prox_f, prox_cc_g=[prox_cc_g],
7271
L=[L], grad_h=grad_h, tau=tau, sigma=[sigma],
7372
niter=niter, callback=callback)
7473

75-
print(x_fbpd.asarray())
76-
77-
78-
# The problem can also be solved using, e.g., the Chambolle-Pock algorithm.
79-
# Here we create the necessary proximal factories of the conjugate functionals
80-
# (see the Chambolle-Pock algorithm and examples on this for more information).
81-
prox_cc_l2 = trans_l2_func.conjugate_functional.proximal
82-
prox_cc_l1 = lam_l1_func.conjugate_functional.proximal
83-
84-
# Combined the proximals for use in the solver
85-
proximal_dual = odl.solvers.combine_proximals(prox_cc_l2, prox_cc_l1)
86-
87-
# Create the matrix of operators for the Chambolle-Pock solver
88-
op = odl.BroadcastOperator(odl.IdentityOperator(space),
89-
odl.IdentityOperator(space))
90-
91-
# Create the proximal operator for the constraint
92-
proximal_primal = odl.solvers.proximal_nonnegativity(op.domain)
93-
94-
# The operator norm is sqrt(2), since only identity operators are used
95-
op_norm = np.sqrt(2)
96-
97-
# Some solver parameters
98-
niter = 50 # Number of iterations
99-
tau = 1.0 / op_norm # Step size for the primal variable
100-
sigma = 1.0 / op_norm # Step size for the dual variable
101-
102-
# Optional: pass callback objects to solver
103-
callback = (odl.solvers.CallbackPrintIteration())
104-
105-
# Starting point, and also updated inplace in the solver
106-
x_cp = op.domain.zero()
107-
108-
# Run the algorithm
109-
odl.solvers.chambolle_pock_solver(op=op, x=x_cp, tau=tau, sigma=sigma,
110-
proximal_primal=proximal_primal,
111-
proximal_dual=proximal_dual, niter=niter,
112-
callback=callback)
113-
114-
print(x_cp.asarray())
74+
print('Solution guess: x = {}'.format(x.asarray()))

odl/operator/default_ops.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -826,24 +826,22 @@ class ConstantOperator(Operator):
826826
"""
827827

828828
def __init__(self, vector, domain=None, range=None):
829-
"""Initialize an instance.
829+
"""Initialize a new instance.
830830
831831
Parameters
832832
----------
833833
vector : `LinearSpaceVector`
834834
The vector constant to be returned
835835
domain : `LinearSpace`, optional
836-
default is vector.space
837-
The domain of the operator.
836+
Domain of the operator. Default: ``vector.space``
838837
range : `LinearSpace`, optional
839-
default : vector.space
840-
The range of the operator.
838+
Range of the operator. Default: ``vector.space``
841839
"""
842-
if domain is None or range is None:
843-
if not isinstance(vector, LinearSpaceVector):
844-
raise TypeError('If either domain or range is unspecified, '
845-
'`vector` {!r} has to be a LinearSpaceVector '
846-
'instance'.format(vector))
840+
if ((domain is None or range is None) and
841+
not isinstance(vector, LinearSpaceVector)):
842+
raise TypeError('If either domain or range is unspecified '
843+
'`vector` must be LinearSpaceVector, got {!r}.'
844+
''.format(vector))
847845

848846
if domain is None:
849847
domain = vector.space

0 commit comments

Comments
 (0)