如何从函数列表[f1,f2,f3,…fn]组合嵌套函数g=fn(…(f3(f2(f1()))…)

是否有一种现成的 Pythonic 方法可以从函数列表 [f1, f2, f3] 中组合多个嵌套函数 g = f3(f2(f1())) ,其中列表中有更多函数。

如果有几个,我可以这样做:

g = lambda x: f3(f2(f1(x)))

然而,当我在深度神经网络中有几十个功能(例如层)时,它是无法管理的。宁愿不创建另一个函数来组合 g 而是找到一种可用的方法。


更新

基于@Chris 的回答。对于顺序神经网络层[ batchnorm, matmul, activation, softmaxloss ],每个层都有一种forward(X)计算其输出到下一层的方法,损失函数 L 和 loss 为:

L = reduce(lambda f, g: lambda X: g(f(X)),  [ layer.forward for layer in layers ] )   # Loss function
network_loss = L(X)

回答

一种使用方式functools.reduce

from functools import reduce

f1 = lambda x: x+1
f2 = lambda x: x*2
f3 = lambda x: x+3
funcs = [f1, f2, f3]

g = reduce(lambda f, g: lambda x: g(f(x)), funcs)

输出:

g(1)==7 # ((1+1) * 2) + 3
g(2)==9 # ((2+1) * 2) + 3

洞察力:

functools.reducefuncs根据其第一个参数(lambda此处)链接其第二个参数(此处)。

话虽如此,它将开始链接f1and f2as f_21(x) = f2(f1(x)),然后f3f_21as f3(f_21(x))which 变为g(x)


以上是如何从函数列表[f1,f2,f3,…fn]组合嵌套函数g=fn(…(f3(f2(f1()))…)的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>