Skip to content

Commit 68add87

Browse files
committed
basic refactoring to remove errors
1 parent 04b1e0f commit 68add87

File tree

1 file changed

+70
-120
lines changed

1 file changed

+70
-120
lines changed

adafruit_ina3221.py

Lines changed: 70 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,11 @@
2323
* Adafruit's Bus Device library: https://github.com/adafruit/Adafruit_CircuitPython_BusDevice
2424
"""
2525

26-
import time
27-
28-
from adafruit_bus_device.i2c_device import I2CDevice
29-
3026
try:
3127
from typing import Any, List
32-
33-
from busio import I2C
3428
except ImportError:
3529
pass
30+
from adafruit_bus_device.i2c_device import I2CDevice
3631

3732
__version__ = "0.0.0+auto.0"
3833
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_INA3221.git"
@@ -152,32 +147,50 @@ class INA3221Channel:
152147
"""Represents a single channel of the INA3221.
153148
154149
Args:
155-
parent (Any): The parent INA3221 instance managing the I2C communication.
150+
device (Any): The device INA3221 instance managing the I2C communication.
156151
channel (int): The channel number (1, 2, or 3) for this instance.
157152
"""
158153

159-
def __init__(self, parent: Any, channel: int) -> None:
160-
self._parent = parent
154+
def __init__(self, device: Any, channel: int) -> None:
155+
self._device = device
161156
self._channel = channel
157+
self._shunt_resistance = 0.5
158+
159+
def enable(self) -> None:
160+
"""Enable this channel"""
161+
config = self._device._read_register(CONFIGURATION, 2)
162+
config_value = (config[0] << 8) | config[1]
163+
config_value |= 1 << (14 - self._channel) # Set the bit for the specific channel
164+
high_byte = (config_value >> 8) & 0xFF
165+
low_byte = config_value & 0xFF
166+
self._device._write_register(CONFIGURATION, bytes([high_byte, low_byte]))
162167

163168
@property
164169
def bus_voltage(self) -> float:
165170
"""Bus voltage in volts."""
166-
return self._parent._bus_voltage(self._channel)
171+
reg_address = [BUSVOLTAGE_CH1, BUSVOLTAGE_CH2, BUSVOLTAGE_CH3][self._channel]
172+
result = self._device._read_register(reg_address, 2)
173+
raw_value = int.from_bytes(result, "big")
174+
voltage = (raw_value >> 3) * 8e-3
175+
return voltage
167176

168177
@property
169178
def shunt_voltage(self) -> float:
170179
"""Shunt voltage in millivolts."""
171-
return self._parent._shunt_voltage(self._channel)
180+
reg_address = [SHUNTVOLTAGE_CH1, SHUNTVOLTAGE_CH2, SHUNTVOLTAGE_CH3][self._channel]
181+
result = self._device._read_register(reg_address, 2)
182+
raw_value = int.from_bytes(result, "big")
183+
raw_value = raw_value - 0x10000 if raw_value & 0x8000 else raw_value # convert to signed int16
184+
return (raw_value >> 3) * 40e-6
172185

173186
@property
174187
def shunt_resistance(self) -> float:
175188
"""Shunt resistance in ohms."""
176-
return self._parent._shunt_resistance[self._channel]
189+
return self._shunt_resistance
177190

178191
@shunt_resistance.setter
179192
def shunt_resistance(self, value: float) -> None:
180-
self._parent._shunt_resistance[self._channel] = value
193+
self._shunt_resistance = value
181194

182195
@property
183196
def current_amps(self) -> float:
@@ -189,25 +202,62 @@ def current_amps(self) -> float:
189202
shunt_voltage = self.shunt_voltage
190203
if shunt_voltage != shunt_voltage: # Check for NaN
191204
return float("nan")
192-
return shunt_voltage / self.shunt_resistance
205+
return shunt_voltage / self._shunt_resistance
206+
207+
@property
208+
def critical_alert_threshold(self) -> float:
209+
"""Critical-Alert threshold in amperes
210+
211+
Returns:
212+
float: The current critical alert threshold in amperes.
213+
"""
214+
reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel
215+
result = self._device._read_register(reg_addr, 2)
216+
threshold = int.from_bytes(result, "big")
217+
return (threshold >> 3) * 40e-6 / self._shunt_resistance
218+
219+
@critical_alert_threshold.setter
220+
def critical_alert_threshold(self, current: float) -> None:
221+
threshold = int(current * self._shunt_resistance / 40e-6 * 8)
222+
reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel
223+
threshold_bytes = threshold.to_bytes(2, "big")
224+
self._device._write_register(reg_addr, threshold_bytes)
225+
226+
@property
227+
def warning_alert_threshold(self) -> float:
228+
"""Warning-Alert threshold in amperes
229+
230+
Returns:
231+
float: The current warning alert threshold in amperes.
232+
"""
233+
reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel
234+
result = self._device._read_register(reg_addr, 2)
235+
threshold = int.from_bytes(result, "big")
236+
return threshold / (self._shunt_resistance * 8)
237+
238+
@warning_alert_threshold.setter
239+
def warning_alert_threshold(self, current: float) -> None:
240+
threshold = int(current * self._shunt_resistance * 8)
241+
reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel
242+
threshold_bytes = threshold.to_bytes(2, "big")
243+
self._device._write_register(reg_addr, threshold_bytes)
193244

194245

195246
class INA3221:
196247
"""Driver for the INA3221 device with three channels."""
197248

198-
def __init__(self, i2c, address: int = DEFAULT_ADDRESS) -> None:
249+
def __init__(self, i2c, address: int = DEFAULT_ADDRESS, enable: List = [0,1,2]) -> None:
199250
"""Initializes the INA3221 class over I2C
200251
Args:
201252
i2c (I2C): The I2C bus to which the INA3221 is connected.
202253
address (int, optional): The I2C address of the INA3221. Defaults to DEFAULT_ADDRESS.
203254
"""
204255
self.i2c_dev = I2CDevice(i2c, address)
205-
self._shunt_resistance: List[float] = [0.05, 0.05, 0.05] # Default shunt resistances
206256
self.reset()
207257

208258
self.channels: List[INA3221Channel] = [INA3221Channel(self, i) for i in range(3)]
209-
for i in range(3):
210-
self.enable_channel(i)
259+
for i in enable:
260+
self.channels[i].enable()
211261
self.mode: int = MODE.SHUNT_BUS_CONT
212262
self.shunt_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS
213263
self.bus_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS
@@ -238,25 +288,6 @@ def reset(self) -> None:
238288
config[0] |= 0x80 # Set the reset bit
239289
return self._write_register(CONFIGURATION, config)
240290

241-
def enable_channel(self, channel: int) -> None:
242-
"""Enable a specific channel of the INA3221.
243-
244-
Args:
245-
channel (int): The channel number to enable (0, 1, or 2).
246-
247-
Raises:
248-
ValueError: If the channel number is invalid (must be 0, 1, or 2).
249-
"""
250-
if channel > 2:
251-
raise ValueError("Invalid channel number. Must be 0, 1, or 2.")
252-
253-
config = self._read_register(CONFIGURATION, 2)
254-
config_value = (config[0] << 8) | config[1]
255-
config_value |= 1 << (14 - channel) # Set the bit for the specific channel
256-
high_byte = (config_value >> 8) & 0xFF
257-
low_byte = config_value & 0xFF
258-
self._write_register(CONFIGURATION, bytes([high_byte, low_byte]))
259-
260291
@property
261292
def die_id(self) -> int:
262293
"""Die ID of the INA3221.
@@ -362,56 +393,6 @@ def averaging_mode(self, mode: int) -> None:
362393
config[1] = (config[1] & 0xF1) | (mode << 1)
363394
self._write_register(CONFIGURATION, config)
364395

365-
@property
366-
def critical_alert_threshold(self) -> float:
367-
"""Critical-Alert threshold in amperes
368-
369-
Returns:
370-
float: The current critical alert threshold in amperes.
371-
"""
372-
if self._channel > 2:
373-
raise ValueError("Invalid channel number. Must be 0, 1, or 2.")
374-
375-
reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel
376-
result = self._parent._read_register(reg_addr, 2)
377-
threshold = int.from_bytes(result, "big")
378-
return (threshold >> 3) * 40e-6 / self.shunt_resistance
379-
380-
@critical_alert_threshold.setter
381-
def critical_alert_threshold(self, current: float) -> None:
382-
if self._channel > 2:
383-
raise ValueError("Invalid channel number. Must be 0, 1, or 2.")
384-
385-
threshold = int(current * self.shunt_resistance / 40e-6 * 8)
386-
reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel
387-
threshold_bytes = threshold.to_bytes(2, "big")
388-
self._parent._write_register(reg_addr, threshold_bytes)
389-
390-
@property
391-
def warning_alert_threshold(self) -> float:
392-
"""Warning-Alert threshold in amperes
393-
394-
Returns:
395-
float: The current warning alert threshold in amperes.
396-
"""
397-
if self._channel > 2:
398-
raise ValueError("Invalid channel number. Must be 0, 1, or 2.")
399-
400-
reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel
401-
result = self._parent._read_register(reg_addr, 2)
402-
threshold = int.from_bytes(result, "big")
403-
return threshold / (self.shunt_resistance * 8)
404-
405-
@warning_alert_threshold.setter
406-
def warning_alert_threshold(self, current: float) -> None:
407-
if self._channel > 2:
408-
raise ValueError("Invalid channel number. Must be 0, 1, or 2.")
409-
410-
threshold = int(current * self.shunt_resistance * 8)
411-
reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel
412-
threshold_bytes = threshold.to_bytes(2, "big")
413-
self._parent._write_register(reg_addr, threshold_bytes)
414-
415396
@property
416397
def flags(self) -> int:
417398
"""Flag indicators from the Mask/Enable register.
@@ -497,37 +478,6 @@ def _to_signed(self, val, bits):
497478
val -= 1 << bits
498479
return val
499480

500-
def _shunt_voltage(self, channel):
501-
if channel > 2:
502-
raise ValueError("Must be channel 0, 1 or 2")
503-
reg_address = [SHUNTVOLTAGE_CH1, SHUNTVOLTAGE_CH2, SHUNTVOLTAGE_CH3][channel]
504-
result = self._read_register(reg_address, 2)
505-
raw_value = int.from_bytes(result, "big")
506-
raw_value = self._to_signed(raw_value, 16)
507-
508-
return (raw_value >> 3) * 40e-6
509-
510-
def _bus_voltage(self, channel):
511-
if channel > 2:
512-
raise ValueError("Must be channel 0, 1 or 2")
513-
514-
reg_address = [BUSVOLTAGE_CH1, BUSVOLTAGE_CH2, BUSVOLTAGE_CH3][channel]
515-
result = self._read_register(reg_address, 2)
516-
raw_value = int.from_bytes(result, "big")
517-
voltage = (raw_value >> 3) * 8e-3
518-
519-
return voltage
520-
521-
def _current_amps(self, channel):
522-
if channel >= 3:
523-
raise ValueError("Must be channel 0, 1 or 2")
524-
525-
shunt_voltage = self._shunt_voltage(channel)
526-
if shunt_voltage != shunt_voltage:
527-
raise ValueError("Must be channel 0, 1 or 2")
528-
529-
return shunt_voltage / self._shunt_resistance[channel]
530-
531481
def _write_register(self, reg, data):
532482
with self.i2c_dev:
533483
self.i2c_dev.write(bytes([reg]) + data)
@@ -538,7 +488,7 @@ def _read_register(self, reg, length):
538488
with self.i2c_dev:
539489
self.i2c_dev.write(bytes([reg]))
540490
self.i2c_dev.readinto(result)
541-
except OSError as e:
542-
print(f"I2C error: {e}")
491+
except OSError as ex:
492+
print(f"I2C error: {ex}")
543493
return None
544494
return result

0 commit comments

Comments
 (0)