我如何将这个嵌套的for循环编写为列表理解?

我正在处理 4D 数据集,其中有一个嵌套的 for 循环(4 个循环)。该for环路的作品,但它需要一段时间来运行:〜5分钟。我试图用列表理解来正确地写这个,但我对如何做到这一点感到困惑,因为我的嵌套循环:

data = np.random.rand(12, 27, 282, 375)

stdev_data = np.std(data, axis=1)

## nested for loop 

count = []

for i in range(data.shape[0]):
    for j in range(data.shape[1]):
        for lat in range(data.shape[2]):
            for lon in range(data.shape[3]):
                count.append((data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]).sum(axis=0))

reshape_counts = np.reshape(count, data.shape)

这是我对列表理解的尝试:

i, j, lat, lon = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
print(i, j, lat, lon)

test_list = [[(data < -1.282 * stdev_data).sum(axis=0) for lon in lat] for j in i]

我收到一条错误消息,指出“int”对象不可迭代。如何以列表理解的形式重写我的嵌套 for 循环以加快进程?

回答

鉴于您使用的是 numpy,我建议您利用这样一个事实,即它们的for循环是用 C 编写的,并且经常被优化。您最终仍将逐步浏览数据,但速度要快得多。这种方法称为矢量化。

在这种情况下,您试图制作一个布尔掩码,这可以说是简化了操作。请记住,.sum()您的表达式中的调用是一个红鲱鱼:您实际上是在对一个标量布尔值求和,它总是会给您零或一。

以下是如何-1.282在第二维中找到小于sigma 的点:

result = data < -1.282 * stdev_data[:, None, ...]

或者,你可以做

result = data < -1.282 * stdev_data.reshape(stdev_data.shape[0], 1, *stdev_data.shape[1:])

或者

result = data < -1.282 * np.reshape(stdev_data, stdev_data.shape[:1] + (1,) + stdev_data.shape[1:])

一个更简单的解决方案是从一开始就传递keepdims=Truenp.std

result = data < -1.282 * np.std(data, axis=1, keepdims=True)

keepdims=True确保 的输出std具有形状(12, 1, 282, 375)而不仅仅是(12, 282, 375),因此您无需自己重新插入尺寸。

现在,如果您真的想计算您的问题似乎暗示的计数,您可以result沿着第二个维度对掩码求和:

counts = result.sum(axis=1)

最后,完全按照说明回答您的实际问题:for循环直接转化为列表推导式。在您的情况下,这意味着for理解中有四个s,完全按照您最初拥有它们的顺序:

[data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]
    for i in range(data.shape[0])
        for j in range(data.shape[1])
            for lat in range(data.shape[2])
                for lon in range(data.shape[3])]

由于推导式被方括号包围,您可以像我所做的那样将它们的内容自由地写在单独的行上,尽管这当然不是必需的。请注意,唯一真正的区别是内容append在前,并且没有冒号。而且,那个红鲱鱼sum也不见了。


以上是我如何将这个嵌套的for循环编写为列表理解?的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>