我如何将这个嵌套的for循环编写为列表理解?
我正在处理 4D 数据集,其中有一个嵌套的 for 循环(4 个循环)。该for环路的作品,但它需要一段时间来运行:〜5分钟。我试图用列表理解来正确地写这个,但我对如何做到这一点感到困惑,因为我的嵌套循环:
data = np.random.rand(12, 27, 282, 375)
stdev_data = np.std(data, axis=1)
## nested for loop
count = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for lat in range(data.shape[2]):
for lon in range(data.shape[3]):
count.append((data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]).sum(axis=0))
reshape_counts = np.reshape(count, data.shape)
这是我对列表理解的尝试:
i, j, lat, lon = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
print(i, j, lat, lon)
test_list = [[(data < -1.282 * stdev_data).sum(axis=0) for lon in lat] for j in i]
我收到一条错误消息,指出“int”对象不可迭代。如何以列表理解的形式重写我的嵌套 for 循环以加快进程?
回答
鉴于您使用的是 numpy,我建议您利用这样一个事实,即它们的for循环是用 C 编写的,并且经常被优化。您最终仍将逐步浏览数据,但速度要快得多。这种方法称为矢量化。
在这种情况下,您试图制作一个布尔掩码,这可以说是简化了操作。请记住,.sum()您的表达式中的调用是一个红鲱鱼:您实际上是在对一个标量布尔值求和,它总是会给您零或一。
以下是如何-1.282在第二维中找到小于sigma 的点:
result = data < -1.282 * stdev_data[:, None, ...]
或者,你可以做
result = data < -1.282 * stdev_data.reshape(stdev_data.shape[0], 1, *stdev_data.shape[1:])
或者
result = data < -1.282 * np.reshape(stdev_data, stdev_data.shape[:1] + (1,) + stdev_data.shape[1:])
一个更简单的解决方案是从一开始就传递keepdims=True给np.std:
result = data < -1.282 * np.std(data, axis=1, keepdims=True)
keepdims=True确保 的输出std具有形状(12, 1, 282, 375)而不仅仅是(12, 282, 375),因此您无需自己重新插入尺寸。
现在,如果您真的想计算您的问题似乎暗示的计数,您可以result沿着第二个维度对掩码求和:
counts = result.sum(axis=1)
最后,完全按照说明回答您的实际问题:for循环直接转化为列表推导式。在您的情况下,这意味着for理解中有四个s,完全按照您最初拥有它们的顺序:
[data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]
for i in range(data.shape[0])
for j in range(data.shape[1])
for lat in range(data.shape[2])
for lon in range(data.shape[3])]
由于推导式被方括号包围,您可以像我所做的那样将它们的内容自由地写在单独的行上,尽管这当然不是必需的。请注意,唯一真正的区别是内容append在前,并且没有冒号。而且,那个红鲱鱼sum也不见了。