Skip to content

Commit e522772

Browse files
authored
Improve generics usage (#282)
This PR makes some improvements related to generic types. It adds public generic type vars so the documentation is rendered more nicely and we get documentation on the type parameters. It also extends the use of covariance for other classes or functions that work like containers that don't change the underlying data, so they can be used more flexibly.
2 parents 63c8dc0 + 13391c8 commit e522772

File tree

10 files changed

+127
-71
lines changed

10 files changed

+127
-71
lines changed

RELEASE_NOTES.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@
173173

174174
## Improvements
175175

176-
* `Receiver`: Use a covariant generic type, which allows the generic type to be broader than the actual type.
176+
* `Receiver`, `merge`/`Merger`, `Error` and its derived classes now use a covariant generic type, which allows the generic type to be broader than the actual type.
177177

178-
* `Sender`: Use a contravariant generic type, which allows the generic type to be narrower than the required type.
178+
* `Sender` now uses a contravariant generic type, which allows the generic type to be narrower than the required type.
179+
180+
* `ChannelError` is now generic, so when accessing the `channel` attribute, the type of the channel is preserved.
179181

180182
## Bug Fixes
181183

src/frequenz/channels/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@
7777
from ._anycast import Anycast
7878
from ._broadcast import Broadcast
7979
from ._exceptions import ChannelClosedError, ChannelError, Error
80+
from ._generic import (
81+
ChannelMessageT,
82+
ErroredChannelT_co,
83+
MappedMessageT_co,
84+
ReceiverMessageT_co,
85+
SenderMessageT_co,
86+
SenderMessageT_contra,
87+
)
8088
from ._merge import Merger, merge
8189
from ._receiver import Receiver, ReceiverError, ReceiverStoppedError
8290
from ._select import (
@@ -93,15 +101,21 @@
93101
"Broadcast",
94102
"ChannelClosedError",
95103
"ChannelError",
104+
"ChannelMessageT",
96105
"Error",
106+
"ErroredChannelT_co",
107+
"MappedMessageT_co",
97108
"Merger",
98109
"Receiver",
99110
"ReceiverError",
111+
"ReceiverMessageT_co",
100112
"ReceiverStoppedError",
101113
"SelectError",
102114
"Selected",
103115
"Sender",
104116
"SenderError",
117+
"SenderMessageT_co",
118+
"SenderMessageT_contra",
105119
"UnhandledSelectedError",
106120
"merge",
107121
"select",

src/frequenz/channels/_anycast.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
from typing import Generic, TypeVar
1212

1313
from ._exceptions import ChannelClosedError
14+
from ._generic import ChannelMessageT
1415
from ._receiver import Receiver, ReceiverStoppedError
1516
from ._sender import Sender, SenderError
1617

1718
_logger = logging.getLogger(__name__)
1819

19-
_T = TypeVar("_T")
20-
2120

22-
class Anycast(Generic[_T]):
21+
class Anycast(Generic[ChannelMessageT]):
2322
"""A channel that delivers each message to exactly one receiver.
2423
2524
# Description
@@ -213,7 +212,7 @@ def __init__(self, *, name: str, limit: int = 10) -> None:
213212
of the channel.
214213
"""
215214

216-
self._deque: deque[_T] = deque(maxlen=limit)
215+
self._deque: deque[ChannelMessageT] = deque(maxlen=limit)
217216
"""The channel's buffer."""
218217

219218
self._send_cv: Condition = Condition()
@@ -282,11 +281,11 @@ async def close(self) -> None:
282281
async with self._recv_cv:
283282
self._recv_cv.notify_all()
284283

285-
def new_sender(self) -> Sender[_T]:
284+
def new_sender(self) -> Sender[ChannelMessageT]:
286285
"""Return a new sender attached to this channel."""
287286
return _Sender(self)
288287

289-
def new_receiver(self) -> Receiver[_T]:
288+
def new_receiver(self) -> Receiver[ChannelMessageT]:
290289
"""Return a new receiver attached to this channel."""
291290
return _Receiver(self)
292291

@@ -302,6 +301,9 @@ def __repr__(self) -> str:
302301
)
303302

304303

304+
_T = TypeVar("_T")
305+
306+
305307
class _Sender(Sender[_T]):
306308
"""A sender to send messages to an Anycast channel.
307309

src/frequenz/channels/_broadcast.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
from typing import Generic, TypeVar
1313

1414
from ._exceptions import ChannelClosedError
15+
from ._generic import ChannelMessageT
1516
from ._receiver import Receiver, ReceiverStoppedError
1617
from ._sender import Sender, SenderError
1718

1819
_logger = logging.Logger(__name__)
1920

20-
_T = TypeVar("_T")
21-
2221

23-
class Broadcast(Generic[_T]):
22+
class Broadcast(Generic[ChannelMessageT]):
2423
"""A channel that deliver all messages to all receivers.
2524
2625
# Description
@@ -206,13 +205,15 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None:
206205
self._recv_cv: Condition = Condition()
207206
"""The condition to wait for data in the channel's buffer."""
208207

209-
self._receivers: dict[int, weakref.ReferenceType[_Receiver[_T]]] = {}
208+
self._receivers: dict[
209+
int, weakref.ReferenceType[_Receiver[ChannelMessageT]]
210+
] = {}
210211
"""The receivers attached to the channel, indexed by their hash()."""
211212

212213
self._closed: bool = False
213214
"""Whether the channel is closed."""
214215

215-
self._latest: _T | None = None
216+
self._latest: ChannelMessageT | None = None
216217
"""The latest message sent to the channel."""
217218

218219
self.resend_latest: bool = resend_latest
@@ -261,11 +262,13 @@ async def close(self) -> None:
261262
async with self._recv_cv:
262263
self._recv_cv.notify_all()
263264

264-
def new_sender(self) -> Sender[_T]:
265+
def new_sender(self) -> Sender[ChannelMessageT]:
265266
"""Return a new sender attached to this channel."""
266267
return _Sender(self)
267268

268-
def new_receiver(self, *, name: str | None = None, limit: int = 50) -> Receiver[_T]:
269+
def new_receiver(
270+
self, *, name: str | None = None, limit: int = 50
271+
) -> Receiver[ChannelMessageT]:
269272
"""Return a new receiver attached to this channel.
270273
271274
Broadcast receivers have their own buffer, and when messages are not
@@ -279,7 +282,7 @@ def new_receiver(self, *, name: str | None = None, limit: int = 50) -> Receiver[
279282
Returns:
280283
A new receiver attached to this channel.
281284
"""
282-
recv: _Receiver[_T] = _Receiver(self, name=name, limit=limit)
285+
recv: _Receiver[ChannelMessageT] = _Receiver(self, name=name, limit=limit)
283286
self._receivers[hash(recv)] = weakref.ref(recv)
284287
if self.resend_latest and self._latest is not None:
285288
recv.enqueue(self._latest)
@@ -300,6 +303,9 @@ def __repr__(self) -> str:
300303
)
301304

302305

306+
_T = TypeVar("_T")
307+
308+
303309
class _Sender(Sender[_T]):
304310
"""A sender to send messages to the broadcast channel.
305311

src/frequenz/channels/_exceptions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@
6666
```
6767
"""
6868

69-
from typing import Any
69+
from typing import Generic
70+
71+
from ._generic import ErroredChannelT_co
7072

7173

7274
class Error(RuntimeError):
@@ -84,28 +86,28 @@ def __init__(self, message: str):
8486
super().__init__(message)
8587

8688

87-
class ChannelError(Error):
89+
class ChannelError(Error, Generic[ErroredChannelT_co]):
8890
"""An error that originated in a channel.
8991
9092
All exceptions generated by channels inherit from this exception.
9193
"""
9294

93-
def __init__(self, message: str, channel: Any):
95+
def __init__(self, message: str, channel: ErroredChannelT_co):
9496
"""Initialize this error.
9597
9698
Args:
9799
message: The error message.
98100
channel: The channel where the error happened.
99101
"""
100102
super().__init__(message)
101-
self.channel: Any = channel
103+
self.channel: ErroredChannelT_co = channel
102104
"""The channel where the error happened."""
103105

104106

105-
class ChannelClosedError(ChannelError):
107+
class ChannelClosedError(ChannelError[ErroredChannelT_co]):
106108
"""A closed channel was used."""
107109

108-
def __init__(self, channel: Any):
110+
def __init__(self, channel: ErroredChannelT_co):
109111
"""Initialize this error.
110112
111113
Args:

src/frequenz/channels/_generic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Generic type variables."""
5+
6+
from typing import TypeVar
7+
8+
ChannelMessageT = TypeVar("ChannelMessageT")
9+
"""The type of the message that can be sent across a channel."""
10+
11+
ErroredChannelT_co = TypeVar("ErroredChannelT_co", covariant=True)
12+
"""The type of channel having an error."""
13+
14+
MappedMessageT_co = TypeVar("MappedMessageT_co", covariant=True)
15+
"""The type of the message received by the receiver after being mapped."""
16+
17+
ReceiverMessageT_co = TypeVar("ReceiverMessageT_co", covariant=True)
18+
"""The type of the message received by a receiver."""
19+
20+
SenderMessageT_co = TypeVar("SenderMessageT_co", covariant=True)
21+
"""The type of the message sent by a sender."""
22+
23+
SenderMessageT_contra = TypeVar("SenderMessageT_contra", contravariant=True)
24+
"""The type of the message sent by a sender."""

src/frequenz/channels/_merge.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,13 @@
5252
import asyncio
5353
import itertools
5454
from collections import deque
55-
from typing import Any, TypeVar
55+
from typing import Any
5656

57+
from ._generic import ReceiverMessageT_co
5758
from ._receiver import Receiver, ReceiverStoppedError
5859

59-
_T = TypeVar("_T")
6060

61-
62-
def merge(*receivers: Receiver[_T]) -> Merger[_T]:
61+
def merge(*receivers: Receiver[ReceiverMessageT_co]) -> Merger[ReceiverMessageT_co]:
6362
"""Merge messages coming from multiple receivers into a single stream.
6463
6564
Example:
@@ -95,31 +94,33 @@ def merge(*receivers: Receiver[_T]) -> Merger[_T]:
9594
return Merger(*receivers, name="merge")
9695

9796

98-
class Merger(Receiver[_T]):
97+
class Merger(Receiver[ReceiverMessageT_co]):
9998
"""A receiver that merges messages coming from multiple receivers into a single stream.
10099
101100
Tip:
102101
Please consider using the more idiomatic [`merge()`][frequenz.channels.merge]
103102
function instead of creating a `Merger` instance directly.
104103
"""
105104

106-
def __init__(self, *receivers: Receiver[_T], name: str | None) -> None:
105+
def __init__(
106+
self, *receivers: Receiver[ReceiverMessageT_co], name: str | None
107+
) -> None:
107108
"""Initialize this merger.
108109
109110
Args:
110111
*receivers: The receivers to merge.
111112
name: The name of the receiver. Used to create the string representation
112113
of the receiver.
113114
"""
114-
self._receivers: dict[str, Receiver[_T]] = {
115+
self._receivers: dict[str, Receiver[ReceiverMessageT_co]] = {
115116
str(id): recv for id, recv in enumerate(receivers)
116117
}
117118
self._name: str = name if name is not None else type(self).__name__
118119
self._pending: set[asyncio.Task[Any]] = {
119120
asyncio.create_task(anext(recv), name=name)
120121
for name, recv in self._receivers.items()
121122
}
122-
self._results: deque[_T] = deque(maxlen=len(self._receivers))
123+
self._results: deque[ReceiverMessageT_co] = deque(maxlen=len(self._receivers))
123124

124125
def __del__(self) -> None:
125126
"""Finalize this merger."""
@@ -170,7 +171,7 @@ async def ready(self) -> bool:
170171
asyncio.create_task(anext(self._receivers[name]), name=name)
171172
)
172173

173-
def consume(self) -> _T:
174+
def consume(self) -> ReceiverMessageT_co:
174175
"""Return the latest message once `ready` is complete.
175176
176177
Returns:

0 commit comments

Comments
 (0)