为什么np.hypot和np.subtract.outer与普通广播相比非常快?使用Numba并行加速numpy进行距离矩阵计算

我有两组大的 2D 点,我需要计算一个距离矩阵。

我需要它在 python 中运行得很快,所以很明显我使用了 numpy。我最近了解了 numpy 广播并使用了它,而不是在 python 中循环,numpy 将在 C 中进行。

我真的认为广播就是我所需要的,直到我看到其他方法比普通广播更好,我有两种计算距离矩阵的方法,但我不明白为什么一种比另一种更好。

我在这里查找https://github.com/numpy/numpy/issues/14761并且我得到了相互矛盾的结果。

下面是两种计算距离矩阵的方法

单元格 [3, 4, 6] 和 [8, 9] 都计算距离矩阵,但 3+4 使用减法。outer 比使用 vanilla 广播的 8 快,使用 hypot 的 6 比 9 快,这很简单道路。我没有尝试在 python 循环中假设它永远不会完成。

我想知道

1. 有没有更快的方法来计算距离矩阵(可能是 scikit-learn 或 scipy)?

2.为什么hypot和subtract.outer这么快?

为了方便起见,我还附上了代码段 tp run 整个事情,并更改了种子以防止缓存恢复

### Cell 1
import numpy as np

np.random.seed(858442)

### Cell 2
%%time
obs = np.random.random((50000, 2))
interp = np.random.random((30000, 2))

CPU times: user 2.02 ms, sys: 1.4 ms, total: 3.42 ms
Wall time: 1.84 ms

### Cell 3
%%time
d0 = np.subtract.outer(obs[:,0], interp[:,0])

CPU times: user 2.46 s, sys: 1.97 s, total: 4.42 s
Wall time: 4.42 s

### Cell 4
%%time
d1 = np.subtract.outer(obs[:,1], interp[:,1])

CPU times: user 3.1 s, sys: 2.7 s, total: 5.8 s
Wall time: 8.34 s

### Cell 5
%%time
h = np.hypot(d0, d1)

CPU times: user 12.7 s, sys: 24.6 s, total: 37.3 s
Wall time: 1min 6s

### Cell 6
np.random.seed(773228)

### Cell 7
%%time
obs = np.random.random((50000, 2))
interp = np.random.random((30000, 2))

CPU times: user 1.84 ms, sys: 1.56 ms, total: 3.4 ms
Wall time: 2.03 ms

### Cell 8
%%time
d = obs[:, np.newaxis, :] - interp
d0, d1 = d[:, :, 0], d[:, :, 1]

CPU times: user 22.7 s, sys: 8.24 s, total: 30.9 s
Wall time: 33.2 s

### Cell 9
%%time
h = np.sqrt(d0**2 + d1**2)

CPU times: user 29.1 s, sys: 2min 12s, total: 2min 41s
Wall time: 6min 10s

在此感谢Jérôme Richard 的 更新

  • Stackoverflow 从不让人失望
  • 使用numba有一种更快的方法
  • 它有及时的编译器,可以将 python 代码段转换为快速机器代码,第一次使用它会比后续使用慢一点,因为它会编译。但即使是第一次 njit parallel 以 9x margin for (49000, 12000) 矩阵击败了 hypot +subtract.outer

各种方法的表现

  • 确保每次运行脚本时使用不同的种子
import sys
import time

import numba as nb
import numpy as np

np.random.seed(int(sys.argv[1]))

d0 = np.random.random((49000, 2))
d1 = np.random.random((12000, 2))

def f1(d0, d1):
    print('Numba without parallel')
    res = np.empty((d0.shape[0], d1.shape[0]), dtype=d0.dtype)
    for i in nb.prange(d0.shape[0]):
        for j in range(d1.shape[0]):
            res[i, j] = np.sqrt((d0[i, 0] - d1[j, 0])**2 + (d0[i, 1] - d1[j, 1])**2)
    return res

# Add eager compilation, compiles before hand
@nb.njit((nb.float64[:, :], nb.float64[:, :]), parallel=True)
def f2(d0, d1):
    print('Numba with parallel')
    res = np.empty((d0.shape[0], d1.shape[0]), dtype=d0.dtype)
    for i in nb.prange(d0.shape[0]):
        for j in range(d1.shape[0]):
            res[i, j] = np.sqrt((d0[i, 0] - d1[j, 0])**2 + (d0[i, 1] - d1[j, 1])**2)
    return res

def f3(d0, d1):
    print('hypot + subtract.outer')
    np.hypot(
        np.subtract.outer(d0[:,0], d1[:,0]),
        np.subtract.outer(d0[:,1], d1[:,1])
    )

if __name__ == '__main__':
    s1 = time.time()
    eval(f'{sys.argv[2]}(d0, d1)')
    print(time.time() - s1)
(base) ~/xx@xx:~/xx$ python3 test.py 523432 f3
hypot + subtract.outer
9.79756784439087
(base) xx@xx:~/xx$ python3 test.py 213622 f2
Numba with parallel
0.3393140316009521

如果我找到了更快的方法,我将更新这篇文章以获得进一步的发展

以上是为什么np.hypot和np.subtract.outer与普通广播相比非常快?使用Numba并行加速numpy进行距离矩阵计算的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>