Skip to content

Basic ndarray.flags implementation #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,55 @@

newaxis = None

FLAGS = [
"C_CONTIGUOUS",
"F_CONTIGUOUS",
"OWNDATA",
"WRITEABLE",
"ALIGNED",
"WRITEBACKIFCOPY",
"FNC",
"FORC",
"BEHAVED",
"CARRAY",
"FARRAY",
]

SHORTHAND_TO_FLAGS = {
"C": "C_CONTIGUOUS",
"F": "F_CONTIGUOUS",
"O": "OWNDATA",
"W": "WRITEABLE",
"A": "ALIGNED",
"X": "WRITEBACKIFCOPY",
"B": "BEHAVED",
"CA": "CARRAY",
"FA": "FARRAY",
}


class Flags:
def __init__(self, flag_to_value: dict):
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
self._flag_to_value = flag_to_value

def __getattr__(self, attr: str):
if attr.islower() and attr.upper() in FLAGS:
return self[attr.upper()]
else:
raise AttributeError(f"No flag attribute '{attr}'")

def __getitem__(self, key):
if key in SHORTHAND_TO_FLAGS.keys():
key = SHORTHAND_TO_FLAGS[key]
if key in FLAGS:
try:
return self._flag_to_value[key]
except KeyError as e:
raise NotImplementedError(f"{key=}") from e
else:
raise KeyError(f"No flag key '{key}'")


##################### ndarray class ###########################

Expand Down Expand Up @@ -58,6 +107,17 @@ def strides(self):
def base(self):
return self._base

@property
def flags(self):
# Note contiguous in torch is assumed C-style
flag_to_value = {"C_CONTIGUOUS": self._tensor.is_contiguous()}
if flag_to_value["C_CONTIGUOUS"]:
# There's no proper way to determine if a tensor is Fortran-style
# contiguous in torch, but at least we know it isn't when it is
# C-style.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, 1D arrays can be both C and F contiguous. Not sure how deep into that rabbit hole we need fall though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right—this was a bit magical to begin with so removed.

flag_to_value["F_CONTIGUOUS"] = False
return Flags(flag_to_value)

@property
def T(self):
return self.transpose()
Expand Down