当前位置: 主页 > 建站知识 > APP开发

新架构RNN反超Transformer:每个隐藏状态都是一个模型,一作:从根本上改变语言模型

发布时间:2024-07-19 09:21   浏览次数:次   作者:6kYzQ!yIEmp_M6UkZ

  新架构,再次向 Transformer 发起挑战!

  核心思想:将 RNN 中的隐藏状态换成可学习的模型。

  甚至在测试时都可以学习,所以该方法称为TTT(Test-Time Training)。

  共同一作 UC 伯克利的 Karen Dalal 表示:我相信这将从根本上改变语言模型。

  一个 TTT 层拥有比 RNN 表达能力更强的隐藏状态,可以直接取代 Transformer 中昂贵的自注意力层。

  在实验中,隐藏状态是线性模型的 TTT-Linear 表现超过了 Transformer 和 Mamba,用更少的算力达到更低的困惑度(左),也能更好利用长上下文(右)。

  此外,隐藏状态是 MLP 模型的 TTT-MLP 在 32k 长上下文时表现还要更好。

  Karen Dalel 还指出,理论上可学习的隐藏状态可以是任意模型,对于更长上下文来说,可以是 CNN、甚至可以是完整的 Transformer 来套娃。

  目前刚刚出炉的 TTT 论文已经在学术界引起关注和讨论,斯坦福博士生 Andrew Gao 认为,这篇论文或许能成为下一篇 Attention is all you need。

  另外有人表示,众多新架构能否真正击败 Transformer,还要看能不能扩展到更大规模。

  Karen Dalel 透露,马上就会推出 7B 模型。‍‍‍

  用机器学习模型来压缩上下文

  传统 RNN,隐藏状态固定大小表达能力受限,也不好并行训练。

  Transformer 强大,但自注意力机制随上下文长度呈平方复杂度,非常昂贵。

  最近一系列基于 RNN 的架构创新中:

  RWKV,用线性注意力结合 RNN 和 Transformer 的优点,在训练时可以并行计算。

  Mamba,赋予模型选择性记住或遗忘信息的能力来压缩上下文,同时设计了面向硬件的高效并行算法。

  它们的表现在短上下文时追上甚至超越了 Transformer,但在 32k 超长上下文以上,Trasformer 依旧称霸。

  TTT 团队的想法来自于:与其让隐藏状态被动地储存信息,不如让它主动学习。

  就像 Transformer 模型作为一个整体在压缩互联网数据到参数中一样,可学习的隐藏状态模型也在少量参数上不断缩上下文信息。

  这种 " 隐藏状态模型 " 随着时间的推移仍然具有固定的大小(固定的模型参数),但表达能力更强了。

  论文的联合指导 UCSD 助理教授王小龙认为:

网络科技模板设计图片

  Transformer 显式地储存所有输入 token,如果你认为个神经网络是压缩信息的好方法,那么压缩这些 token 也将是有意义的。

  如此一来,整个框架的时间复杂度还是线性的,

  至此,序列建模被拆解为两个嵌套的学习循环,外循环负责整体的语言建模,内循环通过自监督学习压缩上下文信息。

  外循环的参数变成了内循环的超参数,也就是元学习的一个变种了。

  标准的元学习是训练一个适应不同任务的模型,而 TTT 是让模型去适应每一个测试样本。单个样本虽然信息量小,但用来训练隐藏状态模型也绰绰有余。

  特别的,在内循环是一个线性模型时,相当于线性注意力。当内循环是一个 Nadaraya-Watson estimator 时,TTT 等价于自注意力。

青海智凌分析手机网站建设的要点

  在测试时学习

  在 TTT 层里,使用自监督学习方法将上下文压缩到隐藏状态。

  上下文就是未标记的数据集,隐藏状态不再是一个固定的向量,可以是线性模型、小型神经网络或任何机器学习模型,更新规则采用了在自监督损失上的一步梯度下降。

  这样一来,隐藏状态模型可以记住产生大梯度的输入,并且可以获得比选择性遗忘机制更强的拟合和泛化能力,并且在测试时仍然为每个输入序列训练不同的参数。

  到目前为止,朴素的 TTT 层已经有效了,但还无法并行化。

  团队提出的解决方案为 mini-batch 梯度下降,把一个 batch 内的梯度计算并行化。

  再通过 Dual form 方法,只在 mini-batch 结束时计算权重以及输出 token,避免冗余计算。在 JAX 版实现中快了 5 倍以上。

  TTT 能否成为 "Transformer 杀手 "?

  理论上都走的通了,那么 TTT 在实验中表现到底如何?

  最简单干净的测试方法,应该是直接替换掉 Transformer 中的自注意力层。

  但是在研究过程中,团队发现 Mamba 等现代 RNN 的骨干中在 RNN 层之前还包含时间卷积,对 TTT 也有帮助。

  所以实验中 TTT-Linear 和 TTT-MLP 主要应用到 Mamba 骨干上,其他训练细节也严格遵照 Mamba 论文中的设置。

  最终在 Pile 数据集短上下文测试中:

  2k 上下文时,TTT-Linear、Mamba 和 Transform 具有相当的性能,TTT-MLP 的表现略差。

  8k 上下文时,TTT-Linear 和 TTT-MLP 都优于 Mamba 和 Transformer,应用在 Transformer 骨干的 TTT-MLP(T)在 1.3B 参数左右也略好与 Mamba。

  总的来说,随着上下文长度的增长,TTT 层相对于 Mamba 的优势也会扩大。

  另外团队猜测,线性模型比 MLP 表达能力差,因此从 Mamba 骨干的卷积中受益更多。

  长上下文实验使用 Pile 的子集 Books3:

  32k 上下文,TTT-Linear 和 TTT-MLP 的表现都优于曼巴,类似于 Pile 8k 的观察。即使是带有 Transformer 骨干的 TTT-MLP(T)表现也略好于曼巴。

  1.3B 参数尺度上,TTT-MLP(T)仅比 TTT-MLP(M)稍差,Transformer 骨干可能更适合论文评估范围之外的更大模型和更长的上下文。

  在 A100 上测试速度,TTT-Linear 在预填充阶段比 Mamba 稍快,解码阶段几乎与 Mamba 速度相同。TTT-MLP 相比 Transformer 整体上也有线性复杂度的优势。

  共同一作 Karan Dala 表示:我一直被问到的一个问题是,我们是否相信 TTT 就是 "Transformer 杀手 ",我仍然认为我们需要继续努力。

  隐藏状态可以是任意模型,但目前的研究只涉及了线性模型和小型 MLP,更复杂的还有待研究。

  隐藏状态模型的学习可以用 Adam 代替普通的梯度下降等等。

  还可用于视频建模

  三位共同一作中:

  Yu Sun 博士毕业于 UC Berkeley,目前是斯坦福大学博士后。

  Xinhao Li 是电子科技大学校友,硕士毕业于 UCSD。

  Karan Dalel 本科毕业于 UC Berkley,正在机器人初创公司 1X 实习。

  最后,联合指导 UCSD 助理教授王小龙还透露,TTT 方法除了语言模型,还适用于视频。

  TTT 就是 "Transformer 杀手 ",我仍然认为我们需要继续努力。

  将来在对长视频进行建模时,我们可以密集地采样帧而不是采样 1 FPS,这些密集帧对 Transformer 来说是一种负担,但对 TTT 层来说是一种福音。

科技渐变粒子背景

  论文地址:

  https://arxiv.org/abs/2407.04620

  参考链接:

  [ 1 ] https://x.com/karansdalal/status/1810338845659131940

  [ 2 ] https://x.com/xiaolonw/status/1810387662060269668