Skip to content

Commit 3b28331

Browse files
committed
ENH: implement flags.f_contiguous
1 parent 478df33 commit 3b28331

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

torch_np/_ndarray.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,16 @@ def base(self):
112112
@property
113113
def flags(self):
114114
# Note contiguous in torch is assumed C-style
115-
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()})
115+
116+
# check if F contiguous
117+
from itertools import accumulate
118+
f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x*y))
119+
f_strides = (1,) + f_strides[:-1]
120+
is_f_contiguous = f_strides == self._tensor.stride()
121+
122+
return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous(),
123+
"F_CONTIGUOUS": is_f_contiguous,}
124+
)
116125

117126
@property
118127
def T(self):

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def test_subclass_writeable(self):
377377
assert_(d[...].flags.writeable)
378378
assert_(d[0].flags.writeable)
379379

380-
@pytest.mark.xfail(reason="can't determine f-style contiguous in torch")
381380
def test_memory_order(self):
382381
# This is not necessary to preserve. Memory layouts for
383382
# more complex indices are not as simple.

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2194,7 +2194,7 @@ def check_function(self, func, fill_value=None):
21942194
**fill_kwarg)
21952195

21962196
assert_equal(arr.dtype, dtype)
2197-
# assert_(getattr(arr.flags, self.orders[order])) # XXX: no ndarray.flags
2197+
assert_(getattr(arr.flags, self.orders[order]))
21982198

21992199
if fill_value is not None:
22002200
val = fill_value

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_basic(self):
272272
assert_equal(a[0, 0], 1)
273273
assert_equal(a_copy[0, 0], 10)
274274

275-
@pytest.mark.xfail(reason="ndarray.flags not implemented")
275+
@pytest.mark.xfail(reason="order='F' not implemented")
276276
def test_order(self):
277277
# It turns out that people rely on np.copy() preserving order by
278278
# default; changing this broke scikit-learn:

0 commit comments

Comments
 (0)