Skip to content

added typehints to midimessage and init #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions adafruit_midi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
https://github.com/adafruit/circuitpython/releases

"""
try:
from typing import Union, Tuple, Any, List, Optional, Dict
except ImportError:
pass

from .midi_message import MIDIMessage

Expand Down Expand Up @@ -54,13 +58,13 @@ class MIDI:

def __init__(
self,
midi_in=None,
midi_out=None,
midi_in: Optional[Any] = None,
midi_out: Optional[Any] = None,
*,
in_channel=None,
out_channel=0,
in_buf_size=30,
debug=False
in_channel: Optional[Union[int, Tuple[int, ...]]] = None,
out_channel: int = 0,
in_buf_size: int = 30,
debug: bool = False
):
if midi_in is None and midi_out is None:
raise ValueError("No midi_in or midi_out provided")
Expand All @@ -78,7 +82,7 @@ def __init__(
self._skipped_bytes = 0

@property
def in_channel(self):
def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]:
"""The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``in_channel = 3`` will listen on MIDI channel 4.
Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)``
Expand All @@ -87,7 +91,7 @@ def in_channel(self):
return self._in_channel

@in_channel.setter
def in_channel(self, channel):
def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None:
if channel is None or channel == "ALL":
self._in_channel = tuple(range(16))
elif isinstance(channel, int) and 0 <= channel <= 15:
Expand All @@ -98,18 +102,20 @@ def in_channel(self, channel):
raise RuntimeError("Invalid input channel")

@property
def out_channel(self):
def out_channel(self) -> int:
"""The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1)."""
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1).
"""
return self._out_channel

@out_channel.setter
def out_channel(self, channel):
def out_channel(self, channel: Optional[int]) -> None:
assert channel is not None
if not 0 <= channel <= 15:
raise RuntimeError("Invalid output channel")
self._out_channel = channel

def receive(self):
def receive(self) -> Optional[MIDIMessage]:
"""Read messages from MIDI port, store them in internal read buffer, then parse that data
and return the first MIDI message (event).
This maintains the blocking characteristics of the midi_in port.
Expand All @@ -120,6 +126,7 @@ def receive(self):
# If the buffer here is not full then read as much as we can fit from
# the input port
if len(self._in_buf) < self._in_buf_size:
assert self._midi_in is not None
bytes_in = self._midi_in.read(self._in_buf_size - len(self._in_buf))
if bytes_in:
if self._debug:
Expand All @@ -140,7 +147,7 @@ def receive(self):
# msg could still be None at this point, e.g. in middle of monster SysEx
return msg

def send(self, msg, channel=None):
def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None:
"""Sends a MIDI message.

:param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects.
Expand All @@ -164,7 +171,8 @@ def send(self, msg, channel=None):

self._send(data, len(data))

def _send(self, packet, num):
def _send(self, packet: bytes, num: int) -> None:
if self._debug:
print("Sending: ", [hex(i) for i in packet[:num]])
assert self._midi_out is not None
self._midi_out.write(packet, num)
65 changes: 46 additions & 19 deletions adafruit_midi/midi_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@
__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git"

try:
from typing import Union, Tuple, Any, List, Optional, Dict
except ImportError:
pass

# From C3 - A and B are above G
# Semitones A B C D E F G
NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19]


def channel_filter(channel, channel_spec):
def channel_filter(
channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]]
) -> bool:
"""
Utility function to return True iff the given channel matches channel_spec.
"""
Expand All @@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec):
raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec)))


def note_parser(note):
def note_parser(note: Union[int, str]) -> int:
"""If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g.
"C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned.

:param note: Either 0-127 int or a str representing the note, e.g. "C#4"
"""
midi_note = note
if isinstance(note, str):
if len(note) < 2:
raise ValueError("Bad note format")
Expand All @@ -61,7 +67,8 @@ def note_parser(note):
sharpen = -1
# int may throw exception here
midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen

elif isinstance(note, int):
midi_note = note
return midi_note


Expand All @@ -82,57 +89,70 @@ class MIDIMessage:
This is an *abstract* class.
"""

_STATUS = None
_STATUS: Optional[int] = None
_STATUSMASK = None
LENGTH = None
LENGTH: Optional[int] = None
CHANNELMASK = 0x0F
ENDSTATUS = None

# Commonly used exceptions to save memory
@staticmethod
def _raise_valueerror_oor():
def _raise_valueerror_oor() -> None:
raise ValueError("Out of range")

# Each element is ((status, mask), class)
# order is more specific masks first
_statusandmask_to_class = []
# Add better type hints for status, mask, class referenced above
_statusandmask_to_class: List[
Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"]
] = []

def __init__(self, *, channel=None):
def __init__(self, *, channel: Optional[int] = None) -> None:
self._channel = channel # dealing with pylint inadequacy
self.channel = channel

@property
def channel(self):
def channel(self) -> Optional[int]:
"""The channel number of the MIDI message where appropriate.
This is *updated* by MIDI.send() method.
"""
return self._channel

@channel.setter
def channel(self, channel):
def channel(self, channel: int) -> None:
if channel is not None and not 0 <= channel <= 15:
raise ValueError("Channel must be 0-15 or None")
self._channel = channel

@classmethod
def register_message_type(cls):
def register_message_type(cls) -> None:
"""Register a new message by its status value and mask.
This is called automagically at ``import`` time for each message.
"""
### These must be inserted with more specific masks first
insert_idx = len(MIDIMessage._statusandmask_to_class)
for idx, m_type in enumerate(MIDIMessage._statusandmask_to_class):
assert cls._STATUSMASK is not None
if cls._STATUSMASK > m_type[0][1]:
insert_idx = idx
break

assert cls._STATUS is not None
assert cls._STATUSMASK is not None
MIDIMessage._statusandmask_to_class.insert(
insert_idx, ((cls._STATUS, cls._STATUSMASK), cls)
)

# pylint: disable=too-many-arguments
@classmethod
def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx):
def _search_eom_status(
cls,
buf: Dict[int, bool],
eom_status: bool,
msgstartidx: int,
msgendidxplusone: int,
endidx: int,
) -> Tuple[int, bool, bool]:
good_termination = False
bad_termination = False

Expand All @@ -155,14 +175,17 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi
return (msgendidxplusone, good_termination, bad_termination)

@classmethod
def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
def _match_message_status(
cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int
) -> Tuple[Optional[Any], bool, bool, bool, bool, int]:
msgclass = None
status = buf[msgstartidx]
known_msg = False
complete_msg = False
bad_termination = False

# Rummage through our list looking for a status match
assert msgclass is not None
for status_mask, msgclass in MIDIMessage._statusandmask_to_class:
masked_status = status & status_mask[1]
if status_mask[0] == masked_status:
Expand Down Expand Up @@ -198,7 +221,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):

# pylint: disable=too-many-locals,too-many-branches
@classmethod
def from_message_bytes(cls, midibytes, channel_in):
def from_message_bytes(
cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]]
) -> Tuple[Optional["MIDIMessage"], int, int]:
"""Create an appropriate object of the correct class for the
first message found in some MIDI bytes filtered by channel_in.

Expand Down Expand Up @@ -240,6 +265,7 @@ def from_message_bytes(cls, midibytes, channel_in):
channel_match_orna = True
if complete_message and not bad_termination:
try:
assert msgclass is not None
msg = msgclass.from_bytes(midibytes[msgstartidx:msgendidxplusone])
if msg.channel is not None:
channel_match_orna = channel_filter(msg.channel, channel_in)
Expand Down Expand Up @@ -270,17 +296,18 @@ def from_message_bytes(cls, midibytes, channel_in):

# A default method for constructing wire messages with no data.
# Returns an (immutable) bytes with just the status code in.
def __bytes__(self):
def __bytes__(self) -> bytes:
"""Return the ``bytes`` wire protocol representation of the object
with channel number applied where appropriate."""
assert self._STATUS is not None
return bytes([self._STATUS])

# databytes value present to keep interface uniform but unused
# A default method for constructing message objects with no data.
# Returns the new object.
# pylint: disable=unused-argument
@classmethod
def from_bytes(cls, msg_bytes):
def from_bytes(cls, msg_bytes: bytes) -> "MIDIMessage":
"""Creates an object from the byte stream of the wire protocol
representation of the MIDI message."""
return cls()
Expand All @@ -298,7 +325,7 @@ class MIDIUnknownEvent(MIDIMessage):

LENGTH = -1

def __init__(self, status):
def __init__(self, status: int):
self.status = status
super().__init__()

Expand All @@ -316,7 +343,7 @@ class MIDIBadEvent(MIDIMessage):

LENGTH = -1

def __init__(self, msg_bytes, exception):
def __init__(self, msg_bytes: bytearray, exception: Exception):
self.data = bytes(msg_bytes)
self.exception_text = repr(exception)
super().__init__()