《Titans: Learning to Memorize at Test Time》原文:https://arxiv.org/pdf/2501.00663v1
Titans 架构 非官方实现:https://github.com/lucidrains/titans-pytorch
一、 研究背景与动机:Transformer 的局限性与人类记忆的启发
1. Transformer 的局限性:长序列处理的瓶颈
Transformer 模型自提出以来,凭借其强大的自注意力机制,在自然语言处理、计算机视觉等领域取得了革命性的进展。然而,随着任务复杂度的提升,Transformer 在处理长序列时逐渐暴露出一些关键问题:
- 计算复杂度高,限制模型扩展性:
- 自注意力机制需要计算每个 token 与序列中所有其他 token 之间的相似度,其时间复杂度和空间复杂度均为 O(N²),其中 N 为序列长度。
- 这意味着当序列长度增加时,计算量和内存消耗呈平方级增长,严重限制了模型处理长序列的能力,例如在处理长文本、视频理解或长期时间序列预测等任务时,Transformer 往往力不从心。
图 1:自注意力机制的计算过程。
- 上下文窗口有限,难以捕捉长距离依赖关系:
- 为了缓解计算负担,Transformer 通常采用固定长度的上下文窗口(例如 512 或 1024),这意味着模型只能关注到当前窗口内的信息。
- 然而,许多现实世界的任务都需要模型能够捕捉到更长时间范围内的依赖关系,例如理解长篇文章或对话中的上下文信息,整合视频中不同时间点的信息,以及利用历史数据中的长期趋势和模式进行预测。
2. 线性 Transformer 的折中方案:效率与性能的矛盾
为了解决 Transformer 的计算瓶颈,研究人员提出了线性 Transformer,其主要改进在于:
- 用核函数替代 softmax: 将自注意力机制中的 softmax 计算替换为核函数,从而将计算复杂度降低为 O(N)。
- 可并行化推理: 线性 Transformer 的计算过程可以表示为循环形式,从而实现更高效的推理。
然而,线性 Transformer 也存在一些局限性:
- 性能下降:
- 核技巧使得模型退化为线性循环网络,数据被压缩为矩阵值状态,导致其性能不如标准 Transformer。
- 这种压缩方式难以有效捕捉复杂的非线性依赖关系。
- 内存管理问题:
- 线性 Transformer 将历史数据压缩到固定大小的矩阵中,但当处理非常长的上下文时,这种压缩方式会导致内存溢出,从而影响模型性能。
图 2:线性 Transformer 的内存更新过程。
3. 人类记忆系统的启发:构建更强大的长时记忆机制
为了克服上述挑战,作者从人类记忆系统中汲取灵感:
- 记忆与学习的关系:
- 论文借鉴了神经心理学文献中关于记忆和学习的定义,将记忆视为由输入引起的神经更新,将学习定义为在给定目标的情况下获取有效且有用的记忆的过程。
- 这意味着,有效的学习离不开强大的记忆机制。
- 人类记忆的多层次性:
- 人类记忆系统并非单一结构,而是由多个子系统组成,例如短期记忆、工作记忆和长期记忆,每个子系统都有不同的功能和组织结构,并能够独立运作。
- 这种多层次性使得人类能够高效地存储、检索和管理信息。
- 现有模型的不足:
- 现有神经网络架构(从 Hopfield 网络到 LSTM 和 Transformer)在处理泛化、长度外推和推理方面都存在挑战,而这些能力对于许多现实世界的复杂任务至关重要。
- 这些架构虽然从人脑中汲取灵感,但都缺乏对长时记忆的有效建模,以及对记忆系统多层次性的模拟。
二、 核心创新:神经长时记忆模块与 Titans 架构
基于以上思考,作者提出了以下创新点:
1. 神经长时记忆模块 (Neural Long-term Memory Module)
(1) 设计理念:
- 元上下文学习机制:
- 该模块被设计为一个元模型,在测试时学习如何将数据记忆/存储到其参数中。
- 这种在线学习的方式使得模型能够根据当前输入动态调整记忆,而不是依赖于预训练时固定的记忆。
- 基于惊讶度的记忆更新:
- 作者借鉴了人类记忆机制中“令人惊讶的事件更容易被记住”的特点,提出了一种基于惊讶度的记忆更新机制。
- 惊讶度通过计算神经网络相对于输入的梯度来衡量,梯度越大,说明输入数据与历史数据差异越大,更值得被记住。
- 这种方法可以有效地捕捉到数据中的关键信息,并将其存储到长时记忆中。
- 相比之下,线性 Transformer 只能基于当前的输入数据进行线性变换,难以有效捕捉长距离的依赖关系。
图 3:基于惊讶度的记忆更新机制。
(2) 关键技术:
- 动量机制:
- 为了防止模型被单个令人惊讶的事件过度影响,作者引入了动量机制,将过去时刻的惊讶度也考虑在内。
- 这意味着模型会综合考虑当前输入和历史输入的惊讶度,从而实现更平滑的记忆更新。
- 衰减机制:
- 为了防止内存溢出,作者还引入了衰减机制,通过权重衰减的方式逐渐遗忘不重要的信息。
- 该机制可以看作是一种门控机制,可以根据需要选择性地清除记忆。
- 作者指出,这种衰减机制是现代循环模型中遗忘机制的泛化,并且与元神经网络在小批量梯度下降、动量和权重衰减下的优化等价。
(3) 记忆结构:
- 与传统线性记忆模型不同,作者采用了多层感知机 (MLP) 作为记忆模块。
- MLP 具有更强的非线性表达能力,能够更有效地存储和检索复杂的信息。
- 相比之下,线性 Transformer 只能使用矩阵值状态来存储信息,难以捕捉复杂的非线性关系。
2. Titans 架构:整合长时记忆与短期记忆
在设计出神经长时记忆模块之后,作者进一步思考如何将其有效地整合到深度学习架构中,并提出了 Titans 架构,其主要特点如下:
(1) 三个超头协同工作:
- 核心 (Core):
- 由短期记忆组成,负责处理数据的主要流程。
- 使用有限窗口大小的注意力机制,例如滑动窗口注意力 (SWA) 或全连接注意力 (FCA)。
- 短期记忆可以看作是短时记忆,用于捕捉当前上下文中的依赖关系。
- 长期记忆 (Long-term Memory):
- 负责存储/记住长期过去的信息。
- 采用上述神经长时记忆模块。
- 长期记忆可以看作是长时记忆,用于存储和检索更长时间范围内的信息。
- 持久记忆 (Persistent Memory):
- 是一组可学习但与数据无关的参数,用于编码有关任务的先验知识。
- 类似于 Transformer 中全连接层的参数,但具有不同的功能。
- 持久记忆可以看作是元记忆,用于存储任务相关的知识,例如语法规则、常识知识等。
图 4:Titans 架构示意图 (MAC 变体)。
(2) 三种不同的整合方式:
- 记忆作为上下文 (MAC):
- 将长时记忆和持久记忆与输入序列连接起来,作为当前上下文的补充信息。
- 注意力机制决定哪些信息需要存储到长时记忆中。
- 在测试时,持久记忆参数保持固定,注意力模块权重进行上下文学习,而长时记忆模块继续学习/记忆信息。
- 这种设计使得模型能够根据当前输入灵活地利用长时记忆信息。
- 记忆作为门控 (MAG):
- 使用滑动窗口注意力作为短期记忆,神经记忆模块作为长期记忆。
- 通过门控机制将两者结合起来,例如使用可学习的向量值权重对两者进行归一化,然后应用非线性激活函数。
- 这种设计可以看作是一种多头架构,其中不同头的结构不同。
图 5:Titans 架构的不同变体 (MAC 和 MAG)。 - 记忆作为层 (MAL):
- 将神经记忆模块作为深度神经网络的层使用,在注意力模块之前压缩过去和当前上下文信息。
- 这种设计在文献中更为常见,例如 H3 模型。
(3) 优势:
- 更灵活的记忆管理:
- 通过将记忆模块作为上下文或门控分支,Titans 架构能够根据当前输入动态地利用长时记忆信息。
- 这与将记忆模块作为层使用的传统方法相比,更具灵活性。
- 更强的表达能力:
- 三个超头的协同工作,使得 Titans 架构能够更有效地处理长序列数据,并整合短期记忆、长期记忆和持久记忆的优势。
- 可扩展性:
- 相比 Transformer,Titans 架构在处理长序列时具有更好的可扩展性,能够在更大的上下文窗口下保持高性能。
三、 实验结果与分析:验证 Titans 架构的有效性
作者在多个任务上进行了广泛的实验,以评估 Titans 架构及其变体的性能:
1. 语言建模与常识推理:
- 实验设置:
- 使用了三种不同规模的 Titans 模型 (340M, 400M, 760M 参数) 以及多个基线模型,包括 Transformer++, RetNet, GLA, Mamba, Mamba2, DeltaNet, TTT 和 Gated DeltaNet。
- 训练数据采用 FineWeb-Edu 数据集。
- 主要结果:
- 在非混合模型中,神经长时记忆模块在困惑度和准确率指标上均取得了最佳性能。
- Titans 的三种变体 (MAC, MAG, MAL) 均优于 Samba (Mamba + 注意力) 和 Gated DeltaNet-H2 (Gated DeltaNet + 注意力)。
- MAC 在处理长距离依赖关系时表现更佳,而 MAG 和 MAC 均优于 MAL 变体。
图 6:Titans 与基线模型在语言建模和常识推理任务上的性能对比。
2. “针在干草堆中”任务:
- 实验设置:
- 使用 RULER 基准测试中的 Single NIAH (S-NIAH) 任务,评估模型在 2K, 4K, 8K 和 16K 长度序列上的检索能力。
- 主要结果:
- 神经长时记忆模块在所有三个任务中均取得了最佳结果。
- Titans 变体也表现出色,其中 MAC 变体表现最佳。
3. BABILong 基准测试:
- 实验设置:
- 该任务要求模型在极长的文档中推理分布的事实信息。
- 分为少样本设置和微调设置。
- 主要结果:
- 在少样本设置中,Titans 优于所有基线,包括参数数量更大的模型,例如 GPT-4 和 GPT4o-mini。
- 在微调设置中,Titans 也优于所有模型,即使是像 GPT-4 这样的超大模型。
- 与基于 Transformer 的记忆模型 (RMT) 相比,Titans 表现出更好的性能,主要归功于其强大的记忆能力。
图 7:Titans 与基线模型在 BABILong 基准测试上的性能对比。
4. 时间序列预测:
- 实验设置:
- 使用 Simba 框架,将 Mamba 模块替换为神经长时记忆模块。
- 在 ETT, ECL, Traffic 和 Weather 基准数据集上进行评估。
- 主要结果:
- 神经长时记忆模块优于所有基线,包括基于 Mamba, 线性模型和 Transformer 的架构。
5. DNA 建模:
- 实验设置:
- 在 GenomicsBenchmarks 上评估预训练模型的下游任务性能。
- 主要结果:
- Titans (LMM) 在不同的下游基因组学任务中均具有竞争力,与最先进的方法不相上下。
6. 效率分析:
- 主要结果:
- 与其他循环模型相比,神经长时记忆模块的训练速度略慢,主要原因是其具有更深的记忆和更复杂的转换过程,以及 Mamba2 实现了高度优化的内核。
- Titans (MAL) 比基线以及记忆模块更快,主要原因是使用了 FlashAttention 的高度优化的内核。
7. 消融研究:
- 主要结果:
- 神经记忆设计的所有组件均对性能有积极贡献,其中权重衰减, 动量, 卷积和持久记忆的贡献最大。
- 架构设计对性能也有重要影响,MAC 和 MAG 在语言建模和常识推理任务中表现接近,而 MAC 在长上下文任务中表现更优。
四、 论文的创新点与优势
- 提出了一种新颖的神经长时记忆模块:
- 借鉴了人类记忆机制中的关键要素,例如惊讶度, 动量和遗忘机制,实现了更有效的记忆更新和存储。
- 采用深度神经网络作为记忆模块,赋予模型更强的表达能力。
- 设计了 Titans 架构,将长时记忆与短期记忆有机结合:
- 提出了三种不同的整合方式,为不同应用场景提供了灵活的选择。
- 核心, 长期记忆和持久记忆三个超头的协同工作,使得模型能够更有效地处理长序列数据。
- 在多个任务上均表现出色:
- 无论是语言建模, 常识推理, 还是时间序列预测和 DNA 建模,Titans 架构均展现出强大的性能,优于现有的 Transformer 和线性循环模型。
- 可扩展性强:
- 能够在更大的上下文窗口下保持高性能,为处理超长序列提供了可能性。
五、 未来展望
尽管 Titans 架构在多个方面取得了令人瞩目的成果,但仍有以下方向值得进一步探索:
- 探索更复杂的记忆模块架构:
- 例如,引入层次化记忆结构,或者将记忆模块与图神经网络等其他模型相结合。
- 开发更高效的记忆更新和存储机制:
- 例如,利用稀疏化技术或量化技术来降低内存消耗和计算成本。
- 将 Titans 架构应用于更广泛的领域:
- 例如,视频理解, 机器人控制, 推荐系统等。
- 探索更有效的训练策略:
- 例如,引入更先进的优化算法,或者利用元学习来加速模型训练。
- 研究 Titans 架构的可解释性:
- 深入理解 Titans 如何存储和利用长时记忆信息,可以为构建更强大的 AI 系统提供新的思路。
六、 总结
这篇论文的核心贡献在于:
- 提出了一种新型的神经长时记忆模块,其设计灵感来源于人类记忆系统,并结合了深度学习中的关键概念,例如梯度下降, 动量和权重衰减。
- 构建了 Titans 架构,将长时记忆与短期记忆有机结合,并探索了三种不同的整合方式,为不同应用场景提供了灵活的选择。
- 通过严谨的实验验证了 Titans 的优越性能,在多个任务上均表现出色,特别是在处理长上下文任务时,展现了强大的扩展能力和更高的准确率。