Skip to content

Commit 50e0dd9

Browse files
committed
MAINT: deduplicate y != None in cov, corrcoef
1 parent 8f10e1b commit 50e0dd9

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

torch_np/_detail/implementations.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,7 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
301301
# #### cov & corrcoef
302302

303303

304-
def corrcoef(xy_tensor, rowvar=True, *, dtype=None):
305-
if rowvar is False:
306-
# xy_tensor is at least 2D, so using .T is safe
307-
xy_tensor = x_tensor.T
308-
304+
def corrcoef(xy_tensor, *, dtype=None):
309305
is_half = dtype == torch.float16
310306
if is_half:
311307
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"

torch_np/_wrapper.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -388,26 +388,40 @@ def diag(v, k=0):
388388
###### misc/unordered
389389

390390

391+
def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
392+
"""Prepate inputs for cov and corrcoef."""
393+
394+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
395+
if y_tensor is not None:
396+
# make sure x and y are at least 2D
397+
ndim_extra = 2 - x_tensor.ndim
398+
if ndim_extra > 0:
399+
x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape)
400+
if not rowvar and x_tensor.shape[0] != 1:
401+
x_tensor = x_tensor.mT
402+
x_tensor = x_tensor.clone()
403+
404+
ndim_extra = 2 - y_tensor.ndim
405+
if ndim_extra > 0:
406+
y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape)
407+
if not rowvar and y_tensor.shape[0] != 1:
408+
y_tensor = y_tensor.mT
409+
y_tensor = y_tensor.clone()
410+
411+
x_tensor = _impl.concatenate((x_tensor, y_tensor), axis=0)
412+
413+
return x_tensor
414+
415+
391416
@_decorators.dtype_to_torch
392417
def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None):
393418
if bias is not None or ddof is not None:
394419
# deprecated in NumPy
395420
raise NotImplementedError
396421

397-
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
398-
if y is not None:
399-
x = array(x, ndmin=2)
400-
if not rowvar and x.shape[0] != 1:
401-
x = x.T
402-
403-
y = array(y, ndmin=2)
404-
if not rowvar and y.shape[0] != 1:
405-
y = y.T
406-
407-
x = concatenate((x, y), axis=0)
408-
409-
x_tensor = asarray(x).get()
410-
result = _impl.corrcoef(x_tensor, rowvar, dtype=dtype)
422+
x_tensor, y_tensor = _helpers.to_tensors_or_none(x, y)
423+
tensor = _xy_helper_corrcoef(x_tensor, y_tensor, rowvar)
424+
result = _impl.corrcoef(tensor, dtype=dtype)
411425
return asarray(result)
412426

413427

@@ -423,21 +437,12 @@ def cov(
423437
*,
424438
dtype=None,
425439
):
426-
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
427-
if y is not None:
428-
m = array(m, ndmin=2)
429-
if not rowvar and m.shape[0] != 1:
430-
m = m.T
431-
432-
y = array(y, ndmin=2)
433-
if not rowvar and y.shape[0] != 1:
434-
y = y.T
435440

436-
m = concatenate((m, y), axis=0)
437-
438-
m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none(
439-
m, fweights, aweights
441+
m_tensor, y_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none(
442+
m, y, fweights, aweights
440443
)
444+
m_tensor = _xy_helper_corrcoef(m_tensor, y_tensor, rowvar)
445+
441446
result = _impl.cov(
442447
m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype
443448
)

0 commit comments

Comments
 (0)