Skip to content

Commit f7898c4

Browse files
michaelosthegebrandonwillard
authored andcommitted
Support and test passing a dict as first argument to change_flags
1 parent 082d02f commit f7898c4

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

tests/test_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def test_config_context():
158158

159159
with configparser.change_flags(test_config_context="new_value"):
160160
assert root.test_config_context == "new_value"
161+
with root.change_flags({"test_config_context": "new_value2"}):
162+
assert root.test_config_context == "new_value2"
163+
assert root.test_config_context == "new_value"
161164
assert root.test_config_context == "test_default"
162165

163166

theano/configparser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ def warn(cls, message, stacklevel=0):
2323
class _ChangeFlagsDecorator:
2424
def __init__(self, *args, _root=None, **kwargs):
2525
# the old API supported passing a dict as the first argument:
26-
args = dict(args)
27-
args.update(kwargs)
28-
self.confs = {k: _root._config_var_dict[k] for k in args}
29-
self.new_vals = args
26+
if args:
27+
assert len(args) == 1 and isinstance(args[0], dict)
28+
kwargs = dict(**args[0], **kwargs)
29+
self.confs = {k: _root._config_var_dict[k] for k in kwargs}
30+
self.new_vals = kwargs
3031
self._root = _root
3132

3233
def __call__(self, f):

0 commit comments

Comments
 (0)