@@ -388,26 +388,40 @@ def diag(v, k=0):
388
388
###### misc/unordered
389
389
390
390
391
+ def _xy_helper_corrcoef (x_tensor , y_tensor = None , rowvar = True ):
392
+ """Prepate inputs for cov and corrcoef."""
393
+
394
+ # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
395
+ if y_tensor is not None :
396
+ # make sure x and y are at least 2D
397
+ ndim_extra = 2 - x_tensor .ndim
398
+ if ndim_extra > 0 :
399
+ x_tensor = x_tensor .view ((1 ,) * ndim_extra + x_tensor .shape )
400
+ if not rowvar and x_tensor .shape [0 ] != 1 :
401
+ x_tensor = x_tensor .mT
402
+ x_tensor = x_tensor .clone ()
403
+
404
+ ndim_extra = 2 - y_tensor .ndim
405
+ if ndim_extra > 0 :
406
+ y_tensor = y_tensor .view ((1 ,) * ndim_extra + y_tensor .shape )
407
+ if not rowvar and y_tensor .shape [0 ] != 1 :
408
+ y_tensor = y_tensor .mT
409
+ y_tensor = y_tensor .clone ()
410
+
411
+ x_tensor = _impl .concatenate ((x_tensor , y_tensor ), axis = 0 )
412
+
413
+ return x_tensor
414
+
415
+
391
416
@_decorators .dtype_to_torch
392
417
def corrcoef (x , y = None , rowvar = True , bias = NoValue , ddof = NoValue , * , dtype = None ):
393
418
if bias is not None or ddof is not None :
394
419
# deprecated in NumPy
395
420
raise NotImplementedError
396
421
397
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
398
- if y is not None :
399
- x = array (x , ndmin = 2 )
400
- if not rowvar and x .shape [0 ] != 1 :
401
- x = x .T
402
-
403
- y = array (y , ndmin = 2 )
404
- if not rowvar and y .shape [0 ] != 1 :
405
- y = y .T
406
-
407
- x = concatenate ((x , y ), axis = 0 )
408
-
409
- x_tensor = asarray (x ).get ()
410
- result = _impl .corrcoef (x_tensor , rowvar , dtype = dtype )
422
+ x_tensor , y_tensor = _helpers .to_tensors_or_none (x , y )
423
+ tensor = _xy_helper_corrcoef (x_tensor , y_tensor , rowvar )
424
+ result = _impl .corrcoef (tensor , dtype = dtype )
411
425
return asarray (result )
412
426
413
427
@@ -423,21 +437,12 @@ def cov(
423
437
* ,
424
438
dtype = None ,
425
439
):
426
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
427
- if y is not None :
428
- m = array (m , ndmin = 2 )
429
- if not rowvar and m .shape [0 ] != 1 :
430
- m = m .T
431
-
432
- y = array (y , ndmin = 2 )
433
- if not rowvar and y .shape [0 ] != 1 :
434
- y = y .T
435
440
436
- m = concatenate ((m , y ), axis = 0 )
437
-
438
- m_tensor , fweights_tensor , aweights_tensor = _helpers .to_tensors_or_none (
439
- m , fweights , aweights
441
+ m_tensor , y_tensor , fweights_tensor , aweights_tensor = _helpers .to_tensors_or_none (
442
+ m , y , fweights , aweights
440
443
)
444
+ m_tensor = _xy_helper_corrcoef (m_tensor , y_tensor , rowvar )
445
+
441
446
result = _impl .cov (
442
447
m_tensor , bias , ddof , fweights_tensor , aweights_tensor , dtype = dtype
443
448
)
0 commit comments