Skip to content

Commit 61f431b

Browse files
committed
Rudimentary flags implementation
1 parent cdc6050 commit 61f431b

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

torch_np/_ndarray.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,55 @@
1414

1515
newaxis = None
1616

17+
FLAGS = [
18+
"C_CONTIGUOUS",
19+
"F_CONTIGUOUS",
20+
"OWNDATA",
21+
"WRITEABLE",
22+
"ALIGNED",
23+
"WRITEBACKIFCOPY",
24+
"FNC",
25+
"FORC",
26+
"BEHAVED",
27+
"CARRAY",
28+
"FARRAY",
29+
]
30+
31+
SHORTHAND_TO_FLAGS = {
32+
"C": "C_CONTIGUOUS",
33+
"F": "F_CONTIGUOUS",
34+
"O": "OWNDATA",
35+
"W": "WRITEABLE",
36+
"A": "ALIGNED",
37+
"X": "WRITEBACKIFCOPY",
38+
"B": "BEHAVED",
39+
"CA": "CARRAY",
40+
"FA": "FARRAY",
41+
}
42+
43+
44+
class Flags:
45+
def __init__(self, flag_to_value: dict):
46+
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
47+
self._flag_to_value = flag_to_value
48+
49+
def __getattr__(self, attr: str):
50+
if attr.islower() and attr.upper() in FLAGS:
51+
return self[attr.upper()]
52+
else:
53+
raise AttributeError(f"No flag '{attr}'")
54+
55+
def __getitem__(self, key):
56+
if key in SHORTHAND_TO_FLAGS.keys():
57+
key = SHORTHAND_TO_FLAGS[key]
58+
if key in FLAGS:
59+
try:
60+
return self._flag_to_value[key]
61+
except KeyError as e:
62+
raise NotImplementedError(f"{key=}") from e
63+
else:
64+
raise KeyError(f"No flag '{key}'")
65+
1766

1867
##################### ndarray class ###########################
1968

@@ -58,6 +107,17 @@ def strides(self):
58107
def base(self):
59108
return self._base
60109

110+
@property
111+
def flags(self):
112+
# Note contiguous in torch is assumed C-style
113+
flag_to_value = {"C_CONTIGUOUS": self._tensor.is_contiguous()}
114+
if flag_to_value["C_CONTIGUOUS"]:
115+
# There's no proper way to determine if a tensor is Fortran-style
116+
# contiguous in torch, but at least we know it isn't when it is
117+
# C-style.
118+
flag_to_value["F_CONTIGUOUS"] = False
119+
return Flags(flag_to_value)
120+
61121
@property
62122
def T(self):
63123
return self.transpose()

0 commit comments

Comments
 (0)