Skip to content

Commit 894f754

Browse files
committed
Rudimentary flags implementation
1 parent cdc6050 commit 894f754

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

torch_np/_ndarray.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,54 @@
1414

1515
newaxis = None
1616

17+
FLAGS = [
18+
"C_CONTIGUOUS",
19+
"F_CONTIGUOUS",
20+
"OWNDATA",
21+
"WRITEABLE",
22+
"ALIGNED",
23+
"WRITEBACKIFCOPY",
24+
"FNC",
25+
"FORC",
26+
"BEHAVED",
27+
"CARRAY",
28+
"FARRAY",
29+
]
30+
31+
SHORTHAND_TO_FLAGS = {
32+
"C": "C_CONTIGUOUS",
33+
"F": "F_CONTIGUOUS",
34+
"O": "OWNDATA",
35+
"W": "WRITEABLE",
36+
"A": "ALIGNED",
37+
"X": "WRITEBACKIFCOPY",
38+
"B": "BEHAVED",
39+
"CA": "CARRAY",
40+
"FA": "FARRAY",
41+
}
42+
43+
44+
class Flags:
45+
def __init__(self, flag_to_value: dict):
46+
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
47+
self._flag_to_value = flag_to_value
48+
49+
def __getattr__(self, attr: str):
50+
if attr.islower() and attr.upper() in FLAGS:
51+
return self[attr.upper()]
52+
raise AttributeError(f"No attribute '{attr}'")
53+
54+
def __getitem__(self, key):
55+
if key in SHORTHAND_TO_FLAGS.keys():
56+
key = SHORTHAND_TO_FLAGS[key]
57+
if key in FLAGS:
58+
try:
59+
return self._flag_to_value[key]
60+
except KeyError as e:
61+
raise NotImplementedError(f"{key=}") from e
62+
else:
63+
raise KeyError(f"{key=}")
64+
1765

1866
##################### ndarray class ###########################
1967

@@ -58,6 +106,17 @@ def strides(self):
58106
def base(self):
59107
return self._base
60108

109+
@property
110+
def flags(self):
111+
# Note contiguous in torch is assumed C-style
112+
flag_to_value = {"C_CONTIGUOUS": self._tensor.is_contiguous()}
113+
if flag_to_value["C_CONTIGUOUS"]:
114+
# There's no proper way to determine if a tensor is Fortran-style
115+
# contiguous in torch, but at least we know it isn't when it is
116+
# C-style.
117+
flag_to_value["F_CONTIGUOUS"] = False
118+
return Flags(flag_to_value)
119+
61120
@property
62121
def T(self):
63122
return self.transpose()

0 commit comments

Comments
 (0)