Skip to content

Commit 6571b12

Browse files
committed
MAINT: split cov, corrcoef, concat into ndarray/torch parts
1 parent 35ec9ab commit 6571b12

File tree

3 files changed

+88
-61
lines changed

3 files changed

+88
-61
lines changed

torch_np/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def wrapped(*args, dtype=None, **kwds):
1515
torch_dtype = None
1616
if dtype is not None:
1717
dtype = _dtypes.dtype(dtype)
18-
torch_dtype = dtype._scalar_type.torch_dtype
18+
torch_dtype = dtype.torch_dtype
1919
return func(*args, dtype=torch_dtype, **kwds)
2020

2121
return wrapped

torch_np/_detail/implementations.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,82 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
9595
return result
9696

9797

98+
# #### concatenate and relatives
99+
100+
def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
101+
# np.concatenate ravels if axis=None
102+
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
103+
104+
# figure out the type of the inputs and outputs
105+
if out is None and dtype is None:
106+
out_dtype = None
107+
else:
108+
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
109+
110+
# cast input arrays if necessary; do not broadcast them agains `out`
111+
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
112+
113+
try:
114+
result = torch.cat(tensors, axis)
115+
except (IndexError, RuntimeError):
116+
raise _util.AxisError
117+
118+
return result
119+
120+
121+
# #### cov & corrcoef
122+
123+
def corrcoef(xy_tensor, rowvar=True, *, dtype=None):
124+
if rowvar is False:
125+
# xy_tensor is at least 2D, so using .T is safe
126+
xy_tensor = x_tensor.T
127+
128+
is_half = dtype == torch.float16
129+
if is_half:
130+
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
131+
dtype = torch.float32
132+
133+
if dtype is not None:
134+
xy_tensor = xy_tensor.to(dtype)
135+
136+
result = torch.corrcoef(xy_tensor)
137+
138+
if is_half:
139+
result = result.to(torch.float16)
140+
141+
return result
142+
143+
144+
def cov(
145+
m_tensor,
146+
bias=False,
147+
ddof=None,
148+
fweights_tensor=None,
149+
aweights_tensor=None,
150+
*,
151+
dtype=None,
152+
):
153+
if ddof is None:
154+
ddof = 1 if bias == 0 else 0
155+
156+
is_half = dtype == torch.float16
157+
if is_half:
158+
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
159+
dtype = torch.float32
160+
161+
if dtype is not None:
162+
m_tensor = m_tensor.to(dtype)
163+
164+
result = torch.cov(
165+
m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor
166+
)
167+
168+
if is_half:
169+
result = result.to(torch.float16)
170+
171+
return result
172+
173+
98174
def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
99175
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L4892-L5047
100176
ndim = len(xi_tensors)
@@ -118,3 +194,4 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
118194
output = [x.clone() for x in output]
119195

120196
return output
197+

torch_np/_wrapper.py

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -434,23 +434,7 @@ def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None):
434434
x = concatenate((x, y), axis=0)
435435

436436
x_tensor = asarray(x).get()
437-
438-
if rowvar is False:
439-
x_tensor = x_tensor.T
440-
441-
is_half = dtype == torch.float16
442-
if is_half:
443-
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
444-
dtype = torch.float32
445-
446-
if dtype is not None:
447-
x_tensor = x_tensor.to(dtype)
448-
449-
result = torch.corrcoef(x_tensor)
450-
451-
if is_half:
452-
result = result.to(torch.float16)
453-
437+
result = _impl.corrcoef(x_tensor, rowvar, dtype=dtype)
454438
return asarray(result)
455439

456440

@@ -478,45 +462,25 @@ def cov(
478462

479463
m = concatenate((m, y), axis=0)
480464

481-
if ddof is None:
482-
if bias == 0:
483-
ddof = 1
484-
else:
485-
ddof = 0
465+
# if ddof is None:
466+
# if bias == 0:
467+
# ddof = 1
468+
# else:
469+
# ddof = 0
486470

487471
m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none(
488472
m, fweights, aweights
489473
)
490-
491-
# work with tensors from now on
492-
is_half = dtype == torch.float16
493-
if is_half:
494-
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
495-
dtype = torch.float32
496-
497-
if dtype is not None:
498-
m_tensor = m_tensor.to(dtype)
499-
500-
result = torch.cov(
501-
m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor
502-
)
503-
504-
if is_half:
505-
result = result.to(torch.float16)
506-
474+
result = _impl.cov(m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype)
507475
return asarray(result)
508476

509477

478+
@_decorators.dtype_to_torch
510479
def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"):
511480
if ar_tuple == ():
512481
# XXX: RuntimeError in torch, ValueError in numpy
513482
raise ValueError("need at least one array to concatenate")
514483

515-
tensors = _helpers.to_tensors(*ar_tuple)
516-
517-
# np.concatenate ravels if axis=None
518-
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
519-
520484
if out is not None:
521485
if not isinstance(out, ndarray):
522486
raise ValueError("'out' must be an array")
@@ -527,22 +491,8 @@ def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"):
527491
"concatenate() only takes `out` or `dtype` as an "
528492
"argument, but both were provided."
529493
)
530-
531-
# figure out the type of the inputs and outputs
532-
if out is None and dtype is None:
533-
out_dtype = None
534-
else:
535-
out_dtype = out.dtype if dtype is None else _dtypes.dtype(dtype)
536-
out_dtype = out_dtype.type.torch_dtype
537-
538-
# cast input arrays if necessary; do not broadcast them agains `out`
539-
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
540-
541-
try:
542-
result = torch.cat(tensors, axis)
543-
except (IndexError, RuntimeError):
544-
raise _util.AxisError
545-
494+
tensors = _helpers.to_tensors(*ar_tuple)
495+
result = _impl.concatenate(tensors, axis, out, dtype, casting)
546496
return _helpers.result_or_out(result, out)
547497

548498

0 commit comments

Comments
 (0)