diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index a4d8fb21..4df40c4f 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -112,7 +112,22 @@ def base(self): @property def flags(self): # Note contiguous in torch is assumed C-style - return Flags({"C_CONTIGUOUS": self._tensor.is_contiguous()}) + + # check if F contiguous + from itertools import accumulate + + f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x * y)) + f_strides = (1,) + f_strides[:-1] + is_f_contiguous = f_strides == self._tensor.stride() + + return Flags( + { + "C_CONTIGUOUS": self._tensor.is_contiguous(), + "F_CONTIGUOUS": is_f_contiguous, + "OWNDATA": self._tensor._base is None, + "WRITEABLE": True, # pytorch does not have readonly tensors + } + ) @property def T(self): diff --git a/torch_np/tests/numpy_tests/core/test_indexing.py b/torch_np/tests/numpy_tests/core/test_indexing.py index 58cf4c2a..ba6ef589 100644 --- a/torch_np/tests/numpy_tests/core/test_indexing.py +++ b/torch_np/tests/numpy_tests/core/test_indexing.py @@ -377,7 +377,6 @@ def test_subclass_writeable(self): assert_(d[...].flags.writeable) assert_(d[0].flags.writeable) - @pytest.mark.xfail(reason="can't determine f-style contiguous in torch") def test_memory_order(self): # This is not necessary to preserve. Memory layouts for # more complex indices are not as simple. diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index c735fa81..83fc3d87 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -2194,7 +2194,7 @@ def check_function(self, func, fill_value=None): **fill_kwarg) assert_equal(arr.dtype, dtype) - # assert_(getattr(arr.flags, self.orders[order])) # XXX: no ndarray.flags + assert_(getattr(arr.flags, self.orders[order])) if fill_value is not None: val = fill_value diff --git a/torch_np/tests/numpy_tests/lib/test_function_base.py b/torch_np/tests/numpy_tests/lib/test_function_base.py index c929a602..ea1b7ddc 100644 --- a/torch_np/tests/numpy_tests/lib/test_function_base.py +++ b/torch_np/tests/numpy_tests/lib/test_function_base.py @@ -272,7 +272,7 @@ def test_basic(self): assert_equal(a[0, 0], 1) assert_equal(a_copy[0, 0], 10) - @pytest.mark.xfail(reason="ndarray.flags not implemented") + @pytest.mark.xfail(reason="order='F' not implemented") def test_order(self): # It turns out that people rely on np.copy() preserving order by # default; changing this broke scikit-learn: