Skip to content

Commit 6a86527

Browse files
authored
Merge pull request #62 from Quansight-Labs/f_contiguous
add ndarrays.flags.f_contiguous
2 parents 478df33 + d42dc66 commit 6a86527

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

torch_np/_ndarray.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,22 @@ def base(self):
112112
@property
113113
def flags(self):
114114
# Note contiguous in torch is assumed C-style
115-
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()})
115+
116+
# check if F contiguous
117+
from itertools import accumulate
118+
119+
f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x * y))
120+
f_strides = (1,) + f_strides[:-1]
121+
is_f_contiguous = f_strides == self._tensor.stride()
122+
123+
return Flags(
124+
{
125+
"C_CONTIGUOUS": self._tensor.is_contiguous(),
126+
"F_CONTIGUOUS": is_f_contiguous,
127+
"OWNDATA": self._tensor._base is None,
128+
"WRITEABLE": True, # pytorch does not have readonly tensors
129+
}
130+
)
116131

117132
@property
118133
def T(self):

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def test_subclass_writeable(self):
377377
assert_(d[...].flags.writeable)
378378
assert_(d[0].flags.writeable)
379379

380-
@pytest.mark.xfail(reason="can't determine f-style contiguous in torch")
381380
def test_memory_order(self):
382381
# This is not necessary to preserve. Memory layouts for
383382
# more complex indices are not as simple.

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2194,7 +2194,7 @@ def check_function(self, func, fill_value=None):
21942194
**fill_kwarg)
21952195

21962196
assert_equal(arr.dtype, dtype)
2197-
# assert_(getattr(arr.flags, self.orders[order])) # XXX: no ndarray.flags
2197+
assert_(getattr(arr.flags, self.orders[order]))
21982198

21992199
if fill_value is not None:
22002200
val = fill_value

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_basic(self):
272272
assert_equal(a[0, 0], 1)
273273
assert_equal(a_copy[0, 0], 10)
274274

275-
@pytest.mark.xfail(reason="ndarray.flags not implemented")
275+
@pytest.mark.xfail(reason="order='F' not implemented")
276276
def test_order(self):
277277
# It turns out that people rely on np.copy() preserving order by
278278
# default; changing this broke scikit-learn:

0 commit comments

Comments
 (0)