File tree 1 file changed +54
-0
lines changed
1 file changed +54
-0
lines changed Original file line number Diff line number Diff line change 15
15
16
16
newaxis = None
17
17
18
+ FLAGS = [
19
+ "C_CONTIGUOUS" ,
20
+ "F_CONTIGUOUS" ,
21
+ "OWNDATA" ,
22
+ "WRITEABLE" ,
23
+ "ALIGNED" ,
24
+ "WRITEBACKIFCOPY" ,
25
+ "FNC" ,
26
+ "FORC" ,
27
+ "BEHAVED" ,
28
+ "CARRAY" ,
29
+ "FARRAY" ,
30
+ ]
31
+
32
+ SHORTHAND_TO_FLAGS = {
33
+ "C" : "C_CONTIGUOUS" ,
34
+ "F" : "F_CONTIGUOUS" ,
35
+ "O" : "OWNDATA" ,
36
+ "W" : "WRITEABLE" ,
37
+ "A" : "ALIGNED" ,
38
+ "X" : "WRITEBACKIFCOPY" ,
39
+ "B" : "BEHAVED" ,
40
+ "CA" : "CARRAY" ,
41
+ "FA" : "FARRAY" ,
42
+ }
43
+
44
+
45
+ class Flags :
46
+ def __init__ (self , flag_to_value : dict ):
47
+ assert all (k in FLAGS for k in flag_to_value .keys ()) # sanity check
48
+ self ._flag_to_value = flag_to_value
49
+
50
+ def __getattr__ (self , attr : str ):
51
+ if attr .islower () and attr .upper () in FLAGS :
52
+ return self [attr .upper ()]
53
+ else :
54
+ raise AttributeError (f"No flag attribute '{ attr } '" )
55
+
56
+ def __getitem__ (self , key ):
57
+ if key in SHORTHAND_TO_FLAGS .keys ():
58
+ key = SHORTHAND_TO_FLAGS [key ]
59
+ if key in FLAGS :
60
+ try :
61
+ return self ._flag_to_value [key ]
62
+ except KeyError as e :
63
+ raise NotImplementedError (f"{ key = } " ) from e
64
+ else :
65
+ raise KeyError (f"No flag key '{ key } '" )
66
+
18
67
19
68
##################### ndarray class ###########################
20
69
@@ -59,6 +108,11 @@ def strides(self):
59
108
def base (self ):
60
109
return self ._base
61
110
111
+ @property
112
+ def flags (self ):
113
+ # Note contiguous in torch is assumed C-style
114
+ return Flags ({"C_CONTIGUOUS" : self ._tensor .is_contiguous ()})
115
+
62
116
@property
63
117
def T (self ):
64
118
return self .transpose ()
You can’t perform that action at this time.
0 commit comments