|
14 | 14 |
|
15 | 15 | newaxis = None
|
16 | 16 |
|
| 17 | +FLAGS = [ |
| 18 | + "C_CONTIGUOUS", |
| 19 | + "F_CONTIGUOUS", |
| 20 | + "OWNDATA", |
| 21 | + "WRITEABLE", |
| 22 | + "ALIGNED", |
| 23 | + "WRITEBACKIFCOPY", |
| 24 | + "FNC", |
| 25 | + "FORC", |
| 26 | + "BEHAVED", |
| 27 | + "CARRAY", |
| 28 | + "FARRAY", |
| 29 | +] |
| 30 | + |
| 31 | +SHORTHAND_TO_FLAGS = { |
| 32 | + "C": "C_CONTIGUOUS", |
| 33 | + "F": "F_CONTIGUOUS", |
| 34 | + "O": "OWNDATA", |
| 35 | + "W": "WRITEABLE", |
| 36 | + "A": "ALIGNED", |
| 37 | + "X": "WRITEBACKIFCOPY", |
| 38 | + "B": "BEHAVED", |
| 39 | + "CA": "CARRAY", |
| 40 | + "FA": "FARRAY", |
| 41 | +} |
| 42 | + |
| 43 | + |
| 44 | +class Flags: |
| 45 | + def __init__(self, flag_to_value: dict): |
| 46 | + assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check |
| 47 | + self._flag_to_value = flag_to_value |
| 48 | + |
| 49 | + def __getattr__(self, attr: str): |
| 50 | + if attr.islower() and attr.upper() in FLAGS: |
| 51 | + return self[attr.upper()] |
| 52 | + raise AttributeError(f"No attribute '{attr}'") |
| 53 | + |
| 54 | + def __getitem__(self, key): |
| 55 | + if key in SHORTHAND_TO_FLAGS.keys(): |
| 56 | + key = SHORTHAND_TO_FLAGS[key] |
| 57 | + if key in FLAGS: |
| 58 | + try: |
| 59 | + return self._flag_to_value[key] |
| 60 | + except KeyError as e: |
| 61 | + raise NotImplementedError(f"{key=}") from e |
| 62 | + else: |
| 63 | + raise KeyError(f"{key=}") |
| 64 | + |
17 | 65 |
|
18 | 66 | ##################### ndarray class ###########################
|
19 | 67 |
|
@@ -58,6 +106,17 @@ def strides(self):
|
58 | 106 | def base(self):
|
59 | 107 | return self._base
|
60 | 108 |
|
| 109 | + @property |
| 110 | + def flags(self): |
| 111 | + # Note contiguous in torch is assumed C-style |
| 112 | + flag_to_value = {"C_CONTIGUOUS": self._tensor.is_contiguous()} |
| 113 | + if flag_to_value["C_CONTIGUOUS"]: |
| 114 | + # There's no proper way to determine if a tensor is Fortran-style |
| 115 | + # contiguous in torch, but at least we know it isn't when it is |
| 116 | + # C-style. |
| 117 | + flag_to_value["F_CONTIGUOUS"] = False |
| 118 | + return Flags(flag_to_value) |
| 119 | + |
61 | 120 | @property
|
62 | 121 | def T(self):
|
63 | 122 | return self.transpose()
|
|
0 commit comments