如何将函数直接映射到列表列表?
我为图像构建了一个像素分类器,对于图像中的每个像素,我想定义它属于哪个预定义的颜色簇。它有效,但每张图像大约 5 分钟,我想我正在做一些肯定可以优化的非 Pythonic 的事情。
我们如何将函数直接映射到列表列表上?
#First I convert my image to a list
#Below list represents a true image size
list1=[[255, 114, 70],
[120, 89, 15],
[247, 190, 6],
[41, 38, 37],
[102, 102, 10],
[255,255,255]]*3583180
然后我们定义了将颜色映射到的集群以及执行此操作的函数(取自PIL 库)
#Define colors of interest
#Colors of interest
RED=[255, 114, 70]
DARK_YELLOW=[120, 89, 15]
LIGHT_YELLOW=[247, 190, 6]
BLACK=[41, 38, 37]
GREY=[102, 102, 10]
WHITE=[255,255,255]
Colors=[RED, DARK_YELLOW, LIGHT_YELLOW, GREY, BLACK, WHITE]
#Function to find closes cluster by root and squareroot distance of RGB
def distance(c1, c2):
(r1,g1,b1) = c1
(r2,g2,b2) = c2
return math.sqrt((r1 - r2)**2 + (g1 - g2) ** 2 + (b1 - b2) **2)
剩下的就是匹配每种颜色,并使用原始颜色的匹配索引创建一个新列表:
Filt_lab=[]
#Match colors and make new list with indexed colors
for pixel in tqdm(list1):
closest_colors = sorted(Colors, key=lambda color: distance(color, pixel))
closest_color = closest_colors[0]
for num, clust in enumerate(Colors):
if list(clust) == list(closest_color):
Filt_lab.append(num)
运行单个图像大约需要 5 分钟,这是可以的,但是可能有一种方法可以大大减少这个时间?
36%|???? | 7691707/21499080 [01:50<03:18, 69721.86it/s]
Filt_lab 的预期结果:
[0, 1, 2, 4, 3, 5]*3583180
回答
您可以使用Numba 的 JIT 大幅加快代码速度。这个想法是classified_pixels通过迭代每个像素的颜色来动态构建。颜色存储在一个 Numpy 数组中,其中索引是颜色键。整个计算可以并行运行。这避免了在内存中创建和写入/读取许多临时数组以及分配大量内存。此外,可以调整数据类型,以便结果数组在内存中更小(因此写入/读取速度更快)。这是最终的脚本:
import numpy as np
import numba as nb
@nb.njit('int32[:,::1](int32[:,:,::1], int32[:,::1])', parallel=True)
def classify(image, colors):
classified_pixels = np.empty((image.shape[0], image.shape[1]), dtype=np.int32)
for i in nb.prange(image.shape[0]):
for j in range(image.shape[1]):
minId = -1
minValue = 256*256 # The initial value is the maximum possible value
ir, ig, ib = image[i, j]
# Find the color index with the minimum difference
for k in range(len(colors)):
cr, cg, cb = colors[k]
total = (ir-cr)**2 + (ig-cg)**2 + (ib-cb)**2
if total < minValue:
minValue = total
minId = k
classified_pixels[i, j] = minId
return classified_pixels
# Representative image
np.random.seed(42)
imarray = np.random.rand(3650,2000,3) * 255
image = imarray.astype(np.int32)
# Colors of interest
RED = [255, 0, 0]
DARK_YELLOW = [120, 89, 15]
LIGHT_YELLOW = [247, 190, 6]
BLACK = [41, 38, 37]
GREY = [102, 102, 10]
WHITE = [255, 255, 255]
# Build a Numpy array rather than a dict
colors = np.array([RED, DARK_YELLOW, LIGHT_YELLOW, GREY, BLACK, WHITE], dtype=np.int32)
# Actual classification
classified_pixels = classify(image, colors)
# Convert array to list
cl_pixel_list = classified_pixels.reshape(classified_pixels.shape[0] * classified_pixels.shape[1]).tolist()
# Print
print(cl_pixel_list[0:10])
这个实现在我的 6 核机器上大约需要0.19 秒。它比迄今为止提供的最后一个答案快约15 倍,比初始实现快一千多倍。请注意,tolist()由于classify函数速度非常快,所以大约一半的时间都花在了上面。