|
| 1 | +# Copyright 2014-2016 The ODL development group |
| 2 | +# |
| 3 | +# This file is part of ODL. |
| 4 | +# |
| 5 | +# ODL is free software: you can redistribute it and/or modify |
| 6 | +# it under the terms of the GNU General Public License as published by |
| 7 | +# the Free Software Foundation, either version 3 of the License, or |
| 8 | +# (at your option) any later version. |
| 9 | +# |
| 10 | +# ODL is distributed in the hope that it will be useful, |
| 11 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 12 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 13 | +# GNU General Public License for more details. |
| 14 | +# |
| 15 | +# You should have received a copy of the GNU General Public License |
| 16 | +# along with ODL. If not, see <http://www.gnu.org/licenses/>. |
| 17 | + |
| 18 | +"""Basic examples on how to write a fucntional. |
| 19 | +
|
| 20 | +When defining a new functional, there are a few standard methods and properties |
| 21 | +that can be implemeted. These are: |
| 22 | +- ``__init__``. This method intialize the functional |
| 23 | +- ``_call``. The actual function call ``functional(x)`` |
| 24 | +- ``gradient`` (property). This gives the gradient operator of the functional, |
| 25 | + i.e., the operator that corresponds to the mapping ``x -> grad_f(x)`` |
| 26 | +- ``proximal``. This returns the proximal operator. If called only as |
| 27 | + ``functional.proximal``, it corresponds to a `Proximal factory`. |
| 28 | +- ``conjugate_functional`` (property). This gives the convex conjugate |
| 29 | + functional |
| 30 | +- ``derivative``. This returns the (directional) derivative operator in a point |
| 31 | + y, such that when called with a point x it corresponds to the linear |
| 32 | + operator ``x --> <x, grad_f(y)>``. Note that this has a default |
| 33 | + implemetation that uses the gradient in order to achieve this. |
| 34 | +
|
| 35 | +Below follows an example of implementing the functional ``||x||_2^2 + <x,y>, |
| 36 | +for some parameter y.`` |
| 37 | +""" |
| 38 | + |
| 39 | +# Imports for common Python 2/3 codebase |
| 40 | +from __future__ import print_function, division, absolute_import |
| 41 | +from future import standard_library |
| 42 | +standard_library.install_aliases() |
| 43 | +from builtins import super |
| 44 | + |
| 45 | +import numpy as np |
| 46 | +import odl |
| 47 | + |
| 48 | +# Here we define the functional |
| 49 | +class MyFunctional(odl.solvers.Functional): |
| 50 | + """This is my functional: ||x||_2^2 + <x, y>.""" |
| 51 | + |
| 52 | + # Defining the __init__ function |
| 53 | + def __init__(self, domain, y): |
| 54 | + # This comand calls the init of Functional and sets a number of |
| 55 | + # parameters associated with a functional. All but domain have default |
| 56 | + # values if not set. |
| 57 | + super().__init__(domain=domain, linear=False, convex=True, |
| 58 | + concave=False, smooth=True, grad_lipschitz=2) |
| 59 | + |
| 60 | + # We need to check that y is in the domain. Then we store the value of |
| 61 | + # y for future use. |
| 62 | + if y not in domain: |
| 63 | + raise TypeError('y is not in the domain!') |
| 64 | + self._y = y |
| 65 | + |
| 66 | + # Now we define a propert which returns y, so that the user can see which |
| 67 | + # translation is used in a particular instance of the class. |
| 68 | + @property |
| 69 | + def y(self): |
| 70 | + return self._y |
| 71 | + |
| 72 | + # Defining the _call function |
| 73 | + def _call(self, x): |
| 74 | + return x.norm()**2 + x.inner(self.y) |
| 75 | + |
| 76 | + # Next we define the gradient. Note that this is a property. |
| 77 | + @property |
| 78 | + def gradient(self): |
| 79 | + # Inside this property, we define the gradient operator. This can be |
| 80 | + # defined anywhere and just returned here, but in this example we will |
| 81 | + # also define it here. |
| 82 | + |
| 83 | + # In order for the initialization of the operator to know which |
| 84 | + # functional it comes from. |
| 85 | + this_functional = self |
| 86 | + |
| 87 | + class MyGradientOperator(odl.Operator): |
| 88 | + |
| 89 | + # Define an __init__ method for this operator |
| 90 | + def __init__(self, functional): |
| 91 | + super().__init__(domain=this_functional.domain, |
| 92 | + range=this_functional.domain) |
| 93 | + |
| 94 | + self._functional = functional |
| 95 | + |
| 96 | + # Define a _call method for this operator |
| 97 | + def _call(self, x): |
| 98 | + return 2.0 * x + self._functional.y |
| 99 | + |
| 100 | + return MyGradientOperator(functional=this_functional) |
| 101 | + |
| 102 | + # Next we define the convex conjugate functional |
| 103 | + @property |
| 104 | + def conjugate_functional(self): |
| 105 | + # This functional is implemented below |
| 106 | + return MyFunctionalConjugate(domain=self.domain, y=self.y) |
| 107 | + |
| 108 | + |
| 109 | +# Here is the conjugate functional |
| 110 | +class MyFunctionalConjugate(odl.solvers.Functional): |
| 111 | + """Hand calculations give that this funtional has the analytic expression |
| 112 | + f^*(x) = ||x||^2/2 - ||x-y||^2/4 + ||y||^2/2 - <x,y>. |
| 113 | + """ |
| 114 | + def __init__(self, domain, y): |
| 115 | + super().__init__(domain=domain, linear=False, convex=True, |
| 116 | + concave=False, smooth=True, grad_lipschitz=2) |
| 117 | + |
| 118 | + if y not in domain: |
| 119 | + raise TypeError('y is not in the domain!') |
| 120 | + self._y = y |
| 121 | + |
| 122 | + @property |
| 123 | + def y(self): |
| 124 | + return self._y |
| 125 | + |
| 126 | + def _call(self, x): |
| 127 | + return (x.norm()**2 / 2.0 - (x - self.y).norm()**2 / 4.0 + |
| 128 | + self.y.norm()**2 / 2.0 - x.inner(self.y)) |
| 129 | + |
| 130 | + |
| 131 | +# Now we test the functional. First we create an instance of the functional |
| 132 | +n = 10 |
| 133 | +space = odl.rn(n) |
| 134 | +y = space.element(np.random.randn(n)) |
| 135 | +my_func = MyFunctional(domain=space, y=y) |
| 136 | + |
| 137 | +# Now we evaluate it, and see that it returns the expected value |
| 138 | +x = space.element(np.random.randn(n)) |
| 139 | + |
| 140 | +if my_func(x) == x.norm()**2 + x.inner(y): |
| 141 | + print('My functional evaluates corretly.') |
| 142 | +else: |
| 143 | + print('There is a bug in the evaluation of my functional.') |
| 144 | + |
| 145 | +# Next we create the gradient |
| 146 | +my_gradient = my_func.gradient |
| 147 | + |
| 148 | +# Frist we test that it is indeed an odl Operator |
| 149 | +if isinstance(my_gradient, odl.Operator): |
| 150 | + print('The gradient is an operator, as it should be.') |
| 151 | +else: |
| 152 | + print('There is an error in the gradient; it is not an operator.') |
| 153 | + |
| 154 | +# Second, we test that it evaluates correctly |
| 155 | +if my_gradient(x) == 2.0 * x + y: |
| 156 | + print('The gradient evaluates correctly.') |
| 157 | +else: |
| 158 | + print('There is an error in the evaluation of the gradient.') |
| 159 | + |
| 160 | +# Since we have not implemented the (directional) derivative, but we have |
| 161 | +# implemeted the gradient, the default implementation will use this in order to |
| 162 | +# evaluate the derivative. We test this behaviour. |
| 163 | +p = space.element(np.random.randn(n)) |
| 164 | +my_deriv = my_func.derivative(x) |
| 165 | + |
| 166 | +if my_deriv(p) == my_gradient(x).inner(p): |
| 167 | + print('The default implementation of the gradient works as intended.') |
| 168 | +else: |
| 169 | + print('There is a bug in the implementation of the derivative') |
| 170 | + |
| 171 | +# Since the proximal operator was not implemented it will raise a |
| 172 | +# NotImplementedError |
| 173 | +try: |
| 174 | + my_func.proximal() |
| 175 | +except NotImplementedError: |
| 176 | + print('As expected we caught a NotImplementedError when trying to create ' |
| 177 | + 'the proximal operator') |
| 178 | +else: |
| 179 | + print('There should have been an error, but it did not occure.') |
| 180 | + |
| 181 | +# We now create the conjugate functional and test a call to it |
| 182 | +my_func_conj = my_func.conjugate_functional |
| 183 | + |
| 184 | +if my_func_conj(x) == (x.norm()**2 / 2.0 - (x - my_func.y).norm()**2 / 4.0 + |
| 185 | + my_func.y.norm()**2 / 2.0 - x.inner(my_func.y)): |
| 186 | + print('The conjugate functional evaluates correctly.') |
| 187 | +else: |
| 188 | + print('There is an error in the evaluation of the conjugate functional.') |
| 189 | + |
| 190 | +# Nothing else has been implemented in the conjugate functional. For example, |
| 191 | +# there is no gradient. |
| 192 | +try: |
| 193 | + my_func_conj.gradient |
| 194 | +except NotImplementedError: |
| 195 | + print('As expected we caught a NotImplementedError when trying to access ' |
| 196 | + 'the gradient operator.') |
| 197 | +else: |
| 198 | + print('There should have been an error, but it did not occure.') |
| 199 | + |
| 200 | +# There is no derivative either. |
| 201 | +try: |
| 202 | + my_func_conj.derivative(x)(p) |
| 203 | +except NotImplementedError: |
| 204 | + print('As expected we caught a NotImplementedError when trying to ' |
| 205 | + 'evaluate the derivative.') |
| 206 | +else: |
| 207 | + print('There should have been an error, but it did not occure.') |
| 208 | + |
| 209 | +# We now test some general properties that exists for all functionals. We can |
| 210 | +# add two functioanls, scale it by multiplying with a scalar from the left, |
| 211 | +# scale the argument by multiplying with a scalar from the right, and also |
| 212 | +# translate the argument. Except for the sum of two functional, the other |
| 213 | +# operations will apply corrections in order to evaluate gradients, etc., |
| 214 | +# in a correct way. |
| 215 | + |
| 216 | +# Scaling the functional |
| 217 | +func_scal = np.random.rand() |
| 218 | +my_func_scaled = func_scal * my_func |
| 219 | + |
| 220 | +if my_func_scaled(x) == func_scal * (my_func(x)): |
| 221 | + print('Scaling of functional works.') |
| 222 | +else: |
| 223 | + print('There is an error in the scaling of functionals.') |
| 224 | + |
| 225 | +my_func_scaled_grad = my_func_scaled.gradient |
| 226 | +if my_func_scaled_grad(x) == func_scal * (my_func.gradient(x)): |
| 227 | + print('Scaling of functional evaluates gradient correctly.') |
| 228 | +else: |
| 229 | + print('There is an error in evaluating the gradient in the scaling of ' |
| 230 | + 'functionals.') |
| 231 | + |
| 232 | +# Scaling of the argument |
| 233 | +arg_scal = np.random.rand() |
| 234 | +my_func_arg_scaled = my_func * arg_scal |
| 235 | + |
| 236 | +if my_func_arg_scaled(x) == my_func(arg_scal * x): |
| 237 | + print('Scaling of functional argument works.') |
| 238 | +else: |
| 239 | + print('There is an error in the scaling of functional argument.') |
| 240 | + |
| 241 | +# Sum of two functionals |
| 242 | +y_2 = space.element(np.random.randn(n)) |
| 243 | +my_func_2 = MyFunctional(domain=space, y=y_2) |
| 244 | +my_funcs_sum = my_func + my_func_2 |
| 245 | + |
| 246 | +if my_funcs_sum(x) == my_func(x) + my_func_2(x): |
| 247 | + print('Summing two functionals works.') |
| 248 | +else: |
| 249 | + print('There is an error in the summation of functionals.') |
| 250 | + |
| 251 | +# Translation of the functional, i.e., creating the functional f(. - y). |
| 252 | +my_func_translated = my_func.translate(y_2) |
| 253 | +if my_func_translated(x) == my_func(x - y_2): |
| 254 | + print('Translation of functional works.') |
| 255 | +else: |
| 256 | + print('There is an error in the translation of functional.') |
0 commit comments