-
-
Notifications
You must be signed in to change notification settings - Fork 324
/
Copy path_option.py
153 lines (123 loc) · 4.53 KB
/
_option.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from __future__ import annotations
import os
from inspect import currentframe
from logging import getLogger
from types import FrameType
from typing import Any, Callable, Generic, Iterator, TypeVar, cast
from warnings import warn
_O = TypeVar("_O")
logger = getLogger(__name__)
class Option(Generic[_O]):
"""An option that can be set using an environment variable of the same name"""
def __init__(
self,
name: str,
default: _O | Option[_O],
mutable: bool = True,
validator: Callable[[Any], _O] = lambda x: cast(_O, x),
) -> None:
self._name = name
self._mutable = mutable
self._validator = validator
self._subscribers: list[Callable[[_O], None]] = []
if name in os.environ:
self._current = validator(os.environ[name])
self._default: _O
if isinstance(default, Option):
self._default = default.default
default.subscribe(lambda value: setattr(self, "_default", value))
else:
self._default = default
logger.debug(f"{self._name}={self.current}")
@property
def name(self) -> str:
"""The name of this option (used to load environment variables)"""
return self._name
@property
def mutable(self) -> bool:
"""Whether this option can be modified after being loaded"""
return self._mutable
@property
def default(self) -> _O:
"""This option's default value"""
return self._default
@property
def current(self) -> _O:
try:
return self._current
except AttributeError:
return self._default
@current.setter
def current(self, new: _O) -> None:
self.set_current(new)
return None
def subscribe(self, handler: Callable[[_O], None]) -> Callable[[_O], None]:
"""Register a callback that will be triggered when this option changes"""
if not self.mutable:
raise TypeError("Immutable options cannot be subscribed to.")
self._subscribers.append(handler)
handler(self.current)
return handler
def is_set(self) -> bool:
"""Whether this option has a value other than its default."""
return hasattr(self, "_current")
def set_current(self, new: Any) -> None:
"""Set the value of this option
Raises a ``TypeError`` if this option is not :attr:`Option.mutable`.
"""
if not self._mutable:
raise TypeError(f"{self} cannot be modified after initial load")
old = self.current
new = self._current = self._validator(new)
logger.debug(f"{self._name}={self._current}")
if new != old:
for sub_func in self._subscribers:
sub_func(new)
def set_default(self, new: _O) -> _O:
"""Set the value of this option if not :meth:`Option.is_set`
Returns the current value (a la :meth:`dict.set_default`)
"""
if not self.is_set():
self.set_current(new)
return self._current
def reload(self) -> None:
"""Reload this option from its environment variable"""
self.set_current(os.environ.get(self._name, self._default))
def unset(self) -> None:
"""Remove the current value, the default will be used until it is set again."""
if not self._mutable:
raise TypeError(f"{self} cannot be modified after initial load")
old = self.current
delattr(self, "_current")
if self.current != old:
for sub_func in self._subscribers:
sub_func(self.current)
def __repr__(self) -> str:
return f"Option({self._name}={self.current!r})"
class DeprecatedOption(Option[_O]): # pragma: no cover
def __init__(self, message: str, *args: Any, **kwargs: Any) -> None:
self._deprecation_message = message
super().__init__(*args, **kwargs)
@Option.current.getter # type: ignore
def current(self) -> _O:
warn(
self._deprecation_message,
DeprecationWarning,
stacklevel=_frame_depth_in_module() + 1,
)
return super().current
def _frame_depth_in_module() -> int:
depth = 0
for frame in _iter_frames(2):
if frame.f_globals.get("__name__") != __name__:
break
depth += 1
return depth
def _iter_frames(index: int = 1) -> Iterator[FrameType]:
frame = currentframe()
while frame is not None:
if index == 0:
yield frame
else:
index -= 1
frame = frame.f_back