@@ -44,7 +44,7 @@ def cumulative_sum(
44
44
if axis < 0 :
45
45
axis += x .ndim
46
46
x = concat ([zeros (x .shape [:axis ] + (1 ,) + x .shape [axis + 1 :], dtype = dt ), x ], axis = axis )
47
- return Array ._new (np .cumsum (x ._array , axis = axis , dtype = dtype ))
47
+ return Array ._new (np .cumsum (x ._array , axis = axis , dtype = dtype ), device = x . device )
48
48
49
49
def max (
50
50
x : Array ,
@@ -55,7 +55,7 @@ def max(
55
55
) -> Array :
56
56
if x .dtype not in _real_numeric_dtypes :
57
57
raise TypeError ("Only real numeric dtypes are allowed in max" )
58
- return Array ._new (np .max (x ._array , axis = axis , keepdims = keepdims ))
58
+ return Array ._new (np .max (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
59
59
60
60
61
61
def mean (
@@ -67,7 +67,7 @@ def mean(
67
67
) -> Array :
68
68
if x .dtype not in _real_floating_dtypes :
69
69
raise TypeError ("Only real floating-point dtypes are allowed in mean" )
70
- return Array ._new (np .mean (x ._array , axis = axis , keepdims = keepdims ))
70
+ return Array ._new (np .mean (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
71
71
72
72
73
73
def min (
@@ -79,7 +79,7 @@ def min(
79
79
) -> Array :
80
80
if x .dtype not in _real_numeric_dtypes :
81
81
raise TypeError ("Only real numeric dtypes are allowed in min" )
82
- return Array ._new (np .min (x ._array , axis = axis , keepdims = keepdims ))
82
+ return Array ._new (np .min (x ._array , axis = axis , keepdims = keepdims ), device = x . device )
83
83
84
84
85
85
def prod (
@@ -104,7 +104,7 @@ def prod(
104
104
dtype = np .complex128
105
105
else :
106
106
dtype = dtype ._np_dtype
107
- return Array ._new (np .prod (x ._array , dtype = dtype , axis = axis , keepdims = keepdims ))
107
+ return Array ._new (np .prod (x ._array , dtype = dtype , axis = axis , keepdims = keepdims ), device = x . device )
108
108
109
109
110
110
def std (
@@ -118,7 +118,7 @@ def std(
118
118
# Note: the keyword argument correction is different here
119
119
if x .dtype not in _real_floating_dtypes :
120
120
raise TypeError ("Only real floating-point dtypes are allowed in std" )
121
- return Array ._new (np .std (x ._array , axis = axis , ddof = correction , keepdims = keepdims ))
121
+ return Array ._new (np .std (x ._array , axis = axis , ddof = correction , keepdims = keepdims ), device = x . device )
122
122
123
123
124
124
def sum (
@@ -143,7 +143,7 @@ def sum(
143
143
dtype = np .complex128
144
144
else :
145
145
dtype = dtype ._np_dtype
146
- return Array ._new (np .sum (x ._array , axis = axis , dtype = dtype , keepdims = keepdims ))
146
+ return Array ._new (np .sum (x ._array , axis = axis , dtype = dtype , keepdims = keepdims ), device = x . device )
147
147
148
148
149
149
def var (
@@ -157,4 +157,4 @@ def var(
157
157
# Note: the keyword argument correction is different here
158
158
if x .dtype not in _real_floating_dtypes :
159
159
raise TypeError ("Only real floating-point dtypes are allowed in var" )
160
- return Array ._new (np .var (x ._array , axis = axis , ddof = correction , keepdims = keepdims ))
160
+ return Array ._new (np .var (x ._array , axis = axis , ddof = correction , keepdims = keepdims ), device = x . device )
0 commit comments