Skip to content

Commit 5f423d2

Browse files
committed
ENH: implement flags.f_contiguous
1 parent 23c2bf3 commit 5f423d2

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

torch_np/_ndarray.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,16 @@ def base(self):
112112
@property
113113
def flags(self):
114114
# Note contiguous in torch is assumed C-style
115-
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()})
115+
116+
# check if F contiguous
117+
from itertools import accumulate
118+
f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x*y))
119+
f_strides = (1,) + f_strides[:-1]
120+
is_f_contiguous = f_strides == self._tensor.stride()
121+
122+
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous(),
123+
"F_CONTIGUOUS": is_f_contiguous,}
124+
)
116125

117126
@property
118127
def T(self):

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ def test_subclass_writeable(self):
436436
assert_(d[...].flags.writeable)
437437
assert_(d[0].flags.writeable)
438438

439-
@pytest.mark.xfail(reason="can't determine f-style contiguous in torch")
440439
def test_memory_order(self):
441440
# This is not necessary to preserve. Memory layouts for
442441
# more complex indices are not as simple.

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2195,7 +2195,7 @@ def check_function(self, func, fill_value=None):
21952195
**fill_kwarg)
21962196

21972197
assert_equal(arr.dtype, dtype)
2198-
# assert_(getattr(arr.flags, self.orders[order])) # XXX: no ndarray.flags
2198+
assert_(getattr(arr.flags, self.orders[order]))
21992199

22002200
if fill_value is not None:
22012201
val = fill_value

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_basic(self):
272272
assert_equal(a[0, 0], 1)
273273
assert_equal(a_copy[0, 0], 10)
274274

275-
@pytest.mark.xfail(reason="ndarray.flags not implemented")
275+
@pytest.mark.xfail(reason="order='F' not implemented")
276276
def test_order(self):
277277
# It turns out that people rely on np.copy() preserving order by
278278
# default; changing this broke scikit-learn:

0 commit comments

Comments
 (0)