获得k个排序数组的交集的最有效方法是什么?
给定 k 个排序数组,获取这些列表交集的最有效方法是什么
例子
输入:
[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
输出:
[1,7]
有一种方法可以根据我在 nlogk 时间的编程面试元素中读到的内容来获得 k 个排序数组的并集。我想知道是否有办法为十字路口做类似的事情
## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heapq.heappop(heap)
it = srtd_iters[ary]
res.append(elem)
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
编辑:显然这是一个我试图解决的算法问题,所以我不能使用任何内置函数,如设置交集等
回答
利用排序顺序
这是一种 O(n) 方法,除了一个迭代器和每个子列表一个值的基本要求之外,不需要任何特殊的数据结构或辅助内存:
from itertools import cycle
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
while (value := next(iterator)) == curr:
pass
pair[VALUE] = value
curr, matches = value, 1
except StopIteration:
return result
这是一个示例会话:
>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]
文字算法
该算法围绕迭代器、值对循环。如果一个值在所有对中都匹配,则它属于交集。如果一个值比目前看到的任何其他值都低,则当前迭代器前进。如果一个值大于目前看到的任何值,它就会成为新的目标并且匹配计数被重置为 1。当任何迭代器耗尽时,算法就完成了。
不依赖于内置函数
itertools.cycle()的使用是完全可选的。通过增加一个在末尾环绕的索引可以很容易地模拟它。
代替:
iterator, value = pair = next(pairs)
你可以写:
pairnum += 1
if pairnum == n:
pairnum = 0
iterator, value = pair = pairs[pairnum]
或更紧凑:
pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum]
重复值
如果要保留重复(如多重集),这是一个简单的修改,只需更改后面的四行result.append(curr)以从每个迭代器中删除匹配元素:
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
for i in range(n):
iterator, value = pair = next(pairs)
pair[VALUE] = next(iterator)
curr, matches = pair[VALUE], 1
except StopIteration:
return result