diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index dde3cd7e..2004b03d 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -14,6 +14,55 @@ newaxis = None +FLAGS = [ + "C_CONTIGUOUS", + "F_CONTIGUOUS", + "OWNDATA", + "WRITEABLE", + "ALIGNED", + "WRITEBACKIFCOPY", + "FNC", + "FORC", + "BEHAVED", + "CARRAY", + "FARRAY", +] + +SHORTHAND_TO_FLAGS = { + "C": "C_CONTIGUOUS", + "F": "F_CONTIGUOUS", + "O": "OWNDATA", + "W": "WRITEABLE", + "A": "ALIGNED", + "X": "WRITEBACKIFCOPY", + "B": "BEHAVED", + "CA": "CARRAY", + "FA": "FARRAY", +} + + +class Flags: + def __init__(self, flag_to_value: dict): + assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + self._flag_to_value = flag_to_value + + def __getattr__(self, attr: str): + if attr.islower() and attr.upper() in FLAGS: + return self[attr.upper()] + else: + raise AttributeError(f"No flag attribute '{attr}'") + + def __getitem__(self, key): + if key in SHORTHAND_TO_FLAGS.keys(): + key = SHORTHAND_TO_FLAGS[key] + if key in FLAGS: + try: + return self._flag_to_value[key] + except KeyError as e: + raise NotImplementedError(f"{key=}") from e + else: + raise KeyError(f"No flag key '{key}'") + ##################### ndarray class ########################### @@ -58,6 +107,11 @@ def strides(self): def base(self): return self._base + @property + def flags(self): + # Note contiguous in torch is assumed C-style + return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()}) + @property def T(self): return self.transpose()