Skip to content

Commit d7340fe

Browse files
committed
ENH: add np.correlate
1 parent a149892 commit d7340fe

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,6 @@ def compress(condition, a, axis=None, out=None):
176176
raise NotImplementedError
177177

178178

179-
def correlate(a, v, mode="valid"):
180-
raise NotImplementedError
181-
182-
183179
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
184180
raise NotImplementedError
185181

torch_np/_funcs_impl.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,28 +570,39 @@ def cov(
570570
return result
571571

572572

573-
def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
574-
# NumPy: if v is longer than a, the arrays are swapped before computation
575-
if a.shape[0] < v.shape[0]:
576-
a, v = v, a
577-
573+
def _conv_corr_impl(a, v, mode):
578574
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
579575
a = _util.cast_if_needed(a, dt)
580576
v = _util.cast_if_needed(v, dt)
581577

582578
padding = v.shape[0] - 1 if mode == "full" else mode
583579

584-
# 1. NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
585-
# 2. flip the weights since numpy does and torch does not
580+
# NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
586581
aa = a[None, :]
587-
vv = torch.flip(v, (0,))[None, None, :]
582+
vv = v[None, None, :]
588583

589584
result = torch.nn.functional.conv1d(aa, vv, padding=padding)
590585

591586
# torch returns a 2D result, numpy returns a 1D array
592587
return result[0, :]
593588

594589

590+
def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
591+
# NumPy: if v is longer than a, the arrays are swapped before computation
592+
if a.shape[0] < v.shape[0]:
593+
a, v = v, a
594+
595+
# flip the weights since numpy does and torch does not
596+
v = torch.flip(v, (0,))
597+
598+
return _conv_corr_impl(a, v, mode)
599+
600+
601+
def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
602+
v = torch.conj_physical(v)
603+
return _conv_corr_impl(a, v, mode)
604+
605+
595606
# ### logic & element selection ###
596607

597608

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,7 +2369,6 @@ def test_dtype_str_bytes(self, likefunc, dtype):
23692369
assert result.strides == (4, 1)
23702370

23712371

2372-
@pytest.mark.xfail(reason="TODO")
23732372
class TestCorrelate:
23742373
def _setup(self, dt):
23752374
self.x = np.array([1, 2, 3, 4, 5], dtype=dt)
@@ -2391,20 +2390,13 @@ def test_float(self):
23912390
assert_array_almost_equal(z, self.z1_4)
23922391
z = np.correlate(self.y, self.x, 'full')
23932392
assert_array_almost_equal(z, self.z2)
2394-
z = np.correlate(self.x[::-1], self.y, 'full')
2393+
z = np.correlate(np.flip(self.x), self.y, 'full')
23952394
assert_array_almost_equal(z, self.z1r)
2396-
z = np.correlate(self.y, self.x[::-1], 'full')
2395+
z = np.correlate(self.y, np.flip(self.x), 'full')
23972396
assert_array_almost_equal(z, self.z2r)
23982397
z = np.correlate(self.xs, self.y, 'full')
23992398
assert_array_almost_equal(z, self.zs)
24002399

2401-
def test_object(self):
2402-
self._setup(Decimal)
2403-
z = np.correlate(self.x, self.y, 'full')
2404-
assert_array_almost_equal(z, self.z1)
2405-
z = np.correlate(self.y, self.x, 'full')
2406-
assert_array_almost_equal(z, self.z2)
2407-
24082400
def test_no_overwrite(self):
24092401
d = np.ones(100)
24102402
k = np.ones(3)
@@ -2416,16 +2408,17 @@ def test_complex(self):
24162408
x = np.array([1, 2, 3, 4+1j], dtype=complex)
24172409
y = np.array([-1, -2j, 3+1j], dtype=complex)
24182410
r_z = np.array([3-1j, 6, 8+1j, 11+5j, -5+8j, -4-1j], dtype=complex)
2419-
r_z = r_z[::-1].conjugate()
2411+
r_z = np.flip(r_z).conjugate()
24202412
z = np.correlate(y, x, mode='full')
24212413
assert_array_almost_equal(z, r_z)
24222414

24232415
def test_zero_size(self):
2424-
with pytest.raises(ValueError):
2416+
with pytest.raises((ValueError, RuntimeError)):
24252417
np.correlate(np.array([]), np.ones(1000), mode='full')
2426-
with pytest.raises(ValueError):
2418+
with pytest.raises((ValueError, RuntimeError)):
24272419
np.correlate(np.ones(1000), np.array([]), mode='full')
24282420

2421+
@pytest.mark.skip(reason='do not implement deprecated behavior')
24292422
def test_mode(self):
24302423
d = np.ones(100)
24312424
k = np.ones(3)

0 commit comments

Comments
 (0)