Skip to content

Commit 7d4c2c2

Browse files
authored
Merge pull request #52 from jordanhemingway-revvity/type_annotations
added typehints to midimessage and init
2 parents 36f8d5a + 9817d84 commit 7d4c2c2

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

adafruit_midi/__init__.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
https://github.com/adafruit/circuitpython/releases
2626
2727
"""
28+
try:
29+
from typing import Union, Tuple, Any, List, Optional, Dict, BinaryIO
30+
except ImportError:
31+
pass
2832

2933
from .midi_message import MIDIMessage
3034

@@ -54,13 +58,13 @@ class MIDI:
5458

5559
def __init__(
5660
self,
57-
midi_in=None,
58-
midi_out=None,
61+
midi_in: Optional[BinaryIO] = None,
62+
midi_out: Optional[BinaryIO] = None,
5963
*,
60-
in_channel=None,
61-
out_channel=0,
62-
in_buf_size=30,
63-
debug=False
64+
in_channel: Optional[Union[int, Tuple[int, ...]]] = None,
65+
out_channel: int = 0,
66+
in_buf_size: int = 30,
67+
debug: bool = False
6468
):
6569
if midi_in is None and midi_out is None:
6670
raise ValueError("No midi_in or midi_out provided")
@@ -78,7 +82,7 @@ def __init__(
7882
self._skipped_bytes = 0
7983

8084
@property
81-
def in_channel(self):
85+
def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]:
8286
"""The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
8387
``in_channel = 3`` will listen on MIDI channel 4.
8488
Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)``
@@ -87,7 +91,7 @@ def in_channel(self):
8791
return self._in_channel
8892

8993
@in_channel.setter
90-
def in_channel(self, channel):
94+
def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None:
9195
if channel is None or channel == "ALL":
9296
self._in_channel = tuple(range(16))
9397
elif isinstance(channel, int) and 0 <= channel <= 15:
@@ -98,19 +102,19 @@ def in_channel(self, channel):
98102
raise RuntimeError("Invalid input channel")
99103

100104
@property
101-
def out_channel(self):
105+
def out_channel(self) -> int:
102106
"""The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
103107
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1).
104108
"""
105109
return self._out_channel
106110

107111
@out_channel.setter
108-
def out_channel(self, channel):
112+
def out_channel(self, channel: int) -> None:
109113
if not 0 <= channel <= 15:
110114
raise RuntimeError("Invalid output channel")
111115
self._out_channel = channel
112116

113-
def receive(self):
117+
def receive(self) -> Optional[MIDIMessage]:
114118
"""Read messages from MIDI port, store them in internal read buffer, then parse that data
115119
and return the first MIDI message (event).
116120
This maintains the blocking characteristics of the midi_in port.
@@ -141,7 +145,7 @@ def receive(self):
141145
# msg could still be None at this point, e.g. in middle of monster SysEx
142146
return msg
143147

144-
def send(self, msg, channel=None):
148+
def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None:
145149
"""Sends a MIDI message.
146150
147151
:param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects.
@@ -165,7 +169,7 @@ def send(self, msg, channel=None):
165169

166170
self._send(data, len(data))
167171

168-
def _send(self, packet, num):
172+
def _send(self, packet: bytes, num: int) -> None:
169173
if self._debug:
170174
print("Sending: ", [hex(i) for i in packet[:num]])
171175
self._midi_out.write(packet, num)

adafruit_midi/midi_message.py

+41-20
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,19 @@
2525
__version__ = "0.0.0+auto.0"
2626
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git"
2727

28+
try:
29+
from typing import Union, Tuple, Any, List, Optional
30+
except ImportError:
31+
pass
32+
2833
# From C3 - A and B are above G
2934
# Semitones A B C D E F G
3035
NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19]
3136

3237

33-
def channel_filter(channel, channel_spec):
38+
def channel_filter(
39+
channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]]
40+
) -> bool:
3441
"""
3542
Utility function to return True iff the given channel matches channel_spec.
3643
"""
@@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec):
4148
raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec)))
4249

4350

44-
def note_parser(note):
51+
def note_parser(note: Union[int, str]) -> int:
4552
"""If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g.
4653
"C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned.
4754
4855
:param note: Either 0-127 int or a str representing the note, e.g. "C#4"
4956
"""
50-
midi_note = note
5157
if isinstance(note, str):
5258
if len(note) < 2:
5359
raise ValueError("Bad note format")
@@ -61,7 +67,8 @@ def note_parser(note):
6167
sharpen = -1
6268
# int may throw exception here
6369
midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen
64-
70+
elif isinstance(note, int):
71+
midi_note = note
6572
return midi_note
6673

6774

@@ -82,40 +89,43 @@ class MIDIMessage:
8289
This is an *abstract* class.
8390
"""
8491

85-
_STATUS = None
92+
_STATUS: Optional[int] = None
8693
_STATUSMASK = None
87-
LENGTH = None
94+
LENGTH: Optional[int] = None
8895
CHANNELMASK = 0x0F
8996
ENDSTATUS = None
9097

9198
# Commonly used exceptions to save memory
9299
@staticmethod
93-
def _raise_valueerror_oor():
100+
def _raise_valueerror_oor() -> None:
94101
raise ValueError("Out of range")
95102

96103
# Each element is ((status, mask), class)
97104
# order is more specific masks first
98-
_statusandmask_to_class = []
105+
# Add better type hints for status, mask, class referenced above
106+
_statusandmask_to_class: List[
107+
Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"]
108+
] = []
99109

100-
def __init__(self, *, channel=None):
110+
def __init__(self, *, channel: Optional[int] = None) -> None:
101111
self._channel = channel # dealing with pylint inadequacy
102112
self.channel = channel
103113

104114
@property
105-
def channel(self):
115+
def channel(self) -> Optional[int]:
106116
"""The channel number of the MIDI message where appropriate.
107117
This is *updated* by MIDI.send() method.
108118
"""
109119
return self._channel
110120

111121
@channel.setter
112-
def channel(self, channel):
122+
def channel(self, channel: int) -> None:
113123
if channel is not None and not 0 <= channel <= 15:
114124
raise ValueError("Channel must be 0-15 or None")
115125
self._channel = channel
116126

117127
@classmethod
118-
def register_message_type(cls):
128+
def register_message_type(cls) -> None:
119129
"""Register a new message by its status value and mask.
120130
This is called automagically at ``import`` time for each message.
121131
"""
@@ -132,7 +142,14 @@ def register_message_type(cls):
132142

133143
# pylint: disable=too-many-arguments
134144
@classmethod
135-
def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx):
145+
def _search_eom_status(
146+
cls,
147+
buf: bytearray,
148+
eom_status: Optional[int],
149+
msgstartidx: int,
150+
msgendidxplusone: int,
151+
endidx: int,
152+
) -> Tuple[int, bool, bool]:
136153
good_termination = False
137154
bad_termination = False
138155

@@ -155,7 +172,9 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi
155172
return (msgendidxplusone, good_termination, bad_termination)
156173

157174
@classmethod
158-
def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
175+
def _match_message_status(
176+
cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int
177+
) -> Tuple[Optional[Any], int, bool, bool, bool, int]:
159178
msgclass = None
160179
status = buf[msgstartidx]
161180
known_msg = False
@@ -198,7 +217,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
198217

199218
# pylint: disable=too-many-locals,too-many-branches
200219
@classmethod
201-
def from_message_bytes(cls, midibytes, channel_in):
220+
def from_message_bytes(
221+
cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]]
222+
) -> Tuple[Optional["MIDIMessage"], int, int]:
202223
"""Create an appropriate object of the correct class for the
203224
first message found in some MIDI bytes filtered by channel_in.
204225
@@ -270,7 +291,7 @@ def from_message_bytes(cls, midibytes, channel_in):
270291

271292
# A default method for constructing wire messages with no data.
272293
# Returns an (immutable) bytes with just the status code in.
273-
def __bytes__(self):
294+
def __bytes__(self) -> bytes:
274295
"""Return the ``bytes`` wire protocol representation of the object
275296
with channel number applied where appropriate."""
276297
return bytes([self._STATUS])
@@ -280,12 +301,12 @@ def __bytes__(self):
280301
# Returns the new object.
281302
# pylint: disable=unused-argument
282303
@classmethod
283-
def from_bytes(cls, msg_bytes):
304+
def from_bytes(cls, msg_bytes: bytes) -> "MIDIMessage":
284305
"""Creates an object from the byte stream of the wire protocol
285306
representation of the MIDI message."""
286307
return cls()
287308

288-
def __str__(self):
309+
def __str__(self) -> str:
289310
"""Print an instance"""
290311
cls = self.__class__
291312
if slots := getattr(cls, "_message_slots", None):
@@ -313,7 +334,7 @@ class MIDIUnknownEvent(MIDIMessage):
313334
_message_slots = ["status"]
314335
LENGTH = -1
315336

316-
def __init__(self, status):
337+
def __init__(self, status: int):
317338
self.status = status
318339
super().__init__()
319340

@@ -333,7 +354,7 @@ class MIDIBadEvent(MIDIMessage):
333354

334355
_message_slots = ["msg_bytes", "exception"]
335356

336-
def __init__(self, msg_bytes, exception):
357+
def __init__(self, msg_bytes: bytearray, exception: Exception):
337358
self.data = bytes(msg_bytes)
338359
self.exception_text = repr(exception)
339360
super().__init__()

0 commit comments

Comments
 (0)