Skip to content

Commit ef2e852

Browse files
authored
Merge pull request #49 from honno/flags
Basic `ndarray.flags` implementation
2 parents 65907fc + c569064 commit ef2e852

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

torch_np/_ndarray.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,55 @@
1515

1616
newaxis = None
1717

18+
FLAGS = [
19+
"C_CONTIGUOUS",
20+
"F_CONTIGUOUS",
21+
"OWNDATA",
22+
"WRITEABLE",
23+
"ALIGNED",
24+
"WRITEBACKIFCOPY",
25+
"FNC",
26+
"FORC",
27+
"BEHAVED",
28+
"CARRAY",
29+
"FARRAY",
30+
]
31+
32+
SHORTHAND_TO_FLAGS = {
33+
"C": "C_CONTIGUOUS",
34+
"F": "F_CONTIGUOUS",
35+
"O": "OWNDATA",
36+
"W": "WRITEABLE",
37+
"A": "ALIGNED",
38+
"X": "WRITEBACKIFCOPY",
39+
"B": "BEHAVED",
40+
"CA": "CARRAY",
41+
"FA": "FARRAY",
42+
}
43+
44+
45+
class Flags:
46+
def __init__(self, flag_to_value: dict):
47+
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
48+
self._flag_to_value = flag_to_value
49+
50+
def __getattr__(self, attr: str):
51+
if attr.islower() and attr.upper() in FLAGS:
52+
return self[attr.upper()]
53+
else:
54+
raise AttributeError(f"No flag attribute '{attr}'")
55+
56+
def __getitem__(self, key):
57+
if key in SHORTHAND_TO_FLAGS.keys():
58+
key = SHORTHAND_TO_FLAGS[key]
59+
if key in FLAGS:
60+
try:
61+
return self._flag_to_value[key]
62+
except KeyError as e:
63+
raise NotImplementedError(f"{key=}") from e
64+
else:
65+
raise KeyError(f"No flag key '{key}'")
66+
1867

1968
##################### ndarray class ###########################
2069

@@ -59,6 +108,11 @@ def strides(self):
59108
def base(self):
60109
return self._base
61110

111+
@property
112+
def flags(self):
113+
# Note contiguous in torch is assumed C-style
114+
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()})
115+
62116
@property
63117
def T(self):
64118
return self.transpose()

0 commit comments

Comments
 (0)