@@ -97,14 +97,13 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
97
97
98
98
# #### concatenate and relatives
99
99
100
+
100
101
def concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
101
102
# np.concatenate ravels if axis=None
102
103
tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
103
104
104
- # figure out the type of the inputs and outputs
105
- if out is None and dtype is None :
106
- out_dtype = None
107
- else :
105
+ if out is not None or dtype is not None :
106
+ # figure out the type of the inputs and outputs
108
107
out_dtype = out .dtype .torch_dtype if dtype is None else dtype
109
108
110
109
# cast input arrays if necessary; do not broadcast them agains `out`
@@ -120,7 +119,8 @@ def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
120
119
121
120
# #### cov & corrcoef
122
121
123
- def corrcoef (xy_tensor , rowvar = True , * , dtype = None ):
122
+
123
+ def corrcoef (xy_tensor , rowvar = True , * , dtype = None ):
124
124
if rowvar is False :
125
125
# xy_tensor is at least 2D, so using .T is safe
126
126
xy_tensor = x_tensor .T
@@ -194,4 +194,3 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
194
194
output = [x .clone () for x in output ]
195
195
196
196
return output
197
-
0 commit comments