I am an AI Reseach Engineer. I was formerly a researcher @Oxford VGG before founding the AI Bites YouTube channel.
This story contains new, firsthand information uncovered by the writer.
我们已经将经典的多层感知器 (MLP) 视为理所当然,并围绕它构建了许多架构。MLP 是当今我们看到的每个 LLM 或基础模型(例如 chatGPT、LLAMA、DALLE 和 CLIP)的重要组成部分。甚至是简单的识别模型(例如 YOLO-v*)。
如果我现在告诉你,我们在 MLP 领域有一个竞争对手,你会怎么想?城里有一篇新论文,名为“Kolmogorov-Arnold 网络”,简称 KAN,它对 MLP 提出了挑战。如果他们提出的解决方案真正具有可扩展性,那么我们就可以拥有下一代神经网络,这将使我们更接近通用人工智能 (AGI)。
虽然 MLP 包含 ReLU、sigmoid、tanh、GeLU 等激活函数,但 KAN 建议我们学习这些激活函数。那么,KAN 是如何做到的呢?它背后的数学原理是什么?它是如何实现的?我们又如何训练 KAN?
我已尽力在此总结了 KAN 论文。您可以选择阅读此摘要或阅读长达 48 页的论文!
如果您和我一样,希望将事物形象化以便更好地理解,这里有这篇文章的视频形式:
让我们从我们非常熟悉的 MLP 开始。MLP 由节点和边组成。在每个节点中,我们将输入相加,并应用 ReLU、GeLU 和 SeLU 等激活函数,以生成该特定节点的输出。
这些激活函数在训练过程中永远不会改变。换句话说,它们没有任何参数。它们不够智能,无法根据给定的训练数据集进行自我调整。因此,在训练过程中进行训练或更新的是每个节点的权重。
现在,如果我们质疑激活函数需要固定并使其可训练的假设,结果会怎样?这就是 KAN 网络试图解决的挑战。KAN 网络的激活函数在训练过程中得到更新。在深入研究之前,让我们先从多项式和曲线拟合开始。
因此,KAN 的基本思想是任何多元复合函数都可以分解为几个单变量函数的总和。
例如,假设我们有一个 3 次方程,其中 y=x³,如上图黄色曲线所示。还有另一个 2 次方程,y=x²,如上图动画中的蓝色曲线所示。我们可以在这个可视化中看到,使用 x² 永远无法实现 x³ 所实现的曲率。
假设我们给出了下面红点和蓝点所表示的数据,并且我们希望找到两个类之间的二元分类边界。
使用二阶多项式x² ,我们将无法找到两者之间的边界,因为 x² 曲线是“U”形,但数据是“S”形。虽然使用x³适合这些数据,但它会带来额外的计算成本。另一种解决方案可能是当输入x为负时使用x² ,而当 x 为正时使用 - x² (上图中用手绘制的蓝色曲线)。
我们所做的只是添加两个低次多项式,以获得具有更高自由度的曲线。这正是 KAN 网络背后的确切思想。
现在让我们来看一个稍微复杂一点的玩具问题,我们知道数据是由一个简单的方程式y=exp(sin(x1² + x2²) + sin(x3² + x4²))生成的。所以我们有 4 个输入变量,并且有三个运算,即指数、正弦和平方。因此,我们可以选择四个输入节点和三个层,每个层专用于三种不同的运算,如下所示。
KAN 网络用于解决一个玩具问题,该网络具有四个输入和三个用于计算的基函数(指数、正弦和平方)
训练后,节点将收敛到平方、正弦和指数函数以拟合数据。
由于这是一个玩具问题,我们知道数据来自哪个方程。但实际上,我们不知道真实世界数据的分布。解决这个问题的一种方法是使用 B 样条函数。
B 样条的基本思想是,任何给定的函数或曲线都可以表示为更简单的函数或曲线的组合。这些更简单的函数称为基函数。例如,让我们以下图中的红色曲线为例。为了简单起见,我们尝试仅用两个基函数来表示它。
我们可以将其分解为 3 个点,因为我们将用两个基函数之和来表示它。这些点称为节点。基函数的数量可以是任意的n 。控制此基函数组合方式的参数是c。当我们“连接”两条曲线时,节点处可能会出现不连续性。解决方案是限制节点处曲线的曲率,以便获得平滑的曲线。例如,我们可以将两条曲线的斜率限制为在节点处相同,如下图中的绿色箭头所示。
由于我们无法在神经网络中施加这样的约束,他们在论文中引入了残差激活函数。这更像是一种正则化。实际上,这是添加到标准样条函数的 SeLU 激活,如下文所示。
KAN 引入了一种称为细粒度的新训练方法。我们都熟悉的是微调,即向模型添加更多参数。然而,在细粒度的情况下,我们可以提高样条网格的密度。这就是所谓的网格扩展。
从论文上图我们可以看出,细粒度只是使 B 样条网格变得密集,从而使其更具代表性,从而更加强大。
样条函数的缺点之一是它们是递归的,因此计算成本很高。它们的计算复杂度为 O(N²LG),高于 MLP 的通常复杂度 O(N²L)。额外的复杂性来自网格间隔 G。
作者通过以下方式来解释这个固有问题:
我们将在结果部分看到这些防御的图表。现在,让我们进一步了解 KAN 的另一个特色。
由于 KAN 可以学习函数,因此它并不像 MLP 那样只是一个黑盒子,我们可以通过为给定数据或问题选择 MLP 的深度和宽度来设计它们。因此,为了使 KAN 更具可解释性并设计出良好的 KAN 网络,我们需要遵循以下步骤:
fix_symbolic(l,i,j,f)
接口函数使我们能够做到这一点,其中l、i、j是节点层和位置, f是可以是sine, cosine, log, etc
的函数本文建议的训练 KAN 网络的不同步骤总结
上图总结了不同的步骤。我们从一个大型网络开始,然后进行稀疏化(步骤 1),修剪生成的网络(步骤 2),设置一些符号化(步骤 3),训练网络(步骤 4),最后得到训练好的模型。
使用上述步骤,他们针对五种不同的小问题训练了 KAN 网络,以说明其有效性并将其与 MLP 进行比较。比较的关键要点如下:
第一个点由上图顶部的五张图中的粗蓝线表示,这些图分别表示 5 个小问题。最后两个点由底部的图表示,该图显示了解决任何给定问题的损失曲线和参数计数。
下一个要点是,在灾难性遗忘问题上,KAN 远胜于 MLP。如果我们输入序列数据进行持续学习,KAN 似乎比 MLP 更能记住过去的数据。如下图所示,KAN 重现了数据中的 5 个阶段,但 MLP 却举步维艰。
他们还进行了大量的实验,以证明 KAN 可用于解决涉及偏微分和物理方程的问题。我们先不讨论这些细节,而是看看何时应该选择 KAN 而不是 MLP。
他们给出了下图来指导我们何时选择 KAN 而不是 MLP。因此,如果符合以下条件,请选择 KAN:
否则,MLP 仍可获胜。
如果你喜欢这篇文章,为什么不关注我呢
也请订阅我的
在我看来,KAN 不会取代 MLP,就像 transformer 横扫 NLP 领域一样。相反,KAN 将证明其对数学和物理中的小众问题很有用。即便如此,我觉得我们还需要更多改进。但对于用基础模型解决的大数据问题,KAN 还有很长的路要走,至少就目前的情况而言。
此外,KAN 架构的训练方法和设计往往偏离了现代神经网络设计和训练的标准方式。尽管如此,GitHub 页面已经有 13k 个 star 和 1.2k 个 fork,这表明它有发展前途。让我们拭目以待吧。