Skip to content

Commit 70fdcf9

Browse files
Remove theanof.set_theano_conf and instead use the config context (#4329)
* Remove theanof.set_theano_conf and instead use the config context properly
1 parent 6f15cbb commit 70fdcf9

File tree

4 files changed

+10
-65
lines changed

4 files changed

+10
-65
lines changed

RELEASE-NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
### Maintenance
66
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
77
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
8+
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
9+
810

911
## PyMC3 3.10.0 (7 December 2020)
1012

pymc3/model.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,7 @@
3636
from pymc3.exceptions import ImputationWarning
3737
from pymc3.math import flatten_list
3838
from pymc3.memoize import WithMemoization, memoize
39-
from pymc3.theanof import (
40-
floatX,
41-
generator,
42-
gradient,
43-
hessian,
44-
inputvars,
45-
set_theano_conf,
46-
)
39+
from pymc3.theanof import floatX, generator, gradient, hessian, inputvars
4740
from pymc3.util import get_transformed_name, get_var_name
4841
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter
4942

@@ -288,15 +281,17 @@ def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
288281
def __enter__(self):
289282
self.__class__.context_class.get_contexts().append(self)
290283
# self._theano_config is set in Model.__new__
284+
self._config_context = None
291285
if hasattr(self, "_theano_config"):
292-
self._old_theano_config = set_theano_conf(self._theano_config)
286+
self._config_context = theano.change_flags(**self._theano_config)
287+
self._config_context.__enter__()
293288
return self
294289

295290
def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
296291
self.__class__.context_class.get_contexts().pop()
297292
# self._theano_config is set in Model.__new__
298-
if hasattr(self, "_old_theano_config"):
299-
set_theano_conf(self._old_theano_config)
293+
if self._config_context:
294+
self._config_context.__exit__(typ, value, traceback)
300295

301296
dct[__enter__.__name__] = __enter__
302297
dct[__exit__.__name__] = __exit__

pymc3/tests/test_theanof.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import collections
16-
1715
from itertools import product
1816

1917
import numpy as np
2018
import pytest
2119
import theano
2220
import theano.tensor as tt
2321

24-
from pymc3.theanof import _conversion_map, set_theano_conf, take_along_axis
22+
from pymc3.theanof import _conversion_map, take_along_axis
2523
from pymc3.vartypes import int_types
2624

2725
FLOATX = str(theano.config.floatX)
@@ -72,27 +70,6 @@ def np_take_along_axis(arr, indices, axis):
7270
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]
7371

7472

75-
class TestSetTheanoConfig:
76-
def test_invalid_key(self):
77-
with pytest.raises(ValueError) as e:
78-
set_theano_conf({"bad_key": True})
79-
e.match("Unknown")
80-
81-
def test_restore_when_bad_key(self):
82-
with theano.configparser.change_flags(compute_test_value="off"):
83-
with pytest.raises(ValueError):
84-
conf = collections.OrderedDict([("compute_test_value", "raise"), ("bad_key", True)])
85-
set_theano_conf(conf)
86-
assert theano.config.compute_test_value == "off"
87-
88-
def test_restore(self):
89-
with theano.configparser.change_flags(compute_test_value="off"):
90-
conf = set_theano_conf({"compute_test_value": "raise"})
91-
assert conf == {"compute_test_value": "off"}
92-
conf = set_theano_conf(conf)
93-
assert conf == {"compute_test_value": "raise"}
94-
95-
9673
class TestTakeAlongAxis:
9774
def setup_class(self):
9875
self.inputs_buffer = dict()

pymc3/theanof.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import numpy as np
1616
import theano
1717

18-
from theano import scalar
18+
from theano import change_flags, scalar
1919
from theano import tensor as tt
20-
from theano.configparser import change_flags
2120
from theano.gof import Op
2221
from theano.gof.graph import inputs
2322
from theano.sandbox.rng_mrg import MRG_RandomStreams
@@ -442,34 +441,6 @@ def floatX_array(x):
442441
return floatX(np.array(x))
443442

444443

445-
def set_theano_conf(values):
446-
"""Change the theano configuration and return old values.
447-
448-
This is similar to `theano.configparser.change_flags`, but it
449-
returns the original values in a pickleable form.
450-
"""
451-
variables = {}
452-
unknown = set(values.keys())
453-
for variable in theano.configparser._config_var_list:
454-
if variable.fullname in values:
455-
variables[variable.fullname] = variable
456-
unknown.remove(variable.fullname)
457-
if len(unknown) > 0:
458-
raise ValueError("Unknown theano config settings: %s" % unknown)
459-
460-
old = {}
461-
for name, variable in variables.items():
462-
old_value = variable.__get__(True, None)
463-
try:
464-
variable.__set__(None, values[name])
465-
except Exception:
466-
for key, old_value in old.items():
467-
variables[key].__set__(None, old_value)
468-
raise
469-
old[name] = old_value
470-
return old
471-
472-
473444
def ix_(*args):
474445
"""
475446
Theano np.ix_ analog

0 commit comments

Comments
 (0)