机器之心报道
编辑:杜伟
正如论文一作所说,「新架构 Titans 既比 Transformer 和现代线性 RNN 更有效,也比 GPT-4 等超大型模型性能更强。」
终于,在 2017 年推出影响 AI 行业长达 8 年的 Transformer 架构之后,谷歌带来了全新的架构 Titans。这次,谷歌的重点是将推理领域非常重要的测试时(test-time)计算用在了记忆(memory)层面。
在谈到推出 Titans 的初衷时,论文一作 Ali Behrouz 表示,「注意力机制一直是大多数 LLM 进展的重要组成部分,不过它无法扩展到长上下文。因此,Titans 应运而出,它成为了一种同时具备注意力机制和元上下文记忆的结构,可以在测试时学习记忆。该架构可以将上下文窗口扩展到 200 万 tokens。」
图源:https://x.com/behrouz_ali/status/1878859086227255347
这意味着,谷歌 Transformer 迎来了它的「继任者」。
图源:https://x.com/mark_k/status/1878896628654022993
多年来,研究人员一直在广泛探究如何有效地利用循环模型和注意力机制,其中循环模型旨在将数据压缩到固定大小的记忆(称为隐状态)中,而注意力机制允许处理整个上下文窗口,捕捉所有 token 的直接依赖。不过,更准确的依赖建模往往伴随着二次成本,导致模型只能处理固定长度的上下文。
因此,谷歌提出了一种新的长期神经记忆模块(neural memory module),它能够学习记忆历史上下文,并帮助注意力机制在利用过去已久信息的同时处理当前上下文。结果表明,这种神经记忆具有快速并行化训练的优势,同时还能保持快速推理。
从记忆的角度来看,谷歌认为注意力机制虽然受限于上下文但可以更准确地建模依赖关系,因此可以起到短期记忆的作用;而神经记忆能够对数据进行记忆,起到了长期、更持久的记忆作用。基于这两个模块,谷歌引入了一个全新的系列架构 —— Titans,通过三种变体有效地将记忆融合到该系统架构中,它们分别是记忆作为上下文(Memory as a Context,MAC)、记忆作为门(Memory as a Gate,MAG)和记忆作为层(Memory as a Layer,MAL)。
在语言建模、常识推理、基因组学和时序预测任务上的实验结果表明,Titans 架构比 Transformer 和近年来的现代线性循环模型更有效。另外,在大海捞针(needle-in-haystack)中,Titans 架构能够有效地扩展到超过 200 万 tokens 的上下文窗口,并且比基准模型实现了更高的准确性。
- 论文标题:Titans: Learning to Memorize at Test Time
- 论文地址:https://arxiv.org/pdf/2501.00663v1
另外,论文作者之一 Peilin Zhong 为谷歌 NYC 算法与优化团队的研究科学家,2021 年加入谷歌。他本科毕业于清华姚班,博士毕业于哥伦比亚大学。
目前,已经有人搞出了有关 Titans 架构的非官方实现,感兴趣的读者可以去看一下。
GitHub 地址:https://github.com/lucidrains/titans-pytorch
学习测试时记忆
谷歌详细介绍了长期神经记忆模块,它成为了一种可以在测试时学习记忆的元模型。
长期记忆
为了设计一个长期神经记忆模块,我们需要模型能够将过去历史的抽象编码到其参数中。因此,一个简单的思路是训练神经网络并期望它能够记住自己的训练数据,然而记忆几乎一直是神经网络中令人头疼的现象,它限制了模型的泛化能力,还引发隐私问题,因此导致测试时性能不佳。
基于此,谷歌认为需要一个在线元模型来学习如何在测试时记忆或忘记数据。在这种设置下,模型学习一个能够记忆的函数,但不会过拟合训练数据,从而在测试时实现更好的泛化性能。
学习过程和意外指标(Learning Process and Surprise Metric)。训练长期记忆的关键思路是将训练视为在线学习问题,其中将过去信息 x_1, …, x_t-1 压缩到长期神经记忆模块中。人类往往能够记住背离预期(令人惊讶)的事件,受此启发,模型意外可以简单定义为它相对于输入的梯度。梯度越大,输入数据与过去数据的偏差就越大。因此,使用这个意外分数,可以将记忆更新如下:
这一意外指标可以导致在重大意外时刻之后出现重要信息缺失。从人类记忆的角度来看,即使一个事件令人难忘,但它可能不会在长时间内持续让我们感到惊讶。为了改进这一现象,谷歌将意外指标分解为了(1)过去意外,它衡量最近过去的意外程度;(2)瞬时意外,它衡量传入数据的意外。
并行化长期记忆训练
作为函数块的参数(Parameters as the Function of Chunks)。谷歌没有让参数 a_t、θ_t 和 η_t 依赖于输入,而是让它们成为函数块。尽管失去了表达能力,但可以帮助更快地训练。在这种情况下,谷歌在每个块中对每一个 a、θ 和 η 都使用了相同的值。在实验中,谷歌将这些参数作为了 token 的函数,并表示,这种简化(即作为块函数)可能是未来工作感兴趣的地方,以便以更高效的方式训练更大的模型。
下图 1 展示了如何并行并在使用矩阵乘法时完成神经记忆训练。
如何融合记忆?
接下来需要解决的一个重要问题是:如何有效且高效地将神经记忆融合到深度学习架构中?
从记忆的角度来看,Transformer 中的 K 和 V 矩阵对可以解释为联想记忆块。由于它们对依赖关系的精确建模以及有限的上下文窗口,它们可以被用作短期记忆模块,以处理当前上下文窗口大小。另一方面,神经记忆能够不断从数据中学习并存储在其权重中,因而可以发挥长期记忆的作用。谷歌通过三个不同的 Titans 变体来回答以上问题。
记忆作为上下文(Memory as a Context,MAC)
Titans 的第一个变体 MAC 的架构设计如下图 2 所示,将记忆作为当前信息的上下文。
该架构具有两个关键优势:一是注意力模块同时具有历史和当前上下文,能够根据当前数据决定是否需要长期记忆信息,二是注意力模块帮助长期记忆只存储来自当前上下文的有用信息。这意味着,并非每个片段中的所有 token 都是有用的,记忆所有 token 可能会导致内存溢出。因此,注意力模块帮助记忆了解哪些信息是有用的,从而更好地管理内存容量。
另外,在测试时,(i)持久记忆参数是固定的,它们编码了有关任务的知识,不应改变;(ii)注意力模块权重是上下文学习器;(iii)长期记忆模块在测试时仍然学习(记忆)信息。也就是说,即使在测试时,神经记忆的权重也会更新,这是因为权重对过去已久的抽象进行了编码。
记忆作为门(Memory as a Gate,MAG)
Titans 第二个变体 MAG 的架构设计如下图 4 所示:
在其中一个分支中,谷歌直接使用输入数据来更新长期记忆;在第二个分支中,谷歌使用了滑动窗口注意力(SWA):
该架构的整体注意力掩码如下图 3b 所示,其中滑动窗口注意力(SWA)充当精确的短期记忆,而神经记忆模块充当模型的衰减记忆。该设计也可以看作是多头架构,其中各头的结构不同。
记忆作为层(Memory as a Layer,MAL)
Titans 的第三个变体 MAL 使用了深度神经网络,这种架构设计在文献中更为常见,其中混合模型堆叠具有完整或滑动窗口注意力的循环模型。
给定输入 x,可以得到以下:
其中 SW-Attn 是滑动窗口注意力。
无注意力记忆(Memory Without Attention)。从记忆的角度来看,谷歌期望记忆系统的每个组件都能独立工作,即使其他组件受到了干扰。因此,即使没有短期记忆(即注意力),长期记忆模块仍然应该是一个强大的模型。谷歌在实验中将这种变体称为 Titans (LMM)。
架构细节
卷积(Convolution)。遵循最近的现代线性循环模型,谷歌在每个查询、键和值投影后都融合了一个 1D 深度可分离卷积层。这些 1D 卷积可以提升性能,并且计算高效。
门控(Gating)。谷歌还在最终输出投影之前利用线性层进行归一化和门控。
实验结果
谷歌在实验部分关注上述三种 Titans 变体,分别是 MAC、MAG 和 MAL,以及单独的神经记忆模块。对于每个模型,谷歌使用了四种尺寸的模型,参数分别是 (i) 170M、(ii) 340M、(iii) 400M 和 (iv) 760M。
语言建模
谷歌首先关注模型在语言建模和常识推理任务中的困惑度。下表 1 报告了 Titans 变体和三种不同大小(340M、400M 和 760M)基线的结果。在包括 Transformer++ 在内的非混合模型中,神经记忆模块在困惑度和准确度测量方面均取得了最佳性能。
谷歌还发现,Titans 的三种变体(MAC, MAG 和 MAL)都优于 Samba (Mamba + 注意力)和 Gated DeltaNet-H2(Gated DeltaNet + 注意力)。
大海捞针
下表 2 结果显示,与基线相比,神经记忆模块均取得了最佳结果。
谷歌将这种卓越的表现归因于 Titans 与现有序列模型的三个关键差异:(1)与 TTT 相比,神经记忆能够通过使用动量和遗忘机制(即权重衰减)更好地处理记忆容量。因此,随着序列长度的增加,神经记忆的性能不会下降,呈现出一致的趋势;(2)与具有门控(遗忘)机制的 Mamba2 相比,Titans 具有深度非线性记忆,从而实现了更好的记忆管理。此外,与神经记忆和 DeltaNet 不同,Mamba2 无法移除记忆,因此在增加序列长度时,其性能会出现显著下降;(3)与 DeltaNet 相比,尽管它能够使用增量规则移除记忆,但无法擦除记忆,缺乏遗忘机制。
最终,正如预期的那样,使用 Titans 变体时能看到相当或更好的结果,其中最佳结果来自 MAC。
BABILong 基准
在微调设置中,谷歌将小型微调版本的 Titans (MAC) 与其他模型进行了比较。
Titans 和基线的结果如下图 6b 所示。Titans 的表现优于所有模型,甚至比 GPT4 这样的超大型模型还要好。此外,与基于 Transformer 的 RMT 等记忆模型相比,Titans 表现出更好的性能,这主要归功于其强大的记忆。
深度记忆的影响
接下来的实验评估了深度记忆对 wall-clock 训练时间和模型性能的影响。
下图 7 中报告了 Titans(LMM)和基线的困惑度与序列长度的关系。有趣的是,随着记忆深度的增加,该模型可以在所有序列长度上实现更好的困惑度。此外,当模型的参数量较少时,更深的记忆模块对序列长度的鲁棒性更强。随着参数量的增加,所有模型在较长的序列上都表现出更好的性能。
时序预测
为了展示记忆模块在更广泛任务中的有效性,谷歌评估了 Titans 在时序预测任务中的表现。结果如下表 3 所示,谷歌的神经记忆模块优于所有基线,包括基于 Mamba、线性和 Transformer 的架构。
DNA 建模
谷歌还进一步评估了神经记忆模块在 DNA 建模任务上的表现,结果如下 4 所示,相较于当前的 SOTA 架构,Titans(LMM)在不同的下游基因组任务中仍具有竞争力。
效率
谷歌还对 Titans 与当前 SOTA 序列模型的效率进行了比较,下图 9 显示了不同序列长度 x 批大小的模型的训练吞吐量。可以看到,谷歌神经记忆模块比 Mamba2 和 Gated DeltaNet 稍慢,不过 Titans (MAL) 比基线和神经记忆模块都要快。
更多技术细节和实验结果请参阅原论文。