Skip to content

Basic ndarray.flags implementation #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,55 @@

newaxis = None

FLAGS = [
"C_CONTIGUOUS",
"F_CONTIGUOUS",
"OWNDATA",
"WRITEABLE",
"ALIGNED",
"WRITEBACKIFCOPY",
"FNC",
"FORC",
"BEHAVED",
"CARRAY",
"FARRAY",
]

SHORTHAND_TO_FLAGS = {
"C": "C_CONTIGUOUS",
"F": "F_CONTIGUOUS",
"O": "OWNDATA",
"W": "WRITEABLE",
"A": "ALIGNED",
"X": "WRITEBACKIFCOPY",
"B": "BEHAVED",
"CA": "CARRAY",
"FA": "FARRAY",
}


class Flags:
def __init__(self, flag_to_value: dict):
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
self._flag_to_value = flag_to_value

def __getattr__(self, attr: str):
if attr.islower() and attr.upper() in FLAGS:
return self[attr.upper()]
else:
raise AttributeError(f"No flag attribute '{attr}'")

def __getitem__(self, key):
if key in SHORTHAND_TO_FLAGS.keys():
key = SHORTHAND_TO_FLAGS[key]
if key in FLAGS:
try:
return self._flag_to_value[key]
except KeyError as e:
raise NotImplementedError(f"{key=}") from e
else:
raise KeyError(f"No flag key '{key}'")


##################### ndarray class ###########################

Expand Down Expand Up @@ -58,6 +107,11 @@ def strides(self):
def base(self):
return self._base

@property
def flags(self):
# Note contiguous in torch is assumed C-style
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()})

@property
def T(self):
return self.transpose()
Expand Down