Skip to content

Commit 4798f75

Browse files
committed
MAINT: trivially simplify concat
1 parent 6571b12 commit 4798f75

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

torch_np/_detail/implementations.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,13 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
9797

9898
# #### concatenate and relatives
9999

100+
100101
def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
101102
# np.concatenate ravels if axis=None
102103
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
103104

104-
# figure out the type of the inputs and outputs
105-
if out is None and dtype is None:
106-
out_dtype = None
107-
else:
105+
if out is not None or dtype is not None:
106+
# figure out the type of the inputs and outputs
108107
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
109108

110109
# cast input arrays if necessary; do not broadcast them agains `out`
@@ -120,7 +119,8 @@ def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
120119

121120
# #### cov & corrcoef
122121

123-
def corrcoef(xy_tensor, rowvar=True, *, dtype=None):
122+
123+
def corrcoef(xy_tensor, rowvar=True, *, dtype=None):
124124
if rowvar is False:
125125
# xy_tensor is at least 2D, so using .T is safe
126126
xy_tensor = x_tensor.T
@@ -194,4 +194,3 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
194194
output = [x.clone() for x in output]
195195

196196
return output
197-

torch_np/_wrapper.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -462,16 +462,12 @@ def cov(
462462

463463
m = concatenate((m, y), axis=0)
464464

465-
# if ddof is None:
466-
# if bias == 0:
467-
# ddof = 1
468-
# else:
469-
# ddof = 0
470-
471465
m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none(
472466
m, fweights, aweights
473467
)
474-
result = _impl.cov(m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype)
468+
result = _impl.cov(
469+
m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype
470+
)
475471
return asarray(result)
476472

477473

0 commit comments

Comments
 (0)