摘要

作为推荐系统(RS)的关键最终环节,多任务融合(MTF)负责将多任务学习(MTL)生成的多个评分整合为最终评分,以最大化用户满意度并决定最终推荐结果。近年来,业界开始采用强化学习(RL)进行MTF,以优化推荐会话(session)中的长期(long-term)用户满意度。然而,当前用于MTF的离线RL算法存在三个严重缺陷

1) 为避免分布外(OOD:out-of-distribution)问题,其约束条件过于严格,严重损害模型性能;
2) 算法无法感知训练数据生成所用的探索策略,且未与真实环境交互,导致仅能学习次优策略;
3) 传统探索策略效率低下且损害用户体验。

针对这些问题,我们提出面向大规模推荐系统MTF的创新方法IntegratedRL-MTF,其核心创新包括:

  • 离线/在线策略融合:通过将离线RL模型与在线探索策略相整合,放宽过严的约束条件,显著提升性能;
  • 高效探索策略:剔除低价值探索空间(low-value exploration space),聚焦潜在高价值状态-动作对的探索;
  • 渐进式训练:借助探索策略进一步优化模型表现。

在腾讯新闻短视频频道的离线和在线实验表明,该方法显著优于基线模型。目前IntegratedRL-MTF已在腾讯推荐系统及其他大型推荐场景中全面部署,取得显著效果提升。

1 引言

推荐系统(Recommender Systems, RSs)[1, 2]通过分析用户偏好提供个性化推荐服务,目前已广泛应用于短视频平台[3, 7, 14]、视频平台[4, 5]、电子商务平台[6, 8-11]及社交网络[12, 13]等场景,每日服务数十亿用户。工业级推荐系统通常包含三阶段流程:候选生成(candidate generation)、排序(ranking)和多任务融合(Multi-Task Fusion, MTF)[4, 15]。在候选生成阶段,系统需从数百万乃至数十亿候选项中筛选出数千个候选项目;排序阶段则采用多任务学习模型(Multi-Task Learning, MTL)[4, 8, 16-18]预测用户点击、观看时长、快速滑动、点赞、分享等多种行为的预估分数;最终通过MTF模型将MTL输出的多任务分数融合为单一分数,生成候选项目的最终排序[15],从而决定推荐结果。然而目前针对MTF的研究仍缺乏实质性突破。

MTF的核心目标是最大化用户满意度。用户满意度通常通过加权计算单次推荐或推荐会话中的多种反馈指标来评估,包括观看时长、有效点击、点赞、分享等行为。其中,推荐会话定义为用户从开始访问推荐系统到离开的完整过程,可能包含一次或多次连续请求。

如图1

在腾讯新闻、抖音和快手等推荐系统中,当前推荐结果会对后续推荐产生显著影响,特别是在同一推荐会话内。因此,我们需要同时考虑当前推荐的即时收益和整个会话内的长期累积收益。最近,部分研究[15,24,25]开始采用离线强化学习(RL)[26]来寻找最优融合权重,以最大化长期收益。与前述方法相比,RL不仅考虑会话内的累积奖励,还能推荐既满足当前用户需求又能带来长期正向交互的内容。此外,RL相比进化策略(ES)具有更强的模型性能和更高的样本效率[23]。目前,RL已在腾讯[15]等多家公司的推荐系统中应用于MTF任务。

然而,现有RL-MTF方法存在以下严重问题[15,26-31]:

  • 1)为避免分布外(OOD)问题,现有离线RL算法采用了过于严格复杂的约束条件,严重损害了模型性能;
  • 2)在线探索与离线训练相互割裂,离线RL算法无法感知训练数据背后的探索策略,也不再与真实环境交互,因此只能学习到次优策略;
  • 3)现有探索策略效率低下且损害用户体验。

针对这些问题,我们提出了一种专门为推荐系统MTF任务设计的新方法IntegratedRL-MTF。首先,该方法将离线RL模型与我们的在线探索策略相结合。在离线训练时,可以直接获取探索策略生成的训练数据分布,从而放宽为避免OOD问题而设置的过度约束,显著提升RL模型性能。其次,我们设计了一种简单但极其高效的探索策略,不仅加快了模型迭代速度,还减少了对用户体验的负面影响,这对商业公司具有重要价值。最后,我们提出渐进式训练模式,借助高效探索策略通过多轮在线探索和离线训练的迭代,使目标策略快速收敛至最优策略。

我们使用自设计的新评估指标(该指标更简单且更适用于RL-MTF评估)在相同数据集上进行了离线实验对比。此外,在大规模推荐系统中进行的在线实验表明,我们的RL模型显著优于其他模型。IntegratedRL-MTF已在我们的推荐系统中稳定运行近一年,并推广至腾讯其他大型推荐系统,取得了显著效果提升。本文将重点阐述IntegratedRL-MTF的核心思想,不深入讨论实现细节。

本研究的主要贡献包括:

  • 系统分析了现有RL-MTF方法,指出其存在约束条件过严影响性能、在线探索与离线训练割裂导致策略次优、传统探索策略低效损害用户体验等核心问题
  • 提出面向大规模推荐系统MTF的定制化RL算法,通过离线RL与探索策略的融合放宽约束条件提升性能,并采用渐进式训练模式实现策略快速收敛
  • 在腾讯新闻短视频频道进行实验验证:离线实验采用新设计的评估指标,在线A/B测试显示模型显著优于基线(用户有效消费时长提升+4.64%,用户停留时长提升+1.74%)

2 问题定义

本节给出腾讯新闻短视频频道(与抖音类似)中RL-MTF的问题定义。如前所述,在当前推荐会话中,推荐结果会对后续推荐产生显著影响。在每个时间步$t$,推荐系统(RS)接收到用户请求后:

  1. 首先从数百万内容中筛选出数千候选项目
  2. 多任务学习(MTL)模型预测每个候选的多种用户行为得分
  3. 多任务融合(MTF)模型使用公式(1)生成融合权重,将MTL模型输出的多个得分组合为最终得分
  4. 最后将推荐列表发送给用户,并将用户反馈上报至平台数据系统

我们将上述融合问题建模为推荐会话内的马尔可夫决策过程(MDP)。在这个MDP中,推荐系统作为智能体与用户(环境)交互,进行序列化推荐,目标是最大化会话内的累积奖励。该MDP框架包含以下关键组件[26]:

  • 状态空间$\mathcal{S}$:是状态$s$的集合,包括用户画像特征(如年龄、性别、top K兴趣、刷新次数等)和用户历史行为序列(如观看、有效点击、点赞等)

  • 动作空间$\mathcal{A}$:是RL模型生成的动作$a$的集合。在我们的问题中,动作$a$是一个融合权重向量$(w_1,…,w_k)$,其中每个元素对应公式(1)中的不同幂次项或偏置项

  • 奖励$\mathcal{R}$:当推荐系统在状态$s_t$采取动作$a_t$并向用户发送推荐列表后,用户对这些内容的各种行为将上报至RS,基于这些行为计算即时奖励$r(s_t,a_t)$

  • 状态转移概率$\mathcal{P}$:转移概率$p(s_{t+1} s_t,a_t)$表示采取动作$a_t$后从状态$s_t$转移到$s_{t+1}$的概率。在我们的问题中,状态包含用户画像特征和用户历史行为序列,因此下一状态$s_{t+1}$取决于用户反馈且是确定性的
  • 折扣因子$\gamma$:决定智能体对未来奖励相对于即时奖励的重视程度,$\gamma \in [0,1]$

基于以上定义,在推荐系统中应用RL进行MTF的目标可以定义为:给定推荐会话内RS与用户以MDP形式交互的历史,如何学习最优策略以最大化累积奖励。

3 提出的解决方案

3.1 奖励函数

在推荐会话中,RS在状态$s_t$采取动作$a_t$计算每个候选的最终得分,并向用户发送推荐列表,随后用户的多种反馈会上报至RS,如图1所示。为了评估即时奖励,我们定义如公式(2)所示的即时奖励函数:

\[r(s_t,a_t) = \sum_{i=1}^k \alpha_i \cdot b_i\]

其中$\alpha_i$是行为$b_i$的权重。在我们的推荐场景中,用户行为$b_1,…,b_k$包括观看时长、有效消费(观看视频超过10秒)以及点赞、分享、收藏等交互行为。通过分析不同用户行为与用户停留时长的相关性,我们为这些行为设置了不同的权重。

3.2 在线探索

在训练RL模型之前,首先需要收集大量探索数据,这对模型性能有关键影响。然而,传统探索策略面临两个挑战[15,32]:

  • 低效率:在实践中,使用传统探索策略在大规模RS中收集足够的探索数据通常需要很长时间。例如,在我们的平台上使用动作噪声探索策略收集一次探索数据通常需要五天或更长时间。这影响了模型迭代速度并意味着收入损失

  • 对用户体验的负面影响:传统探索策略生成的过多探索动作(包括异常动作)会对用户体验产生显著负面影响,甚至导致用户流失,这是不可接受的

为解决上述问题,我们首先在推荐场景数据集上,对新学习的RL策略与基线RL策略在相同状态下生成动作的绝对差值分布进行了分析。为简化分析,我们将动作各维度的取值范围归一化至$[-1,1]$区间,并选取最重要的4维动作(包括有效消费、观看时长、播放完成率和正向行为率)进行说明,如图2所示。我们观察到,对于相同状态,新学习RL策略生成的动作通常不会与基线RL策略生成的动作产生显著偏离,这一现象也与我们的直觉相符。

基于此发现,我们提出了一种简单但极其高效的探索策略,如公式(3)所示,该策略根据基线策略为每个用户定义个性化的探索上下界:

\[a_{explore} = a_{baseline} + \delta,\quad \delta \sim \mathcal{U}(lower_b, upper_b)\]

探索动作由基线策略输出的动作加上由$lower_b$和$upper_b$定义的均匀分布随机扰动生成。我们通过统计分析精心选择了$lower_b$和$upper_b$的取值。该探索策略的核心思想是消除低价值探索空间,仅聚焦于探索潜在高价值的状态-动作对,如图3所示。相较于传统探索策略(本文以常用于生成探索数据的动作噪声探索策略为例,如图3珊瑚色曲线所示),我们的策略展现出极高的效率。在相同探索密度要求下,我们推荐场景中的探索策略效率约为动作噪声探索策略的210倍(具体分析见第4节)。此外,相比动作噪声探索策略,我们的策略能减少数据分布对RL-MTF模型训练的干扰。第3节详述的渐进式训练模式进一步扩展了探索策略的探索空间,因此可设置更小的个性化探索空间上下界。

```markdown

3.3 IntegratedRL-MTF:面向大规模推荐系统MTF定制的强化学习算法

为解决前文所述问题,我们提出名为IntegratedRL-MTF的新方法。下面将分别介绍其执行器网络(Actor Network)、评价器网络(Critic Network)和渐进式训练模式。

3.3.1 执行器网络

执行器网络的目标是为特定状态输出最优动作。遵循常规设置,我们在学习过程中构建两个执行器网络:

  • 当前执行器网络 $\pi(s)$
  • 目标执行器网络 $\pi’(s)$

$\pi(s)$ 通过将执行器网络与我们的探索策略相融合,实现了以下创新设计(如公式4-5所示):

  1. 约束松弛机制:通过整合在线探索策略的数据分布知识,放宽传统RL-MTF的严格约束条件
  2. 多评价器一致性惩罚项:基于多个评价器输出的一致性引入额外惩罚项,有效缓解外推误差

数学表达为:
\(\pi(s) = \arg\min_a \left[ \mathcal{L}_{actor} + \lambda \cdot \mathbb{E}_{\xi\sim\mathcal{U}}[\max_{j}Q_j(s,a+\xi) - \min_{j}Q_j(s,a+\xi)] \right]\)
其中$\lambda$为调节系数,$\xi$为探索噪声,$Q_j$表示第$j$个评价器网络。

在训练$\pi(s)$期间,如第3.2节所述,可以直接获取每个用户探索数据分布的上界和下界。因此,我们可以利用这一特性来简化过于严格的约束条件,并充分发挥$\pi(s)$的能力。如果$\pi(s)$在状态$s_t$生成的动作处于用户的上界和下界范围内,则公式4中第二项的值为零,即不施加惩罚以避免影响模型能力。否则,将根据超出用户上界或下界的偏差施加惩罚。通过这种方式,当前actor网络的性能相比现有方法得到显著提升,这一点在第4节的实验中得到验证。

此外,我们还引入了一个惩罚机制,该机制定义为多个独立critics[33]输出估计值的标准差,以减轻外推误差,这是公式4中的第三项。由于我们的探索策略具有极高的效率,在用户上下界范围内收集的探索动作与传统动作噪声探索策略相比具有显著更高的平均密度,这对模型优化极具价值。此外,与高斯扰动相比,个性化上下界内的随机扰动减轻了数据分布对模型训练的干扰。如果$\pi(s)$输出的动作处于用户的探索空间内,公式4中第三项的值会很小甚至可以忽略。否则,将施加相应的惩罚来减轻外推误差。

目标actor网络$\pi’(s)$是一个辅助网络,负责基于下一状态生成下一最优动作,以缓解由bootstrapping引起的过估计问题。其参数会使用当前actor网络进行周期性的软更新。

3.3.2 Critic网络

Critic网络$Q(s,a)$负责估计推荐会话中状态-动作对$(s,a)$的累积奖励。$Q(s,a)$还将critic网络与我们的探索策略相结合以避免外推误差。在我们的解决方案中,创建了多个独立的critic网络,这些网络被随机初始化并独立训练。每个critic网络的目标是最小化TD-error,如公式6所示。如果$\pi’(s)$在下一状态$s_{t+1}$生成的下一动作处于用户的上下界范围内,公式6中第二项的值为零。否则,将根据超出用户规定上下界的偏差施加惩罚。实践中,我们通常将critic网络数量设为24,这足以在我们的推荐场景中取得良好效果。为了获得更好的性能,我们为每个critic定义了一个目标网络,其参数会使用相应的critic网络进行周期性软更新。

3.3.3 渐进式训练模式

离线RL的一个严重缺点是当模型离线训练时,它仅依赖于之前收集的数据而不再与真实环境交互。离线训练期间缺乏实时交互会导致学习策略与实际环境之间的差异,这对离线RL算法的性能产生显著负面影响[15,26-31]。

为了在大规模RS中缓解这个问题,我们的解决方案采用渐进式训练模式,通过高效的探索策略进行多轮在线探索和离线模型训练来学习最优策略,使目标策略能够快速收敛到最优策略。由于我们的探索策略效率很高,我们将之前的单次数据探索和离线模型训练划分为五轮在线数据探索和离线模型训练。最新学习到的策略将作为下一轮在线探索的基线策略。通过迭代高效地探索环境,学习到的策略将不断改进,从而进一步提升我们RL模型的性能。

3.4 基于RL-MTF的推荐系统

我们在腾讯新闻短视频频道实现了IntegratedRL-MTF,如图4所示。我们的RL-MTF框架由两个组件组成:离线模型训练和在线模型服务。离线模型训练组件负责预处理探索数据和训练RL-MTF模型。在线模型服务组件主要负责在接收到用户请求时生成个性化最优动作,计算每个候选的最终得分。此外,在线模型服务组件还负责在线探索以收集训练数据。

4.实验

https://arxiv.org/pdf/2404.17589

Deepseek AI在《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》详细介绍了V2版的实现:

1.介绍

在过去的几年中,大型语言模型(LLMs)(Anthropic, 2023; Google, 2023; OpenAI, 2022, 2023)经历了快速发展,为通用人工智能(AGI)的曙光提供了初步的展望。通常,随着参数数量的增加,LLM的智能水平会显著提升,使其能够在各种任务中展现出涌现(emergent)能力(Wei et al., 2022)。然而,这种改进的代价是:更大的训练计算资源需求以及推理吞吐量的潜在下降。这些限制对LLM的广泛采用和利用构成了重大挑战。为了解决这一问题,我们提出了DeepSeek-V2,一个强大的开源混合专家(MoE)语言模型,其特点是通过创新的Transformer架构实现经济的训练和高效的推理。该模型总参数量为236B,每个token激活21B参数,并支持128K token的上下文长度。

我们在Transformer框架(Vaswani et al., 2017)中优化了注意力模块和前馈网络(FFNs),分别提出了多头隐注意力(MLA)DeepSeekMoE

  1. 在注意力机制方面,多头注意力(MHA)(Vaswani et al., 2017)的键值(KV)缓存对LLM的推理效率构成了显著障碍。为了解决这一问题,研究者们探索了多种方法,包括分组查询注意力(GQA)(Ainslie et al., 2023)和多查询注意力(MQA)(Shazeer, 2019)。然而,这些方法在减少KV缓存的同时往往会牺牲性能。为了兼顾两者,我们引入了MLA,这是一种配备低秩键值联合压缩(low-rank key-value joint compression)的注意力机制。实验表明,MLA在性能上优于MHA,同时显著减少了推理过程中的KV缓存,从而提升了推理效率。

  2. 在前馈网络(FFNs)方面,我们采用了DeepSeekMoE架构(Dai et al., 2024),该架构通过细粒度的专家分割和共享专家隔离,实现了更高的专家专业化潜力。与传统的MoE架构(如GShard(Lepikhin et al., 2021))相比,DeepSeekMoE展现出了显著优势,使我们能够以较低的成本训练强大的模型。在训练过程中,我们采用专家并行策略,并设计了补充机制来控制通信开销并确保负载均衡。

通过结合这两种技术,DeepSeek-V2在性能(图1(a))、训练成本和推理吞吐量(图1(b))方面均表现出色。我们构建了一个高质量、多来源的预训练语料库,包含8.1T token。与DeepSeek 67B(我们之前的版本)(DeepSeek-AI, 2024)使用的语料库相比,该语料库的数据量更大,尤其是中文数据,且数据质量更高。

  • 首先我们在完整的预训练语料库上对DeepSeek-V2进行预训练。
  • 然后,我们收集了1.5M个对话会话,涵盖数学、代码、写作、推理、安全等多个领域,用于对DeepSeek-V2 Chat(SFT)进行监督微调。
  • 最后,我们遵循DeepSeekMath(Shao et al., 2024)的方法,采用组相对策略优化(GRPO)进一步对齐模型与人类偏好,生成DeepSeek-V2 Chat(RL)。

我们在广泛的中英文基准测试中评估了DeepSeek-V2,并将其与代表性的开源模型进行了比较。评估结果表明,即使仅激活21B参数,DeepSeek-V2仍然在开源模型中表现出顶级性能,成为最强的开源MoE语言模型。

  • 图1(a)显示,在MMLU上,DeepSeek-V2仅以少量激活参数就达到了顶级性能。
  • 如图1(b)所示,与DeepSeek 67B相比,DeepSeek-V2节省了42.5%的训练成本,减少了93.3%的KV缓存,并将最大生成吞吐量提升至5.76倍。

我们还在开放式基准测试中评估了DeepSeek-V2 Chat(SFT)和DeepSeek-V2 Chat(RL)。值得注意的是,DeepSeek-V2 Chat(RL)在AlpacaEval 2.0(Dubois et al., 2024)上达到了38.9的长度控制胜率,在MT-Bench(Zheng et al., 2023)上获得了8.97的综合评分,在AlignBench(Liu et al., 2023)上获得了7.91的综合评分。英文开放式对话评估表明,DeepSeek-V2 Chat(RL)在开源聊天模型中具有顶级性能。此外,AlignBench的评估表明,在中文方面,DeepSeek-V2 Chat(RL)超越了所有开源模型,甚至击败了大多数闭源模型。

图片名称

图1 (a)不同开源模型中MMLU精度与激活参数的对比。(b) DeepSeek-67B(Dense)和DeepSeek-v2的训练成本和推理效率。

为了促进对MLA和DeepSeekMoE的进一步研究和开发,我们还向开源社区发布了DeepSeek-V2-Lite,这是一个配备MLA和DeepSeekMoE的小型模型。其总参数量为15.7B,每个token激活2.4B参数。关于DeepSeek-V2-Lite的详细描述见附录B。

在本文的其余部分,我们首先详细描述了DeepSeek-V2的模型架构(第2节)。随后,我们介绍了预训练工作,包括训练数据构建、超参数设置、基础设施、长上下文扩展以及模型性能和效率的评估(第3节)。接着,我们展示了对齐工作(alignment),包括监督微调(SFT)、强化学习(RL)、评估结果及其他讨论(第4节)。最后,我们总结了结论,探讨了DeepSeek-V2的当前局限性,并展望了未来的工作(第5节)。

2. 架构

总体而言,DeepSeek-V2仍然基于Transformer架构(Vaswani et al., 2017),其中每个Transformer块由一个注意力模块和一个前馈网络(FFN)组成。然而,对于注意力模块和FFN,我们设计并采用了创新的架构。对于注意力模块,我们设计了多头隐注意力(MLA),利用低秩键值联合压缩来消除推理时键值(KV)缓存的瓶颈,从而支持高效推理。对于FFN,我们采用了DeepSeekMoE架构(Dai et al., 2024),这是一种高性能的MoE架构,能够以较低的成本训练强大的模型。图2展示了DeepSeek-V2的架构示意图,本节将详细介绍MLA和DeepSeekMoE的细节。对于其他微小细节(例如层归一化和FFN中的激活函数),除非特别说明,DeepSeek-V2遵循DeepSeek 67B(DeepSeek-AI, 2024)的设置。

图片名称

图2 DeepSeek-V2架构示意图。MLA通过显著减少生成的KV缓存来确保高效的推理,DeepSeekMoE通过稀疏架构以经济的成本训练强模型

2.1 多头隐注意力:提升推理效率

传统的Transformer模型通常采用多头注意力(MHA)(Vaswani et al., 2017),但在生成过程中,其庞大的键值(KV)缓存会成为限制推理效率的瓶颈。为了减少KV缓存,研究者提出了多查询注意力(MQA)(Shazeer, 2019)和分组查询注意力(GQA)(Ainslie et al., 2023)。这些方法需要更少的KV缓存,但其性能无法与MHA媲美(我们在附录D.1中提供了MHA、GQA和MQA的消融实验)。

对于DeepSeek-V2,我们设计了一种创新的注意力机制,称为多头隐注意力(MLA)。MLA配备了低秩键值联合压缩,不仅性能优于MHA,而且所需的KV缓存显著减少。以下我们将介绍其架构,并在附录D.2中提供MLA与MHA的对比。

2.1.1 预备知识:标准多头注意力

我们首先介绍标准MHA机制作为背景。设:

  • $d$为嵌入维度
  • $n_h$为注意力头(attention heads)的数量
  • $d_h$为每个头的维度
  • $h_t \in \mathbb{R}^d$为第$t$个token在注意力层的输入

标准MHA首先通过三个矩阵$W_Q$、$W_K$、$W_V \in \mathbb{R}^{d_h n_h \times d}$分别生成$q_t$、$k_t$、$v_t \in \mathbb{R}^{d_h n_h}$:

\[q_t = W^Q h_t, \quad (1) \\ k_t = W^K h_t, \quad (2) \\ v_t = W^V h_t. \quad (3)\]

然后,$q_t$、$k_t$、$v_t$ 将被切分为 $n_h$ 个头以进行多头注意力计算:

\[[q_{t,1}; q_{t,2}; \dots; q_{t,n_h}] = q_t, \quad (4) \\ [k_{t,1}; k_{t,2}; \dots; k_{t,n_h}] = k_t, \quad (5) \\ [v_{t,1}; v_{t,2}; \dots; v_{t,n_h}] = v_t, \quad (6) \\ o_{t,i} = \sum_{j=1}^t \text{Softmax}\ _j \left( \frac{q_{t,i}^T k_{j,i}}{\sqrt{d_h}} \right) v_{j,i}, \quad (7) \\ u_t = W_O [o_{t,1}; o_{t,2}; \dots; o_{t,n_h}], \quad (8)\]

其中:

  • $q_{t,i}$、$k_{t,i}$、$v_{t,i} \in \mathbb{R}^{d_h}$ 分别表示第 $i$ 个注意力头的查询、键和值;
  • $W_O \in \mathbb{R}^{d \times d_h n_h}$ 表示输出投影矩阵。

在推理过程中,所有键和值都需要被缓存以加速推理,因此MHA需要为每个token缓存 $2 n_h d_h l$ 个元素($l$ 为层数)。在模型部署中,这种庞大的KV缓存是一个巨大的瓶颈,限制了最大batch-size和序列长度。

2.1.2 低秩键值联合压缩

MLA的核心是:通过低秩联合压缩键和值来减少KV缓存

\[c^{KV}_t = W^{DKV} h_t, \quad (9) \\ k^C_t = W^{UK} c^{KV}_t, \quad (10) \\ v^C_t = W^{UV} c^{KV}_t, \quad (11)\]

其中:

  • $c^{KV}_t \in \mathbb{R}^{d_c}$ 是键和值的压缩隐向量;
  • $d_c (\ll d_h n_h)$ 表示KV压缩维度
  • $W^{DKV} \in \mathbb{R}^{d_c \times d}$ 是下投影矩阵(down-projection matrix)
  • $W^{UK}$ 和 $W^{UV} \in \mathbb{R}^{d_h n_h \times d_c}$ 分别是键和值的上投影矩阵(up-projection matrices)

在推理过程中,MLA只需缓存 $c_t^{KV}$,因此其KV缓存仅为 $d_c l$ 个元素。此外,在推理过程中,由于 $W^{UK}$ 可以被吸收到 $W^Q$ 中,$W^{UV}$ 可以被吸收到 $W_O$中,我们甚至不需要显式计算键和值来进行注意力计算。图3直观地展示了MLA中的KV联合压缩如何减少KV缓存。

图片名称

图3 多头注意(MHA)、分组查询注意(GQA)、多查询注意(MQA)和多头潜在注意(MLA)的简化说明。通过将键和值联合压缩成一个隐向量,MLA显著降低了推理过程中的KV缓存

此外,为了减少训练期间的激活内存,我们还对Query进行了低秩压缩,尽管这并不能减少KV缓存:

\[c^Q_t = W^{DQ} h_t, \quad (12) \\ q^C_t = W^{UQ} c^Q_t, \quad (13)\]

其中:

  • $c^Q_t \in \mathbb{R}^{d’_c}$ 是查询的压缩隐向量;
  • $d’_c (\ll d_h n_h)$ 表示查询压缩维度;
  • $W^{DQ} \in R^{d’c \times d}$ 和 $W{UQ} \in R^{d_h n_h \times d’_c}$ 分别是查询的下投影和上投影矩阵。

2.1.3 解耦的旋转位置嵌入

我们计划为DeepSeek-V2使用旋转位置嵌入(RoPE)(Su et al., 2024),这与DeepSeek 67B(DeepSeek-AI, 2024)一致。然而,RoPE与低秩KV压缩不兼容。具体来说,RoPE对键和查询都是位置敏感的。如果我们对键 $k^C_t$ 应用RoPE,公式10中的 $W_{UK}$ 将与一个位置敏感的RoPE矩阵耦合。这样,$W_{UK}$ 在推理过程中无法再被吸收到 $W_Q$ 中,因为与当前生成token相关的RoPE矩阵会位于 $W_Q$ 和 $W_{UK}$ 之间,而矩阵乘法不满足交换律。因此,我们必须在推理过程中重新计算所有前缀token的键,这将显著降低推理效率。

作为解决方案,我们提出了解耦RoPE策略,该策略使用额外的多头查询 $q^R_{t,i} \in \mathbb{R}^{d^R_h}$ 和一个共享键 $k^R_t \in \mathbb{R}^{d^R_h}$ 来承载RoPE,其中 $d^R_h$ 表示解耦查询和键的每头维度。配备解耦RoPE策略后,MLA执行以下计算:

\[[q^R_{t,1}; q^R_{t,2}; \dots; q^R_{t,n_h}] = q^R_t = \text{RoPE}(W^{QR} c^Q_t), \quad (14) \\ k^R_t = \text{RoPE}(W^{KR} h_t), \quad (15) \\ q_{t,i} = [q^C_{t,i}; q^R_{t,i}], \quad (16) \\ k_{t,i} = [k^C_{t,i}; k^R_t], \quad (17) \\ o_{t,i} = \sum_{j=1}^t \text{Softmax}_j \left( \frac{q_{t,i}^T k_{j,i}}{\sqrt{d_h + d^R_h}} \right) v^C_{j,i}, \quad (18) \\ u_t = W^O [o_{t,1}; o_{t,2}; \dots; o_{t,n_h}], \quad (19)\]

其中:

  • $W^{QR} \in \mathbb{R}^{d^R_h n_h \times d’_c}$ 和 $W^{KR} \in \mathbb{R}^{d^R_h \times d}$ 是生成解耦查询和键的矩阵;
  • $\text{RoPE}(\cdot)$ 表示应用RoPE矩阵的操作;$[\cdot; \cdot]$ 表示拼接操作。

在推理过程中,解耦键也需要被缓存。因此,DeepSeek-V2需要的总KV缓存为 $(d_c + d^R_h) l$ 个元素。

为了展示MLA的完整计算过程,我们在附录C中整理并提供了其完整公式。

2.1.4 KV缓存对比

我们在表1中展示了不同注意力机制下每个token的KV缓存对比。MLA仅需要少量的KV缓存,相当于仅2.25组的GQA,但其性能优于MHA。

图片名称

表1

2.2 DeepSeekMoE:以经济成本训练强大模型

2.2.1 基本架构

对于FFN,我们采用了DeepSeekMoE架构(Dai et al., 2024)。DeepSeekMoE有两个关键思想:将专家分割为更细粒度以实现更高的专家专业化和更准确的知识获取,以及隔离一些共享专家以减少路由专家之间的知识冗余。在激活专家参数和总专家参数数量相同的情况下,DeepSeekMoE能够大幅超越传统MoE架构(如GShard(Lepikhin et al., 2021))。

设 $u_t$ 为第 $t$ 个token的FFN输入,我们计算FFN输出 $h’_t$ 如下:

\[h'_t = u_t + \sum_{i=1}^{N_s} \text{FFN}^{(s)}_i (u_t) + \sum_{i=1}^{N_r} g_{i,t} \text{FFN}^{(r)}_i (u_t), \quad (20) \\ g_{i,t} = \begin{cases} s_{i,t}, & s_{i,t} \in \text{Topk}(\{s_{j,t} | 1 \leq j \leq N_r\}, K_r), \\ 0, & \text{否则}, \end{cases} \quad (21) \\ s_{i,t} = \text{Softmax}_i (u_t^T e_i), \quad (22)\]

其中:

  • $N_s$ 和 $N_r$ 分别表示共享专家和路由专家的数量;
  • $\text{FFN}^{(s)}_i (\cdot)$ 和 $\text{FFN}^{(r)}_i (\cdot)$ 分别表示第 $i$ 个共享专家和第 $i$ 个路由专家;
  • $K_r$ 表示激活的路由专家数量;
  • $g_{i,t}$ 是第 $i$ 个专家的门控值;
  • $s_{i,t}$ 是token与专家的亲和度;
  • $e_i$ 是第 $i$ 个路由专家在该层的中心点;
  • $\text{Topk}(\cdot, K)$ 表示从第 $t$ 个token与所有路由专家的亲和度分数中选取最高的 $K$ 个分数。

2.2.2 设备限制路由

我们设计了一种设备限制路由机制,以限制MoE相关的通信成本。当采用专家并行时,路由专家将分布在多个设备上。对于每个token,其MoE相关的通信频率与其目标专家覆盖的设备数量成正比。由于DeepSeekMoE中的细粒度专家分割,激活的专家数量可能较大,因此如果采用专家并行,MoE相关的通信成本会更高。

对于DeepSeek-V2,除了简单的top-K选择路由专家外,我们还确保每个token的目标专家最多分布在 $M$ 个设备上。具体来说,对于每个token,我们首先选择 $M$ 个设备,这些设备中的专家具有最高的亲和度分数。然后,我们在这些设备上的专家中进行top-K选择。在实践中,我们发现当 $M \geq 3$ 时,设备限制路由能够实现与无限制top-K路由大致相当的良好性能。

2.2.3 负载均衡的辅助损失

在自动学习的路由策略中,我们考虑了负载均衡问题。首先,负载不均衡会增加路由崩溃的风险(Shazeer et al., 2017),导致某些专家无法被充分训练和利用。其次,当采用专家并行时,负载不均衡会降低计算效率。在DeepSeek-V2的训练过程中,我们设计了三种辅助损失,分别用于控制专家级负载均衡($L_{\text{ExpBal}}$)、设备级负载均衡($L_{\text{DevBal}}$)和通信均衡($L_{\text{CommBal}}$)。

图片名称

表2

图片名称

表3

图片名称

表4

图片名称

表5

专家级均衡损失

我们使用专家级均衡损失(Fedus et al., 2021; Lepikhin et al., 2021)来减轻路由崩溃的风险:

\[L_{\text{ExpBal}} = \alpha_1 \sum_{i=1}^{N_r} f_i P_i, \quad (23) \\ f_i = \frac{N_r}{K_r T} \sum_{t=1}^T \mathbb{1}(\text{Token } t \text{ 选择专家 } i), \quad (24) \\ P_i = \frac{1}{T} \sum_{t=1}^T s_{i,t}, \quad (25)\]

其中,$\alpha_1$ 是一个超参数,称为专家级均衡因子;$\mathbb{1}(\cdot)$ 是指示函数;$T$ 表示序列中的token数量。

设备级均衡损失

除了专家级均衡损失外,我们还设计了设备级均衡损失,以确保不同设备之间的计算负载均衡。在DeepSeek-V2的训练过程中,我们将所有路由专家划分为 $D$ 个组 ${E_1, E_2, \dots, E_D}$,并将每个组部署在单个设备上。设备级均衡损失计算如下:

\[L_{\text{DevBal}} = \alpha_2 \sum_{i=1}^D f'_i P'_i, \quad (26) \\ f'_i = \frac{1}{|E_i|} \sum_{j \in E_i} f_j, \quad (27) \\ P'_i = \sum_{j \in E_i} P_j, \quad (28)\]

其中:

  • $\alpha_2$ 是一个超参数,称为设备级均衡因子。

通信均衡损失

最后,我们引入了通信均衡损失,以确保每个设备的通信负载均衡。尽管设备限制路由机制保证了每个设备的发送通信是有界的,但如果某个设备接收的token比其他设备多,实际通信效率也会受到影响。为了缓解这一问题,我们设计了通信均衡损失如下:

\[L_{\text{CommBal}} = \alpha_3 \sum_{i=1}^D f''_i P''_i, \quad (29) \\ f''_i = \frac{D}{M T} \sum_{t=1}^T \mathbb{1}(\text{Token } t \text{ 被发送到设备 } i), \quad (30) \\ P''_i = \sum_{j \in E_i} P_j, \quad (31)\]

其中:

  • $\alpha_3$ 是一个超参数,称为通信均衡因子。

设备限制路由机制的原则是确保每个设备最多向其他设备传输 $M T$ 个隐藏状态。同时,通信均衡损失用于鼓励每个设备从其他设备接收大约 $M T$ 个隐藏状态。通信均衡损失保证了设备之间的信息交换均衡,从而提高了通信效率。

2.2.4 Token丢弃策略

尽管均衡损失旨在鼓励负载均衡,但必须承认它们无法保证严格的负载均衡。为了进一步减轻因负载不均衡导致的计算浪费,我们在训练期间引入了设备级的token丢弃策略。该方法首先计算每个设备的平均计算预算,这意味着每个设备的容量因子为1.0。然后,受Riquelme et al. (2021)启发,我们在每个设备上丢弃具有最低亲和度分数的token,直到达到计算预算。此外,我们确保属于大约10%训练序列的token永远不会被丢弃。通过这种方式,我们可以根据效率需求灵活决定在推理期间是否丢弃token,并始终确保训练和推理之间的一致性。

3. 预训练

3.1 实验设置

3.1.1 数据构建

在保持与DeepSeek 67B(DeepSeek-AI, 2024)相同的数据处理阶段的基础上,我们扩展了数据量并提升了数据质量。为了扩大预训练语料库,我们探索了互联网数据的潜力并优化了清理流程,从而恢复了大量被错误删除的数据。此外,我们加入了更多的中文数据,旨在更好地利用中文互联网上的语料库。除了数据量,我们还关注数据质量。我们从各种来源丰富了预训练语料库的高质量数据,同时改进了基于质量的过滤算法。改进后的算法确保大量无益数据被移除,而有价值的数据则大部分被保留。此外,我们从预训练语料库中过滤掉了争议性内容,以减轻特定区域文化引入的数据偏差。关于该过滤策略影响的详细讨论见附录E。

我们采用了与DeepSeek 67B相同的分词器,该分词器基于字节级字节对编码(BBPE)算法构建,词汇量为100K。我们的分词预训练语料库包含8.1T token,其中中文token比英文多约12%。

3.1.2 超参数

模型超参数:我们将Transformer层数设置为60,隐藏维度设置为5120。所有可学习参数均以标准差0.006随机初始化。在MLA中,我们将注意力头数$n_h$设置为128,每头维度$d_h$设置为128。KV压缩维度$d_c$设置为512,查询压缩维度$d’_c$设置为1536。对于解耦查询和键,我们将每头维度$d^R_h$设置为64。根据Dai et al. (2024),我们将除第一层外的所有FFN替换为MoE层。每个MoE层包含2个共享专家和160个路由专家,其中每个专家的中间隐藏维度为1536。在路由专家中,每个token激活6个专家。此外,低秩压缩和细粒度专家分割会影响层的输出规模。因此,在实践中,我们在压缩隐向量后使用额外的RMS Norm层,并在宽度瓶颈(即压缩隐向量和路由专家的中间隐藏状态)处乘以额外的缩放因子以确保训练稳定。在此配置下,DeepSeek-V2总参数量为236B,每个token激活21B参数。

训练超参数:我们使用AdamW优化器(Loshchilov and Hutter, 2017),超参数设置为$\beta_1 = 0.9$、$\beta_2 = 0.95$、$\text{weight_decay} = 0.1$。学习率采用预热和阶梯衰减策略(DeepSeek-AI, 2024)。初始时,学习率在前2K步从0线性增加到最大值。随后,在训练约60%的token后,学习率乘以0.316,在训练约90%的token后再次乘以0.316。最大学习率设置为$2.4 \times 10^{-4}$,梯度裁剪范数设置为1.0。我们还使用了批量大小调度策略,在前225B token的训练中,批量大小从2304逐步增加到9216,之后保持9216。我们将最大序列长度设置为4K,并在8.1T token上训练DeepSeek-V2。我们利用流水线并行将模型的不同层部署在不同设备上,每层的路由专家均匀分布在8个设备上($D = 8$)。对于设备限制路由,每个token最多发送到3个设备($M = 3$)。对于均衡损失,我们设置$\alpha_1 = 0.003$、$\alpha_2 = 0.05$、$\alpha_3 = 0.02$。我们在训练期间使用token丢弃策略以加速训练,但在评估时不丢弃任何token。

3.1.3 基础设施

DeepSeek-V2基于HAI-LLM框架(High-flyer, 2023)进行训练,这是我们工程师开发的高效轻量级训练框架。它采用了16路零气泡流水线并行(Qi et al., 2023)、8路专家并行(Lepikhin et al., 2021)和ZeRO-1数据并行(Rajbhandari et al., 2020)。由于DeepSeek-V2激活参数相对较少,并且部分算子被重新计算以节省激活内存,因此可以在不需要张量并行的情况下进行训练,从而减少通信开销。此外,为了进一步提高训练效率,我们将共享专家的计算与专家并行的all-to-all通信重叠。我们还为通信、路由算法和跨专家的融合线性计算定制了更快的CUDA内核。此外,MLA还基于改进版的FlashAttention-2(Dao, 2023)进行了优化。

我们在配备NVIDIA H800 GPU的集群上进行所有实验。H800集群中的每个节点包含8个GPU,节点内通过NVLink和NVSwitch连接。节点间使用InfiniBand互连以促进通信。

3.1.4 长上下文扩展

在DeepSeek-V2的初始预训练后,我们使用YaRN(Peng et al., 2023)将默认上下文窗口长度从4K扩展到128K。YaRN特别应用于解耦共享键$k^R_t$,因为它负责承载RoPE(Su et al., 2024)。对于YaRN,我们将缩放因子$s$设置为40,$\alpha$设置为1,$\beta$设置为32,目标最大上下文长度设置为160K。在这些设置下,我们可以预期模型在128K的上下文长度下表现良好。与原始YaRN略有不同,由于我们独特的注意力机制,我们调整了长度缩放因子以调节注意力熵。因子$\sqrt{t}$计算为$\sqrt{t} = 0.0707 \ln s + 1$,旨在最小化困惑度。

我们额外训练了1000步,序列长度为32K,批量大小为576个序列。尽管训练仅在32K的序列长度下进行,但模型在128K的上下文长度下仍表现出色。如图4所示,在“Needle In A Haystack”(NIAH)测试中,DeepSeek-V2在所有上下文窗口长度(最高128K)下均表现良好。

图片名称

图4

3.2. 评估

3.2.1. 评估基准

DeepSeek-V2 是在双语语料库上进行预训练的,因此我们在英语和中文的一系列基准上对其进行了评估。我们的评估基于集成在 HAI-LLM 框架中的内部评估框架。包含的基准分类如下,其中带下划线的基准为中文:

  • 多学科选择题数据集 包括 MMLU (Hendrycks et al., 2020)、C-Eval (Huang et al., 2023) 和 CMMLU (Li et al., 2023)。
  • 语言理解和推理数据集 包括 HellaSwag (Zellers et al., 2019)、PIQA (Bisk et al., 2020)、ARC (Clark et al., 2018) 和 BigBench Hard (BBH) (Suzgun et al., 2022)。
  • 闭卷问答数据集 包括 TriviaQA (Joshi et al., 2017) 和 NaturalQuestions (Kwiatkowski et al., 2019)。
  • 阅读理解数据集 包括 RACE (Lai et al., 2017)、DROP (Dua et al., 2019)、C3 (Sun et al., 2019) 和 CMRC (Cui et al., 2019)。
  • 指代消解数据集 包括 WinoGrande (Sakaguchi et al., 2019) 和 CLUEWSC (Xu et al., 2020)。
  • 语言建模数据集 包括 Pile (Gao et al., 2020)。
  • 中文理解与文化数据集 包括 CHID (Zheng et al., 2019) 和 CCPM (Li et al., 2021)。
  • 数学数据集 包括 GSM8K (Cobbe et al., 2021)、MATH (Hendrycks et al., 2021) 和 CMath (Wei et al., 2023)。
  • 代码数据集 包括 HumanEval (Chen et al., 2021)、MBPP (Austin et al., 2021) 和 CRUXEval (Gu et al., 2024)。
  • 标准化考试 包括 AGIEval (Zhong et al., 2023)。注意,AGIEval 包含英语和中文子集。

根据我们之前的工作 (DeepSeek-AI, 2024),我们对以下数据集采用基于困惑度(perplexity)的评估:HellaSwag、PIQA、WinoGrande、RACE-Middle、RACE-High、MMLU、ARC-Easy、ARC-Challenge、CHID、C-Eval、CMMLU、C3 和 CCPM;对以下数据集采用基于生成的评估:TriviaQA、NaturalQuestions、DROP、MATH、GSM8K、HumanEval、MBPP、CRUXEval、BBH、AGIEval、CLUEWSC、CMRC 和 CMath。此外,我们对 Pile-test 进行基于语言建模的评估,并使用 Bits-Per-Byte (BPB) 作为指标,以确保使用不同分词器的模型之间的公平比较。

为了直观地了解这些基准,我们在附录 G 中提供了每个基准的评估格式。

3.2.2. 评估结果

在表 2 中,我们将 DeepSeek-V2 与几个代表性的开源模型进行了比较,包括 DeepSeek 67B (DeepSeek-AI, 2024)(我们之前的版本)、Qwen1.5 72B (Bai et al., 2023)、LLaMA3 70B (AI@Meta, 2024) 和 Mixtral 8x22B (Mistral, 2024)。我们使用内部评估框架评估了所有这些模型,并确保它们共享相同的评估设置。总体而言,DeepSeek-V2 仅激活了 21B 参数,但在几乎所有基准上都显著优于 DeepSeek 67B,并在开源模型中达到了顶级性能。

进一步,我们详细比较了 DeepSeek-V2 与其他开源模型的表现:

  1. 与 Qwen1.5 72B 的比较:Qwen1.5 72B 是另一个支持中文和英文的模型。DeepSeek-V2 在大多数英语、代码和数学基准上表现出压倒性优势。在中文基准上,Qwen1.5 72B 在多学科选择题任务上表现更好,而 DeepSeek-V2 在其他任务上表现相当或更好。需要注意的是,对于 CHID 基准,Qwen1.5 72B 的分词器在我们的评估框架中会遇到错误,因此我们未记录 Qwen1.5 72B 的 CHID 分数。

  2. 与 Mixtral 8x22B 的比较:DeepSeek-V2 在英语基准上表现相当或更好,除了与英语常识知识密切相关的 TriviaQA、NaturalQuestions 和 HellaSwag。值得注意的是,DeepSeek-V2 在 MMLU 上优于 Mixtral 8x22B。在代码和数学基准上,DeepSeek-V2 与 Mixtral 8x22B 表现相当。由于 Mixtral 8x22B 并未专门针对中文数据进行训练,其中文能力远不及 DeepSeek-V2。

  3. 与 LLaMA3 70B 的比较:DeepSeek-V2 的训练数据量不到 LLaMA3 70B 的四分之一。因此,我们承认 DeepSeek-V2 在基础英语能力上仍与 LLaMA3 70B 存在轻微差距。然而,即使训练数据和激活参数少得多,DeepSeek-V2 在代码和数学能力上仍与 LLaMA3 70B 相当。此外,作为双语模型,DeepSeek-V2 在中文基准上显著优于 LLaMA3 70B。

最后,值得一提的是,某些先前的研究 (Hu et al., 2024) 在预训练阶段引入了 SFT 数据,而 DeepSeek-V2 在预训练期间从未接触过 SFT 数据。

3.2.3. 训练和推理效率

训练成本:由于 DeepSeek-V2 为每个 token 激活的参数较少,且所需的 FLOPs 少于 DeepSeek 67B,理论上训练 DeepSeek-V2 比训练 DeepSeek 67B 更经济。尽管训练 MoE 模型会引入额外的通信开销,但通过我们的操作符和通信优化,DeepSeek-V2 的训练可以达到相对较高的模型 FLOPs 利用率(MFU)。在实际的 H800 集群训练中,每训练一万亿 token,DeepSeek 67B 需要 300.6K GPU 小时,而 DeepSeek-V2 仅需 172.8K GPU 小时,即稀疏的 DeepSeek-V2 比密集的 DeepSeek 67B 节省了 42.5% 的训练成本。

推理效率:为了高效部署 DeepSeek-V2 提供服务,我们首先将其参数转换为 FP8 精度。此外,我们还对 DeepSeek-V2 进行了 KV 缓存量化 (Hooper et al., 2024; Zhao et al., 2023),进一步将其 KV 缓存中的每个元素平均压缩到 6 位。得益于 MLA 和这些优化,实际部署的 DeepSeek-V2 所需的 KV 缓存显著少于 DeepSeek 67B,因此可以支持更大的批量大小。我们基于实际部署的 DeepSeek 67B 服务的提示和生成长度分布,评估了 DeepSeek-V2 的生成吞吐量。在单个 8 H800 GPU 节点上,DeepSeek-V2 的生成吞吐量超过每秒 50K token,是 DeepSeek 67B 最大生成吞吐量的 5.76 倍。此外,DeepSeek-V2 的提示输入吞吐量超过每秒 100K token。

4. 对齐

4.1. 监督微调(SFT)

基于我们之前的研究(DeepSeek-AI, 2024),我们构建了包含 150 万条实例的指令微调数据集,其中 120 万条用于提升模型的有用性,30 万条用于提升安全性。与初始版本相比,我们提高了数据质量,以减少幻觉响应并增强写作能力。我们对 DeepSeek-V2 进行了 2 轮微调,学习率设置为 $5 \times 10^{-6}$。

对于 DeepSeek-V2 Chat(SFT)的评估,我们主要采用基于生成的基准测试,除了几个具有代表性的选择题任务(如 MMLU 和 ARC)。我们还对 DeepSeek-V2 Chat(SFT)进行了指令跟随评估(IFEval)(Zhou et al., 2023),使用提示级别的宽松准确率作为指标。此外,我们使用 2023 年 9 月 1 日至 2024 年 4 月 1 日的 LiveCodeBench(Jain et al., 2024)问题来评估聊天模型。除了标准基准测试外,我们还在开放式对话基准测试上进一步评估了模型,包括 MT-Bench(Zheng et al., 2023)、AlpacaEval 2.0(Dubois et al., 2024)和 AlignBench(Liu et al., 2023)。为了进行比较,我们还在我们的评估框架和设置中评估了 Qwen1.5 72B Chat、LLaMA-3-70B Instruct 和 Mistral-8x22B Instruct。对于 DeepSeek 67B Chat,我们直接参考了之前发布的评估结果。

4.2. 强化学习(RL)

为了进一步释放 DeepSeek-V2 的潜力并使其与人类偏好对齐,我们进行了强化学习(RL)以调整其偏好。

强化学习算法:为了节省 RL 的训练成本,我们采用了组相对策略优化(GRPO)(Shao et al., 2024),该方法省略了通常与策略模型大小相同的评论模型,而是从组分数中估计基线。具体来说,对于每个问题 $q$,GRPO 从旧策略 $\pi_{\theta_{\text{old}}}$ 中采样一组输出 ${o_1, o_2, \dots, o_G}$,然后通过最大化以下目标来优化策略模型 $\pi_{\theta}$:

\[J_{\text{GRPO}}(\theta) = \mathbb{E}\left[q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q)\right] \frac{1}{G} \sum_{i=1}^G \min\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip}\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right) A_i\right) - \beta D_{\text{KL}}(\pi_{\theta} || \pi_{\text{ref}}),\]

其中:

  • $\epsilon$ 和 $\beta$ 是超参数,
  • $A_i$ 是优势值,通过每组输出对应的奖励 ${r_1, r_2, \dots, r_G}$ 计算:
\[A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \dots, r_G\})}{\text{std}(\{r_1, r_2, \dots, r_G\})}.\]

训练策略:在初步实验中,我们发现针对推理数据(如代码和数学提示)的 RL 训练表现出与通用数据训练不同的独特特性。例如,模型的数学和编码能力可以在更长的训练步骤中持续提升。因此,我们采用了两阶段 RL 训练策略:首先进行推理对齐,然后进行人类偏好对齐。在第一阶段,我们训练了一个用于代码和数学推理任务的奖励模型 $RM_{\text{reasoning}}$,并使用其反馈优化策略模型:

\[r_i = RM_{\text{reasoning}}(o_i).\]

在第二阶段,我们采用多奖励框架,从有用性奖励模型 $RM_{\text{helpful}}$、安全性奖励模型 $RM_{\text{safety}}$ 和基于规则的奖励模型 $RM_{\text{rule}}$ 中获取奖励。最终响应 $o_i$ 的奖励为:

\[r_i = c_1 \cdot RM_{\text{helpful}}(o_i) + c_2 \cdot RM_{\text{safety}}(o_i) + c_3 \cdot RM_{\text{rule}}(o_i),\]

其中:

  • $c_1, c_2, c_3$ 是对应的系数。

为了获得可靠的奖励模型,我们精心收集了偏好数据,并进行了严格的质量过滤和比例调整。我们基于编译器反馈获取代码偏好数据,基于真实标签获取数学偏好数据。对于奖励模型训练,我们使用 DeepSeek-V2 Chat(SFT)初始化奖励模型,并使用点对或对对的损失进行训练。在实验中,我们观察到 RL 训练能够充分挖掘和激活模型的潜力,使其能够从可能的响应中选择正确且令人满意的答案。

训练效率优化:在极大模型上进行 RL 训练对训练框架提出了高要求。我们实施了以下工程优化:

  • (1)采用混合引擎,分别针对训练和推理采用不同的并行策略以提高 GPU 利用率;
  • (2)利用 vLLM(Kwon et al., 2023)作为推理后端,加速推理速度;
  • (3)精心设计模型卸载到 CPU 和加载回 GPU 的调度策略,以实现训练速度和内存消耗的近最优平衡。

4.3. 评估结果

标准基准测试评估:我们首先在标准基准测试上评估了 DeepSeek-V2 Chat(SFT)和 DeepSeek-V2 Chat(RL)。值得注意的是,DeepSeek-V2 Chat(SFT)在 GSM8K、MATH 和 HumanEval 评估中相比其基础版本有显著提升,这归因于我们的 SFT 数据中包含大量数学和代码相关内容。此外,DeepSeek-V2 Chat(RL)进一步提升了数学和代码基准测试的表现。

与其他模型的比较中,DeepSeek-V2 Chat(SFT)在几乎所有英语、数学和代码基准测试上均优于 Qwen1.5 72B Chat。在中文基准测试上,DeepSeek-V2 Chat(SFT)在多学科选择题任务上略低于 Qwen1.5 72B Chat,与其基础版本的表现一致。与最先进的开源 MoE 模型 Mixtral 8x22B Instruct 相比,DeepSeek-V2 Chat(SFT)在大多数基准测试上表现更好,除了 NaturalQuestions 和 IFEval。与最先进的开源模型 LLaMA3 70B Chat 相比,DeepSeek-V2 Chat(SFT)在代码和数学相关基准测试上表现相似,LLaMA3 70B Chat 在 MMLU 和 IFEval 上表现更好,而 DeepSeek-V2 Chat(SFT)在中文任务上表现更强。最终,DeepSeek-V2 Chat(RL)在数学和编码任务上相比 DeepSeek-V2 Chat(SFT)进一步提升了性能。

开放式生成评估:我们在开放式对话基准测试上进一步评估了模型。对于英语开放式对话生成,我们使用 MT-Bench 和 AlpacaEval 2.0 作为基准测试。评估结果显示,DeepSeek-V2 Chat(RL)相比 DeepSeek-V2 Chat(SFT)有显著优势,展示了 RL 训练在实现更好对齐方面的有效性。与其他开源模型相比,DeepSeek-V2 Chat(RL)在 MT-Bench 和 AlpacaEval 2.0 上均优于 Mistral 8x22B Instruct 和 Qwen1.5 72B Chat。与 LLaMA3 70B Instruct 相比,DeepSeek-V2 Chat(RL)在 MT-Bench 上表现相当,在 AlpacaEval 2.0 上显著优于后者。

对于中文开放式生成能力,我们基于 AlignBench 进行了评估。结果显示,DeepSeek-V2 Chat(RL)相比 DeepSeek-V2 Chat(SFT)有轻微优势。值得注意的是,DeepSeek-V2 Chat(SFT)显著优于所有开源中文模型,在中文推理和语言任务上大幅领先第二好的开源模型 Qwen1.5 72B Chat。此外,DeepSeek-V2 Chat(SFT)和 DeepSeek-V2 Chat(RL)均优于 GPT-4-0613 和 ERNIEBot 4.0,巩固了我们的模型在支持中文的顶级 LLM 中的地位。

4.4. 讨论

SFT 数据量:关于是否需要大量 SFT 数据的讨论一直存在争议。先前的工作(Young et al., 2024; Zhou et al., 2024)认为少于 10K 条 SFT 数据足以产生令人满意的结果。然而,在我们的实验中,如果使用少于 10K 条数据,我们在 IFEval 基准测试上观察到显著的性能下降。可能的解释是,语言模型需要一定量的数据来发展特定技能。尽管随着模型规模的增加,所需数据量可能会减少,但不能完全消除。我们的观察强调了为 LLM 提供足够数据以具备所需能力的关键性。此外,SFT 数据的质量也至关重要,尤其是在涉及写作或开放式问题的任务中。

强化学习的对齐税:在人类偏好对齐过程中,我们观察到开放式生成基准测试的显著性能提升,无论是 AI 还是人类评估者的评分。然而,我们也注意到“对齐税”现象(Ouyang et al., 2022),即对齐过程可能会对某些标准基准测试(如 BBH)的性能产生负面影响。为了缓解对齐税,我们在 RL 阶段在数据处理和训练策略改进上付出了巨大努力,最终在标准基准测试和开放式基准测试之间实现了可接受的权衡。

在线强化学习:在我们的偏好对齐实验中,我们发现在线方法显著优于离线方法。因此,我们投入了大量精力实现了一个在线 RL 框架来对齐 DeepSeek-V2。关于在线或离线偏好对齐的结论可能因不同情境而异,我们将在未来工作中进行更深入的比较和分析。

5. 结论、局限性与未来工作

本文介绍了 DeepSeek-V2,一个支持 128K 上下文长度的大型 MoE 语言模型。除了强大的性能外,它还以经济高效的训练和推理为特点,得益于其创新的架构(包括 MLA 和 DeepSeekMoE)。在实际应用中,相比 DeepSeek 67B,DeepSeek-V2 在显著提升性能的同时,节省了 42.5% 的训练成本,减少了 93.3% 的 KV 缓存,并将最大生成吞吐量提升至 5.76 倍。评估结果表明,仅激活 21B 参数的 DeepSeek-V2 在开源模型中达到了顶级性能,成为最强的开源 MoE 模型。

DeepSeek-V2 及其聊天版本具有其他 LLM 常见的局限性,包括预训练后缺乏持续知识更新、可能生成未经核实的信息(如未经验证的建议)以及可能产生幻觉。此外,由于我们的数据主要由中文和英文内容组成,模型在其他语言上的表现可能有限。在中文和英文以外的场景中,应谨慎使用。

DeepSeek 将持续投资于开源大模型,致力于逐步接近通用人工智能的目标。我们正在探索进一步扩展 MoE 模型的方法,同时保持经济高效的训练和推理成本。我们的下一步目标是在即将发布的版本中实现与 GPT-4 相当的性能。我们的对齐团队不断努力提升模型,旨在开发一个不仅有用而且诚实、安全的模型,最终目标是使模型的价值与人类价值观对齐,同时尽量减少人类监督的需求。通过优先考虑伦理和负责任的发展,我们致力于为社会创造积极和有益的影响。目前,DeepSeek-V2 仅支持文本模态。在未来计划中,我们打算使模型支持多模态,增强其在不同场景中的多功能性和实用性。

#

https://arxiv.org/pdf/2412.19437

追一科技在《ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING》提出了RoPE的方法:

摘要

位置编码最近在Transformer架构中显示出其有效性。它为序列中不同位置元素之间的依赖关系建模提供了有价值的监督。在本文中,我们:

  • 首先研究了将位置信息整合到基于Transformer的语言模型学习过程中的各种方法。
  • 接着,我们提出了一种名为旋转位置嵌入(Rotary Position Embedding, RoPE)的新方法,以有效利用位置信息。具体来说,所提出的RoPE通过旋转矩阵对绝对位置进行编码,同时将显式的相对位置依赖关系引入自注意力机制中。值得注意的是,RoPE具有一些有价值的特性,包括:序列长度的灵活性、随着相对距离增加而衰减的token间依赖性,以及为线性自注意力配备相对位置编码的能力
  • 最后,我们在多种长文本分类基准数据集上评估了增强后的Transformer(称为RoFormer)的性能。

实验结果表明,RoFormer始终优于其他替代方法。此外,我们还提供了理论分析来解释一些实验结果。RoFormer已经集成到Huggingface中:https://huggingface.co/docs/transformers/model_doc/roformer

1 引言

词序对于自然语言理解具有重要价值

  • 基于循环神经网络(RNN)的模型通过沿时间维度递归计算隐藏状态来编码词的顺序。
  • 基于卷积神经网络(CNN)的模型(如Gehring等[2017])通常被认为是与位置无关的,但最近的研究(Islam等[2020])表明,常用的填充操作可以隐式地学习位置信息。
  • 近年来,基于Transformer(Vaswani等[2017])构建的预训练语言模型(PLMs)在各种自然语言处理(NLP)任务中取得了最先进的性能,包括上下文表示学习(Devlin等[2019])、机器翻译(Vaswani等[2017])和语言建模(Radford等[2019])等。

与基于RNN和CNN的模型不同,PLMs利用自注意力机制从语义上捕捉给定语料的上下文表示。因此,PLMs在并行化方面相比RNN取得了显著改进,并且与CNN相比,能够更好地建模更长的token内关系。

值得注意的是,当前PLMs的自注意力架构被证明是与位置无关的(Yun等[2020])。基于这一观点,研究者们提出了多种方法将位置信息编码到学习过程中。

  • 一方面,通*过预定义函数生成的绝对位置编码(Vaswani等[2017])被添加到上下文表示中,而可训练的绝对位置编码(Gehring等[2017]、Devlin等[2019]、Lan等[2020]、Clark等[2020]、Radford等[2019]、Radford和Narasimhan[2018])也被广泛使用。
  • 另一方面,之前的工作(Parikh等[2016]、Shaw等[2018]、Huang等[2018]、Dai等[2019]、Yang等[2019]、Raffel等[2020]、Ke等[2020]、He等[2020]、Huang等[2020])主要集中在相对位置编码上,通常将相对位置信息编码到注意力机制中。
  • 此外,Liu等[2020]的作者提出从神经微分方程(Neural ODE,Chen等[2018a]的角度建模位置编码的依赖性,Wang等[2020]的作者提出在复数空间中建模位置信息

尽管这些方法有效,但它们通常将位置信息添加到上下文表示中,因此不适合线性自注意力架构。

在本文中,我们提出了一种新方法,即旋转位置嵌入(Rotary Position Embedding, RoPE),将位置信息引入PLMs的学习过程中。具体来说,RoPE通过旋转矩阵对绝对位置进行编码,同时将显式的相对位置依赖关系引入自注意力机制中。值得注意的是,RoPE具有一些优于现有方法的特性,包括序列长度的灵活性、随着相对距离增加而衰减的token间依赖性,以及为线性自注意力配备相对位置编码的能力。在多种长文本分类基准数据集上的实验结果表明,增强后的Transformer(称为RoFormer)相比基线方法能够取得更好的性能,从而证明了RoPE的有效性。

简而言之,我们的贡献如下:

  • 我们研究了现有的相对位置编码方法,发现它们大多基于将位置编码添加到上下文表示的分解思想。我们提出了一种新方法,即旋转位置嵌入(RoPE),将位置信息引入PLMs的学习过程中。其核心思想是通过将上下文表示与旋转矩阵相乘来编码相对位置,并具有清晰的理论解释。
  • 我们研究了RoPE的特性,并表明其随着相对距离的增加而衰减,这符合自然语言编码的需求。我们还指出,之前的相对位置编码方法与线性自注意力不兼容。
  • 我们在多种长文本基准数据集上评估了提出的RoFormer。实验结果表明,RoFormer始终优于其他替代方法。一些预训练语言模型的实验代码已在GitHub上开源:https://github.com/ZhuiyiTechnology/roformer

本文的其余部分组织如下:在第2节中,我们建立了自注意力架构中位置编码问题的形式化描述,并回顾了之前的工作;在第3节中,我们描述了旋转位置编码(RoPE)并研究了其特性;在第4节中,我们报告了实验结果;最后,在第5节中对本文进行了总结。

2 背景与相关工作

2.1 初步知识

设 $S_N = \lbrace w_i \rbrace_{i=1}^N$ 为一个包含 N 个输入token的序列,其中:

  • $ w_i $ 是第 $ i $ 个元素
  • $ S_N $ 对应的词嵌入表示为 $ E_N = \lbrace x_i \rbrace_{i=1}^N $,其中 $ x_i \in R^d $ 是第i个token的d维词嵌入向量(不包含位置信息)。

自注意力机制首先将位置信息融入词嵌入中,并将其转换为查询(query)、键(key)和值(value)表示:

\[\begin{aligned} q_m &= f_q(x_m, m) \\ k_n &= f_k(x_n, n) \\ v_n &= f_v(x_n, n), \end{aligned}\]

…(1)

其中:

  • $ q_m $、$ k_n $ 和 $ v_n $ 分别通过函数 $ f_q $、$ f_k $ 和 $ f_v $ 融入了第 $ m $ 和第 $ n $ 个位置的信息。

查询和键值用于计算注意力权重,而输出则是对值表示的加权和:

\[\begin{aligned} a_{m,n} &= \frac{\exp\left(\frac{q_m^\intercal k_n}{\sqrt{d}}\right)}{\sum_{j=1}^N \exp\left(\frac{q_m^\intercal k_j}{\sqrt{d}}\right)} \\ o_m &= \sum_{n=1}^N a_{m,n} v_n. \end{aligned}\]

…(2)

现有的基于Transformer的位置编码方法主要集中在选择合适的函数来构建公式(1)。

2.2 绝对位置嵌入

公式(1)的一个典型选择是:

\[f_{t: t \in \{q, k, v\}}(x_i, i) := W_{t: t \in \{q, k, v\}}(x_i + p_i),\]

…(3)

其中:

  • $ p_i \in \mathbb{R}^d $ 是一个依赖于token $ x_i $ 位置的 $ d $ 维向量。

之前的工作(Devlin等[2019]、Lan等[2020]、Clark等[2020]、Radford等[2019]、Radford和Narasimhan[2018])引入了一组可训练的向量 $ p_i \in {p_t}_{t=1}^L $,其中 $ L $ 是最大序列长度。Vaswani等[2017]的作者提出使用正弦函数生成 $ p_i $:

\[\begin{aligned} p_{i,2t} &= \sin\left(\frac{k}{10000^{2t/d}}\right) \\ p_{i,2t+1} &= \cos\left(\frac{k}{10000^{2t/d}}\right), \end{aligned}\]

…(4)

其中:

  • $ p_{i,2t} $ 是 $ d $ 维向量 $ p_i $ 的第 $ 2t $ 个元素。

在下一节中,我们将展示我们提出的RoPE与这种正弦函数直觉相关。然而,RoPE不是直接将位置信息添加到上下文表示中,而是通过与正弦函数相乘来融入相对位置信息

2.3 相对位置嵌入

Shaw等[2018]的作者对公式(1)应用了以下不同的设置:

\[\begin{aligned} f_q(x_m) &:= W_q x_m \\ f_k(x_n, n) &:= W_k (x_n + \widetilde{p}^k_r) \\ f_v(x_n, n) &:= W_v (x_n + \widetilde{p}^v_r), \end{aligned}\]

…(5)

其中:

  • $ \widetilde{p}^k_r $、$ \widetilde{p}^v_r \in \mathbb{R}^d $ 是可训练的相对位置嵌入。注意,$ r = \text{clip}(m - n, r_{\text{min}}, r_{\text{max}}) $ 表示位置 $ m $ 和 $ n $ 之间的相对距离。

他们通过假设超出一定距离的相对位置信息无用,对相对距离进行了裁剪。Dai等[2019]的作者在保持公式(3)的形式下,提出将公式(2)中的 $ q_m^\intercal k_n $ 分解为:

\[q_m^\intercal k_n = x_m^\intercal W_q^\intercal W_k x_n + x_m^\intercal W_q^\intercal W_k p_n + p_m^\intercal W_q^\intercal W_k x_n + p_m^\intercal W_q^\intercal W_k p_n,\]

…(6)

其核心思想是:将绝对位置嵌入 $ p_n $ 替换为其正弦编码的相对对应项 $ \widetilde{p}_{m-n} $,同时将第三和第四项中的绝对位置 $ p_m $ 替换为两个与查询位置无关的可训练向量 $ u $ 和 $ v $。此外,$ W_k $ 被区分为基于内容的键向量 $ x_n $ 和基于位置的键向量 $ p_n $,分别表示为 $ W_k $ 和 $ \widetilde{W}_k $,从而得到:

\[q_m^\intercal k_n = x_m^\intercal W_q^\intercal W_k x_n + x_m^\intercal W_q^\intercal \widetilde{W}_k \widetilde{p}_{m-n} + u^\intercal W_q^\intercal W_k x_n + v^\intercal W_q^\intercal \widetilde{W}_k \widetilde{p}_{m-n}.\]

…(7)

值得注意的是,值项中的位置信息通过设置 $ f_v(x_j) := W_v x_j $ 被移除。后续工作(Raffel等[2020]、He等[2020]、Ke等[2020]、Huang等[2020])遵循了这些设置,仅将相对位置信息编码到注意力权重中。然而,Raffel等[2020]的作者将公式(6)重新表述为:

\[q_m^\intercal k_n = x_m^\intercal W_q^\intercal W_k x_n + b_{i,j},\]

…(8)

其中:

  • $ b_{i,j} $ 是一个可训练的偏置项。

Ke等[2020]的作者研究了公式(6)中的中间两项,发现绝对位置与词之间的相关性较弱。Raffel等[2020]的作者提出使用不同的投影矩阵对词或位置进行建模:

\[q_m^\intercal k_n = x_m^\intercal W_q^\intercal W_k x_n + p_m^\intercal U_q^\intercal U_k p_n + b_{i,j}.\]

…(9)

He等[2020]的作者认为,两个token的相对位置只能通过公式(6)中的中间两项完全建模。因此,绝对位置嵌入 $ p_m $ 和 $ p_n $ 被简单地替换为相对位置嵌入 $ \widetilde{p}_{m-n} $:

\[q_m^\intercal k_n = x_m^\intercal W_q^\intercal W_k x_n + x_m^\intercal W_q^\intercal W_k \widetilde{p}_{m-n} + \widetilde{p}_{m-n}^\intercal W_q^\intercal W_k x_n.\]

…(10)

对四种相对位置嵌入变体的比较(Radford和Narasimhan[2018])表明,类似于公式(10)的变体在其他三种中效率最高。总的来说,所有这些方法都试图在自注意力设置下基于公式(3)的分解来修改公式(6),这是Vaswani等[2017]最初提出的。它们通常直接将位置信息添加到上下文表示中。与之不同,我们的方法旨在在某些约束下从公式(1)推导出相对位置编码。接下来,我们将展示通过旋转上下文表示融入相对位置信息,推导出的方法更具可解释性。

3 提出的方法

在本节中,我们讨论提出的旋转位置嵌入(Rotary Position Embedding, RoPE)。我们首先在第3.1节中形式化相对位置编码问题,然后在第3.2节中推导RoPE,并在第3.3节中研究其特性。

3.1 形式化

基于Transformer的语言建模通常通过自注意力机制利用各个token的位置信息。如公式(2)所示,$ q_m^\intercal k_n $ 通常用于在不同位置的token之间传递知识。为了融入相对位置信息,我们要求查询 $ q_m $ 和键 $ k_n $ 的内积由一个函数 $ g $ 表示,该函数仅以词嵌入 $ x_m $、$ x_n $ 及其相对位置 $ m - n $ 作为输入变量。换句话说,我们希望内积仅以相对形式编码位置信息:

\[\langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, m - n).\]

…(11)

最终目标是:找到一种等效的编码机制来解决函数 $ f_q(x_m, m) $ 和 $ f_k(x_n, n) $,以符合上述关系。

3.2 旋转位置嵌入

3.2.1 二维情况

我们从维度 $ d = 2 $ 的简单情况开始。在这些设置下,我们利用二维平面上向量的几何性质及其复数形式来证明(详见第3.4.1节),公式(11)的一个解为:

\[\begin{aligned} f_q(x_m, m) &= (W_q x_m) e^{i m \theta} \\ f_k(x_n, n) &= (W_k x_n) e^{i n \theta} \\ g(x_m, x_n, m - n) &= \text{Re}\left[(W_q x_m) (W_k x_n)^* e^{i (m - n) \theta}\right], \end{aligned}\]

…(12)

其中:

  • $\text{Re}[\cdot]$ 表示复数的实部,$(W_k x_n)^*$ 表示 $(W_k x_n)$ 的共轭复数。
  • $\theta \in \mathbb{R}$ 是一个预设的非零常数。

我们可以进一步将 $ f_{{q, k}} $ 写成矩阵乘法形式

\[f_{\{q, k\}}(x_m, m) = \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} W_{\lbrace,q,k\rbrace}^{11} & W_{\lbrace,q,k\rbrace}^{12} \\ W_{\lbrace,q,k\rbrace}^{21} & W_{\lbrace,q,k\rbrace}^{22} \end{pmatrix} \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix},\]

…(13)

其中:

  • $(x_m^{(1)}, x_m^{(2)})$ 是 $ x_m $ 在二维坐标中的表示。
  • 类似地,$ g $ 可以视为一个矩阵,从而在二维情况下解决第3.1节中的形式化问题。

具体来说,融入相对位置嵌入非常简单:只需将仿射变换后的词嵌入向量旋转其位置索引的角度倍数,从而解释旋转位置嵌入的直觉。

3.2.2 一般形式

为了将二维结果推广到任意 $ x_i \in \mathbb{R}^d $(其中 $ d $ 为偶数),我们将 $ d $ 维空间划分为 $ d/2 $ 个子空间,并利用内积的线性性将它们组合起来,将 $ f_{{q, k}} $ 转化为:

\[f_{\{q, k\}}(x_m, m) = R^d_{\Theta, m} W_{\{q, k\}} x_m,\]

…(14)

其中:

\[R^d_{\Theta, m} = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}\]

…(15)

是旋转矩阵,其预定义参数为 $\Theta = {\theta_i = 10000^{-2(i-1)/d}, i \in [1, 2, …, d/2]}$。图(1)展示了RoPE的图形化说明。将RoPE应用于公式(2)中的自注意力机制,我们得到:

\[q_m^\intercal k_n = (R^d_{\Theta, m} W_q x_m)^\intercal (R^d_{\Theta, n} W_k x_n) = x_m^\intercal W_q R^d_{\Theta, n-m} W_k x_n,\]

…(16)

其中:

  • $ R^d_{\Theta, n-m} = (R^d_{\Theta, m})^\intercal R^d_{\Theta, n} $。
  • 注意,$ R^d_{\Theta} $ 是一个正交矩阵,这确保了在编码位置信息过程中的稳定性。

此外,由于 $ R^d_{\Theta} $ 的稀疏性,直接应用矩阵乘法如公式(16)所示在计算上并不高效;我们在理论解释中提供了另一种实现方式。

图片名称

图1 Rotary Position Embedding(RoPE)的实现

与之前工作中采用的加法性质的位置嵌入方法(即公式(3)到(10))不同,我们的方法是乘法的。此外,RoPE通过旋转矩阵乘积自然地融入了相对位置信息,而不是在应用于自注意力时修改加法位置编码的扩展公式中的项。

3.3 RoPE的特性

长期衰减性:遵循Vaswani等[2017],我们设置 $\theta_i = 10000^{-2i/d}$。可以证明,这种设置提供了长期衰减特性(详见第3.4.3节),这意味着内积会随着相对位置的增加而衰减。这一特性与直觉一致,即具有较长相对距离的token对应该具有较少的连接。

RoPE与线性注意力:自注意力可以改写为更一般的形式:

\[\text{Attention}(Q, K, V)_m = \frac{\sum_{n=1}^N \text{sim}(q_m, k_n) v_n}{\sum_{n=1}^N \text{sim}(q_m, k_n)},\]

…(17)

其中原始自注意力选择 $\text{sim}(q_m, k_n) = \exp(q_m^\intercal k_n / \sqrt{d})$。注意,原始自注意力需要计算每对token的查询和键的内积,其复杂度为 $ O(N^2) $。遵循Katharopoulos等[2020],线性注意力将公式(17)重新表述为:

\[\text{Attention}(Q, K, V)_m = \frac{\sum_{n=1}^N \phi(q_m)^\intercal \varphi(k_n) v_n}{\sum_{n=1}^N \phi(q_m)^\intercal \varphi(k_n)},\]

…(18)

其中 $\phi(\cdot)$ 和 $\varphi(\cdot)$ 通常是非负函数。Katharopoulos等[2020]的作者提出 $\phi(x) = \varphi(x) = \text{elu}(x) + 1$,并首先利用矩阵乘法的结合性计算键和值的乘积。Shen等[2021]使用softmax函数在内积之前分别对查询和键进行归一化,这等价于 $\phi(q_i) = \text{softmax}(q_i)$ 和 $\phi(k_j) = \exp(k_j)$。有关线性注意力的更多细节,我们鼓励读者参考原始论文。在本节中,我们重点讨论将RoPE与公式(18)结合。由于RoPE通过旋转注入位置信息,这保持了隐藏表示的范数不变,因此我们可以通过将旋转矩阵与非负函数的输出相乘来将RoPE与线性注意力结合:

\[\text{Attention}(Q, K, V)_m = \frac{\sum_{n=1}^N \left(R^d_{\Theta, m} \phi(q_m)\right)^\intercal \left(R^d_{\Theta, n} \varphi(k_n)\right) v_n}{\sum_{n=1}^N \phi(q_m)^\intercal \varphi(k_n)}.\]

…(19)

值得注意的是,我们保持分母不变以避免除以零的风险,而分子中的求和可能包含负项。尽管公式(19)中每个值 $ v_i $ 的权重并未严格概率归一化,但我们认为计算仍可以建模值的重要性。

图片名称

图 2

3.4 理论解释

3.4.1 二维情况下RoPE的推导

在 $ d = 2 $ 的情况下,我们考虑两个词嵌入向量 $ x_q $ 和 $ x_k $,分别对应于查询和键,以及它们的位置 $ m $ 和 $ n $。根据公式(1),它们的位置编码对应为:

\[\begin{aligned} q_m &= f_q(x_q, m), \\ k_n &= f_k(x_k, n), \end{aligned}\]

…(20)

其中:

  • $ q_m $ 和 $ k_n $ 的下标表示编码的位置信息。

假设存在一个函数 $ g $,定义了由 $ f_{{q, k}} $ 生成的向量之间的内积:

\[q_m^\intercal k_n = \langle f_q(x_m, m), f_k(x_n, n) \rangle = g(x_m, x_n, n - m),\]

…(21)

我们进一步要求以下初始条件成立:

\[\begin{aligned} q &= f_q(x_q, 0), \\ k &= f_k(x_k, 0), \end{aligned}\]

…(22)

这可以理解为未编码位置信息的向量。在这些设定下,我们尝试找到 $ f_q $ 和 $ f_k $ 的解。首先,我们利用二维向量的几何意义及其复数形式,将公式(20)和(21)中的函数分解为:

\[\begin{aligned} f_q(x_q, m) &= R_q(x_q, m) e^{i \Theta_q(x_q, m)}, \\ f_k(x_k, n) &= R_k(x_k, n) e^{i \Theta_k(x_k, n)}, \\ g(x_q, x_k, n - m) &= R_g(x_q, x_k, n - m) e^{i \Theta_g(x_q, x_k, n - m)}, \end{aligned}\]

…(23)

其中:

  • $ R_f $、$ R_g $ 和 $ \Theta_f $、$ \Theta_g $ 分别是 $ f_{{q, k}} $ 和 $ g $ 的径向和角度分量。将它们代入公式(21),我们得到以下关系:
\[\begin{aligned} R_q(x_q, m) R_k(x_k, n) &= R_g(x_q, x_k, n - m), \\ \Theta_k(x_k, n) - \Theta_q(x_q, m) &= \Theta_g(x_q, x_k, n - m), \end{aligned}\]

…(24)

对应的初始条件为:

\[\begin{aligned} q &= \|q\| e^{i \theta_q} = R_q(x_q, 0) e^{i \Theta_q(x_q, 0)}, \\ k &= \|k\| e^{i \theta_k} = R_k(x_k, 0) e^{i \Theta_k(x_k, 0)}, \end{aligned}\]

…(25)

其中:

  • $ |q| $、$ |k| $ 和 $ \theta_q $、$ \theta_k $ 分别是 $ q $ 和 $ k $ 在二维平面上的径向和角度分量。

接下来,我们在公式(24)中设 $ m = n $,并考虑公式(25)中的初始条件:

\[\begin{aligned} R_q(x_q, m) R_k(x_k, m) &= R_g(x_q, x_k, 0) = R_q(x_q, 0) R_k(x_k, 0) = \|q\| \|k\|, \\ \Theta_k(x_k, m) - \Theta_q(x_q, m) &= \Theta_g(x_q, x_k, 0) = \Theta_k(x_k, 0) - \Theta_q(x_q, 0) = \theta_k - \theta_q. \end{aligned}\]

…(26)

一方面,从公式(26a)可以直接得到 $ R_f $ 的解:

\[\begin{aligned} R_q(x_q, m) &= R_q(x_q, 0) = \|q\|, \\ R_k(x_k, n) &= R_k(x_k, 0) = \|k\|, \\ R_g(x_q, x_k, n - m) &= R_g(x_q, x_k, 0) = \|q\| \|k\|, \end{aligned}\]

…(27)

这表明径向函数 $ R_q $、$ R_k $ 和 $ R_g $ 与位置信息无关。另一方面,从公式(26b)可以看出,$ \Theta_q(x_q, m) - \theta_q = \Theta_k(x_k, m) - \theta_k $ 表明角度函数不依赖于查询和键,我们设 $ \Theta_f := \Theta_q = \Theta_k $,并将 $ \Theta_f(x_{{q, k}}, m) - \theta_{{q, k}} $ 表示为位置 $ m $ 的函数,记为 $ \phi(m) $,从而得到:

\[\Theta_f(x_{\{q, k\}}, m) = \phi(m) + \theta_{\{q, k\}}.\]

…(28)

进一步,将 $ n = m + 1 $ 代入公式(24)并考虑上述方程,我们得到:

\[\phi(m + 1) - \phi(m) = \Theta_g(x_q, x_k, 1) + \theta_q - \theta_k,\]

…(29)

由于右边是一个与 $ m $ 无关的常数,$ \phi(m) $ 在连续整数输入下形成一个等差数列:

\[\phi(m) = m \theta + \gamma,\]

…(30)

其中 $ \theta, \gamma \in \mathbb{R} $ 是常数,且 $ \theta $ 非零。总结从公式(27)到(30)的解:

\[\begin{aligned} f_q(x_q, m) &= \|q\| e^{i (\theta_q + m \theta + \gamma)} = q e^{i (m \theta + \gamma)}, \\ f_k(x_k, n) &= \|k\| e^{i (\theta_k + n \theta + \gamma)} = k e^{i (n \theta + \gamma)}. \end{aligned}\]

…(31)

注意,我们没有对公式(22)中的 $ f_q $ 和 $ f_k $ 施加任何约束,因此 $ f_q(x_m, 0) $ 和 $ f_k(x_n, 0) $ 可以自由选择。为了使我们的结果与公式(3)可比,我们定义:

\[\begin{aligned} q &= f_q(x_m, 0) = W_q x_n, \\ k &= f_k(x_n, 0) = W_k x_n. \end{aligned}\]

…(32)

然后,我们在最终解中简单地设 $ \gamma = 0 $:

\[\begin{aligned} f_q(x_m, m) &= (W_q x_m) e^{i m \theta}, \\ f_k(x_n, n) &= (W_k x_n) e^{i n \theta}. \end{aligned}\]

…(33)

3.4.2 旋转矩阵乘法的计算高效实现

利用公式(15)中 $ R^d_{\Theta, m} $ 的稀疏性,$ R^d_{\Theta} $ 与 $ x \in \mathbb{R}^d $ 的乘法可以更高效地实现为:

\[R^d_{\Theta, m} x = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m \theta_1 \\ \cos m \theta_1 \\ \cos m \theta_2 \\ \cos m \theta_2 \\ \vdots \\ \cos m \theta_{d/2} \\ \cos m \theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_d \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \sin m \theta_1 \\ \sin m \theta_1 \\ \sin m \theta_2 \\ \sin m \theta_2 \\ \vdots \\ \sin m \theta_{d/2} \\ \sin m \theta_{d/2} \end{pmatrix}.\]

…(34)

3.4.3 RoPE的长期衰减性

我们可以将向量 $ q = W_q x_m $ 和 $ k = W_k x_n $ 的条目成对分组,公式(16)中的RoPE内积可以写成复数乘法的形式:

\[(R^d_{\Theta, m} W_q x_m)^\intercal (R^d_{\Theta, n} W_k x_n) = \text{Re} \left[ \sum_{i=0}^{d/2-1} q_{[2i:2i+1]} k^*_{[2i:2i+1]} e^{i (m - n) \theta_i} \right],\]

…(35)

其中:

  • $ q_{[2i:2i+1]} $ 表示 $ q $ 的第 $ 2i $ 到 $ 2i+1 $ 个条目。

记 $ h_i = q_{[2i:2i+1]} k^*{[2i:2i+1]}, S_j = \sum{i=0}^{j-1} e^{i (m - n) \theta_i} $,并设 $ h_{d/2} = 0 $ 和 $ S_0 = 0 $,我们可以使用Abel变换将求和重写为:

\[\sum_{i=0}^{d/2-1} q_{[2i:2i+1]} k^*_{[2i:2i+1]} e^{i (m - n) \theta_i} = \sum_{i=0}^{d/2-1} h_i (S_{i+1} - S_i) = -\sum_{i=0}^{d/2-1} S_{i+1} (h_{i+1} - h_i).\]

…(36)

因此,

\[\left| \sum_{i=0}^{d/2-1} q_{[2i:2i+1]} k^*_{[2i:2i+1]} e^{i (m - n) \theta_i} \right| = \left| \sum_{i=0}^{d/2-1} S_{i+1} (h_{i+1} - h_i) \right| \leq \sum_{i=0}^{d/2-1} |S_{i+1}| |h_{i+1} - h_i| \leq \left( \max_i |h_{i+1} - h_i| \right) \sum_{i=0}^{d/2-1} |S_{i+1}|.\]

…(37)

注意到,通过设置 $ \theta_i = 10000^{-2i/d} $,$ \frac{1}{d/2} \sum_{i=1}^{d/2} \mid S_i \mid $ 的值会随着相对距离 $ m - n $ 的增加而衰减,如图(2)所示。

4.实验

#

https://arxiv.org/pdf/2104.09864

google在《Titans: Learning to Memorize at Test Time》提出了区别于Transformer的的一种新架构:Titans。我们来看一下它的实现,是否有前景:

摘要

在过去的十多年里,关于如何有效利用循环模型(recurrent models)和注意力机制(attentions)的研究已经非常广泛。循环模型的目标是将数据压缩到一个固定大小的内存中(称为hidden state),而注意力机制则允许模型关注整个上下文窗口,捕捉所有token之间的直接依赖关系。然而,这种更精确的依赖关系建模带来了二次方的计算成本(quadratic cost),限制了模型只能处理固定长度的上下文。我们提出了一种新的神经长期记忆模块(neural long-term memory module),该模块能够学习记忆历史上下文,并帮助注意力机制在利用过去信息的同时关注当前上下文。我们展示了这种神经记忆模块具有快速并行化训练的优势,同时保持了快速的推理能力。从记忆的角度来看,我们认为注意力机制由于其有限的上下文但精确的依赖关系建模,起到了短期记忆的作用;而神经记忆(neural memory)由于其能够记忆数据的能力,起到了长期、更持久的记忆作用。基于这两个模块,我们引入了一系列新的架构,称为Titans,并提出了三种变体,以探讨如何有效地将记忆融入这一架构中。我们在语言建模、常识推理、基因组学和时间序列任务上的实验结果表明,Titans比Transformers和最近的现代线性循环模型更有效。此外,与基线模型相比,Titans能够有效地扩展到超过200万的上下文窗口大小,并在“大海捞针”任务中表现出更高的准确性。

1.介绍

Transformers,一种纯基于注意力机制的架构(Vaswani 等人,2017),已被牢牢确立为序列建模中的最先进模型,主要归功于其上下文学习能力和大规模学习能力(Kaplan 等人,2020)。Transformers 的核心构建模块——注意力模块——充当关联记忆模块(Bietti 等人,2024),它们学习存储key-value关联性(associations),并通过计算query(即搜索信号)和key(即上下文)之间的成对相似性来检索这些关联。因此,从设计上看,Transformer 的输出完全取决于当前上下文窗口中token的直接依赖关系。然而,这种精确的依赖关系建模带来了与上下文长度相关的二次方时间和内存复杂度。在复杂的现实任务中(例如语言建模(N. F. Liu 等人,2024)、视频理解(C.-Y. Wu 等人,2019)、长期时间序列预测(H. Zhou 等人,2021)),上下文窗口可能变得非常大,这使得 Transformers 在这些下游任务中的适用性面临挑战。

为了克服 Transformers 的可扩展性问题,最近的研究旨在设计不同变体的线性 Transformers(Kacham、Mirrokni 和 P. Zhong,2024;Katharopoulos 等人,2020;S. Yang、B. Wang、Shen 等人,2024),其中注意力机制中的 softmax 被核函数取代(详见 §2.1),从而显著降低了内存消耗。尽管线性 Transformers 具有高效性并能够扩展到更长的上下文,但与 Transformers 相比,它们的性能并不具有竞争力,因为核技巧使模型变成了线性循环网络,其中数据被压缩为矩阵值状态(Katharopoulos 等人,2020)。然而,这带来了关于线性循环(或线性 Transformers)模型的一个矛盾事实:一方面,我们使用这些线性模型来增强可扩展性和效率(线性与二次方复杂度),其优势在非常长的上下文中显现;另一方面,非常长的上下文无法被适当地压缩到一个小的向量值或矩阵值状态中(S. Wang,2024)。

此外,除了效率问题外,大多数现有架构——从 Hopfield 网络(Hopfield,1982)到 LSTM(Jürgen Schmidhuber 和 Hochreiter,1997)以及 Transformers(Vaswani 等人,2017)——在处理泛化、长度外推和/或推理(Anil 等人,2022;Qin、Y. Zhong 和 Deng,2024)时都面临挑战,而这些是许多复杂现实任务中不可分割的部分。尽管这些架构从人类大脑中汲取了灵感,但它们都缺少以下关键部分:

  • (1)学习过程中的关键组件——例如短期记忆、长期记忆、元记忆、关注当前上下文等(Cowan,2008);
  • (2)这些组件如何作为可以独立运行的互联系统;以及/或
  • (3)从数据中主动学习并记忆过去历史的抽象能力。

我们认为,在一个有效的学习范式中,类似于人类大脑,存在独立但相互关联的模块,每个模块都负责学习过程中至关重要的组件。

记忆视角

记忆(memory)是一种基本的心理过程,也是人类学习中不可分割的组成部分(Terry,2017)。如果没有一个正常运作的记忆系统,人类和动物将只能局限于基本的反射和刻板行为。因此,记忆一直是机器学习文献中许多开创性研究的灵感来源;例如,Hopfield 网络(Hopfield,1982)、LSTM(Jürgen Schmidhuber 和 Hochreiter,1997)以及 Transformers(Vaswani 等人,2017)。

从神经心理学文献中对记忆和学习的常见定义中汲取灵感(Okano、Hirano 和 Balaban,2000),大多数现有架构将记忆视为由输入引起的神经更新,并将学习定义为在给定目标的情况下获取有效且有用记忆的过程。从这个角度来看,循环神经网络(RNN)(Williams 和 Zipser,1989)可以被定义为具有向量值记忆模块 M(也称为hidden state)的模型,其主要步骤包括:

在时间 𝑡 给定新输入 $𝑥_𝑡$ 时,模型

  • (1)使用函数 $𝑓(M_{𝑡−1}, 𝑥_𝑡) $ 更新记忆(带有压缩);
  • (2)使用函数 $𝑔(M_𝑡, 𝑥_𝑡)$ 检索输入的相应记忆(详见 §2.1)。

类似地,Transformers 可以被视为具有增长记忆和两个相似步骤的架构。即,key和value矩阵对充当模型的记忆,模型:

  • (1)通过将key和value附加到记忆中来更新记忆(无压缩),
  • (2)通过查找query向量与key向量的相似性来检索query向量的相应记忆,然后将其用于加权value向量以生成输出。

这种视角可以帮助我们更好地理解现有范式、它们的关键差异,并设计更有效的架构。例如,Transformers(Vaswani 等人,2017)和线性 Transformers(Katharopoulos 等人,2020)之间的主要区别在于记忆结构以及记忆更新步骤,其中:线性 Transformers 将历史数据压缩为固定大小的矩阵值记忆,而 Transformers 则保留所有历史数据(在上下文长度内)而不进行任何压缩。虽然线性 Transformers 和线性 RNN(包括状态空间模型)都在记忆更新步骤中压缩信息,但关键区别在于记忆的结构,其中线性 RNN(相对于线性 Transformers)使用向量值记忆(相对于矩阵值记忆)。因此,这种视角促使我们提出以下问题:

  • (Q1)什么是良好的记忆结构
  • (Q2)什么是适当的记忆更新机制
  • (Q3)什么是良好的记忆检索过程

重新审视我们对人类记忆的理解,它既不是一个单一的过程,也不服务于单一的功能(Cowan,2008)。事实上,记忆是一个系统的联合体——例如短期记忆、工作记忆(working memory)和长期记忆——每个系统服务于不同的功能,具有不同的神经结构,并且每个系统都能够独立运行(Willingham,1997)。这一事实促使我们提出:

  • (Q4)如何设计一个包含不同互联记忆模块的高效架构

最后,存储记忆是一个神经过程,需要对过去的抽象进行编码和存储。假设一个单一向量或矩阵(其参数以线性方式编码数据)足以存储长期历史可能过于简化。

  • (Q5)是否需要深度记忆模块来有效存储/记住遥远的过去?

贡献与路线图

在本文中,我们旨在通过设计一个长期神经记忆模块来回答上述五个问题,该模块能够在测试时高效且有效地学习记忆。基于其设计,我们讨论了如何将其融入架构中。

神经记忆(§3)。我们提出了一种(深度)神经长期记忆模块,它(作为元上下文模型)学习如何在测试时将数据记忆/存储到其参数中。受人类长期记忆系统(Mandler,2014)的启发,

我们设计了这个记忆模块,使得违反预期的事件(即令人惊讶的事件: surprising)更容易被记住。为此,我们通过神经网络在关联记忆损失中对输入的梯度来衡量输入的“惊讶度(surprise)”(详见 §3.1)。为了更好地处理有限的内存,我们提出了一种衰减机制,该机制考虑了内存大小与数据惊讶度的比例,从而实现更好的内存管理。我们展示了这种衰减机制实际上是现代循环模型中遗忘机制的泛化(Dao 和 Gu,2024;Gu 和 Dao,2024;S. Yang、Kautz 和 Hatamizadeh,2024)。有趣的是,我们发现这种机制等同于使用小批量梯度下降、动量和权重衰减来优化元神经网络。基于张量化小批量梯度下降以使用更多矩阵乘法操作(Yu Sun 等人,2024),我们提出了一种快速且可并行化的算法来训练我们的深度神经长期记忆模块。

Titans 架构(§4)

在设计完长期神经记忆模块后,一个重要的问题是:如何高效且有效地将记忆融入深度学习架构中。我们提出了 Titans,这是一个由三个超头部(hyper-heads)组成的深度模型家族:

  • (1)核心模块:该模块包含短期记忆,负责数据处理的主要流程(我们使用有限窗口大小的注意力机制);
  • (2)长期记忆模块:这一分支是我们的神经长期记忆模块,负责存储/记住遥远的过去;
  • (3)持久记忆模块:这是一组可学习但与数据无关的参数,用于编码任务相关知识。

最后,作为概念验证,我们提出了 Titans 的三种变体,其中我们将记忆分别融入为:

  • (i)一个上下文(context)
  • (ii)层(layer)
  • (iii)一个门控分支(gated branch)

实验结果(§5)

我们在语言建模、常识推理、记忆密集型任务、“大海捞针”任务、时间序列预测和 DNA 建模任务上进行了实验评估。我们观察到,Titans 架构在所有现代循环模型及其混合变体(结合滑动窗口注意力机制)的综合基准测试中均表现优异。此外,Titans 在相同上下文窗口下优于 Transformers,并在使用整个上下文的 Transformers 中表现出竞争力。这些结果是在 Titans 能够扩展到超过 200 万上下文窗口大小的情况下实现的,而 Transformers 则无法做到这一点

2 预备知识

在本节中,我们将讨论本文中使用的符号和一些背景概念。我们令:

  • $ x \in \mathbb{R}^{N \times d_{\text{in}}} $ 表示输入
  • $ \mathbf{M} $ 表示神经网络(神经记忆模块:neural memory)
  • $ \mathbf{Q} $、$ \mathbf{K} $、$ \mathbf{V} $ 分别表示注意力机制中的query、key和value
  • $ M $ 表示注意力掩码(attention mask)
  • $ S^{(i)} $ 表示:在对序列进行分段时,使用第 $ i $ 段。

在本文中,我们简化符号并使用下标来指代矩阵、向量或段中的特定元素。例如,我们令:

  • $ S^{(i)}_j $ 表示第 $ i $ 段中的第 $ j $ 个 token。

唯一的例外是下标为 $ t $ 的情况,我们保留它来表示时间上的递归或神经网络在时间 $ t $ 的状态。

给定:神经网络 $ \mathbf{N} $ 和数据样本 $ x $,我们使用:

  • $ \mathbf{N}(x) $(或 $ \mathbf{N}^*(x) $)表示带权重调整(或不带权重调整)的前向传播

此外,我们简化符号并使用:

  • $ \mathbf{N}^{(k)} $ 表示神经网络的第 $ k $ 层

接下来,我们首先讨论注意力机制及其高效变体的背景,然后回顾现代线性 RNN,最后讨论这些架构的记忆视角,这促使我们设计了 Titans。

2.1 背景

注意力机制。Transformers(Vaswani 等人,2017)作为许多深度学习模型的实际骨干,基于注意力机制。给定:

  • 输入 $ x \in \mathbb{R}^{N \times d_{\text{in}}} $

因果注意力机制基于输入依赖的key、value和query矩阵计算输出 $ y \in \mathbb{R}^{N \times d_{\text{in}}} $:

\[\mathbf{Q} = x \mathbf{W}_Q, \quad \mathbf{K} = x \mathbf{W}_K, \quad \mathbf{V} = x \mathbf{W}_V, \quad (1)\] \[y_i = \frac{\sum_{j=1}^i \exp\left(\frac{\mathbf{Q}_i^\top \mathbf{K}_j}{\sqrt{d_{\text{in}}}}\right) \mathbf{V}_j}{\sum_{\ell=1}^i \exp\left(\frac{\mathbf{Q}_i^\top \mathbf{K}_\ell}{\sqrt{d_{\text{in}}}}\right)}, \quad (2)\]

其中:

  • $ W_Q, W_K, W_V \in R^{d_{in} \times d_{in}} $ 是可学习参数

尽管 Transformers 在召回能力上表现出色,但它们至少需要 $ N \times d $ 次操作来计算输出,导致内存消耗较大且对较长序列的吞吐量较低。

高效注意力机制。为了提高软注意力机制在长序列上的内存消耗和吞吐量,许多研究集中在注意力机制的 I/O 感知实现(Dao 2024;Dao, D. Fu 等人,2022),通过稀疏化注意力矩阵(B. Chen 等人,2021;Choromanski 等人,2021;Dai 等人,2019)、近似 softmax(Arora 等人,2024)或开发基于核的(线性)注意力机制(Aksenov 等人,2024;Kacham, Mirrokni 和 P. Zhong,2024;Schlag, Irie 和 Jürgen Schmidhuber,2021;S. Yang, B. Wang, Shen 等人,2024)来设计更高效的注意力机制。在本部分,我们重点关注后者,即线性注意力机制,其中标准注意力中的 softmax 被替换为替代核函数 $ \phi(\cdot, \cdot) $,使得 $ \phi(x, y) = \phi(x) \phi(y) $。因此,注意力可以写成:

\[y_i = \frac{\sum_{j=1}^i \phi(\mathbf{Q}_i^\top \mathbf{K}_j) \mathbf{V}_j}{\sum_{\ell=1}^i \phi(\mathbf{Q}_i^\top \mathbf{K}_\ell)} = \frac{\phi(\mathbf{Q}_i)^\top \sum_{j=1}^i \phi(\mathbf{K}_j) \mathbf{V}_j}{\phi(\mathbf{Q}_i)^\top \sum_{\ell=1}^i \phi(\mathbf{K}_\ell)}, \quad (3)\]

由于:

  • 项 $ \sum_{j=1}^i \phi(K_j), \sum_{\ell=1}^i \phi(K_\ell) $ 在每一步中重复使用,因此吞吐量更高。

当选择核函数为单位矩阵时(Yutao Sun 等人,2023),上述公式可以写成递归形式:

\[\mathbf{M}_t = \mathbf{M}_{t-1} + \mathbf{K}_t^\top \mathbf{V}_t, \quad (4)\] \[y_t = \mathbf{Q}_t \mathbf{M}_t, \quad (5)\]

这使得线性注意力机制能够高效推理。

现代线性模型及其记忆视角。如前所述,可以将学习定义为获取有效且有用记忆的过程。基于此,可以将循环神经网络(RNN)的hidden state视为记忆单元,模型旨在将信息压缩到其中。因此,在一般形式的循环神经网络中,hidden state可以被视为记忆单元,递归过程可以分为记忆单元的读和写操作。即,令:

  • $ x \in \mathbb{R}^{N \times d_{\text{in}}} $ 为输入
  • $ \mathbf{M} \in \mathbb{R}^d $ 为记忆单元
  • $ y \in \mathbb{R}^{d_{\text{in}}} $ 为输出

则循环神经网络的一般形式定义为:

\[\mathbf{M}_t = f(\mathbf{M}_{t-1}, x_t), \quad \text{写操作} \quad (6)\] \[y_t = g(\mathbf{M}_t, x_t), \quad \text{读操作} \quad (7)\]

其中:

  • $ f(\cdot, \cdot) $ 是读操作,
  • $ g(\cdot, \cdot) $ 是写操作。

注意,这里的 $ \mathbf{M}_t $ 下标表示记忆在时间 $ t $ 的状态。

从这一视角来看,线性 Transformers 的递归公式(见公式 4)等同于将键和值 $ (\mathbf{K}_t, \mathbf{V}_t) $ 加性地压缩并写入矩阵值记忆单元 $ \mathbf{M}_t $ 中。因此,在处理长上下文数据时,这种加性特性会导致内存溢出,显著损害模型性能。为了解决这一问题,研究集中在两个有前景的方向上:

  • (1)添加遗忘机制:一些研究提出了线性模型的自适应(数据依赖)遗忘门机制,可以在需要时擦除记忆。例如,GLA(S. Yang, B. Wang, Shen 等人,2024)、LRU(Orvieto 等人,2023)、Griffin(De 等人,2024)、xLSTM(Beck 等人,2024)和 Mamba2(Dao 和 Gu,2024)等模型,后者还与离散化的传统状态空间模型(Gu 和 Dao,2024)相关联。
  • (2)改进写操作:为了克服传统循环模型中记忆写操作的加性特性,Widrow 和 Hoff(1988)提出了 Delta 规则,在添加记忆(即键值对)之前,模型首先移除其过去的值。为了增强可并行化训练和扩展性,S. Yang, B. Wang, Yu Zhang 等人(2024)提出了一种快速并行化算法。最后,最近 S. Yang, Kautz 和 Hatamizadeh(2024)通过添加遗忘门改进了 DeltaNets。

记忆模块。记忆一直是神经网络设计的核心部分之一(Graves, Wayne 和 Danihelka,2014;JH Schmidhuber,1992;Jürgen Schmidhuber 和 Hochreiter,1997;J. Zhang 等人,2024)。将线性层视为键值(关联)记忆系统(key-value (associative) memory system)的思想可以追溯到快速权重程序(fast weight programs),其中动态快速程序被纳入循环神经网络中作为可写记忆(JH Schmidhuber,1992)。Hebbian(Hebb,2005)和 delta(Prados 和 Kak,1989)学习规则是快速权重程序中最流行的学习规则,已在各种研究中广泛探索(Irie, Schlag 等人,2021;Munkhdalai, Sordoni 等人,2019;Munkhdalai 和 H. Yu,2017;Schlag, Irie 和 Jürgen Schmidhuber,2021;JH Schmidhuber,1992;S. Yang, Kautz 和 Hatamizadeh,2024;S. Yang, B. Wang, Yu Zhang 等人,2024)。然而,所有这些模型都基于瞬时惊讶度,忽略了序列中的 token 流(见第 3.1 节),并且大多数模型缺乏遗忘门,导致内存管理不佳。

我们在附录 C 中进一步讨论了我们的架构与最近模型的联系。其他相关工作在附录 A 中讨论。

3 测试时的记忆学习

为了克服长期记忆的不足,并使模型能够学习、遗忘和检索信息,本节提出了一种神经长期记忆模块,这是一种在测试时学习记忆的元模型。

  • 在 3.1 节中,我们首先讨论神经记忆的动机和设计。
  • 在 3.2 节中,我们讨论如何通过快速且可并行化的训练使我们的架构设计受益。
  • 在 3.3 节中,我们通过持久记忆模块增强我们的架构,其中使用可学习但与数据无关的参数来学习任务的元信息。

图片名称

图1 关于如何并行训练神经记忆并使用矩阵乘法(matmuls)的图示说明。

3.1 长期记忆

为了设计一个神经长期记忆模块,我们需要一个能够将过去历史的抽象编码到其参数中的模型。一个例子是大语言模型(LLMs),它们被证明能够记忆训练数据(Leybzon 和 Kervadec,2024;Schwarzschild 等人,2024;Staab 等人,2024)。因此,一个简单的想法是:训练一个神经网络并期望它记忆其训练数据。然而,记忆化几乎总是被认为是神经网络中的不良现象,因为它限制了模型的泛化能力(Bayat 等人,2024),引发隐私问题(Staab 等人,2024),并导致测试时性能不佳。此外,训练数据的记忆化在测试时可能没有帮助,因为数据可能是分布外的。我们认为,我们需要一个在测试时学习如何记忆/遗忘数据的在线元模型。在这种设置中,模型学习的是一个能够记忆的函数,但它不会过度拟合训练数据,从而在测试时实现更好的泛化。

学习过程与惊讶度度量(Surprise Metric)。训练长期记忆的关键思想是:将其训练视为一个在线学习问题,我们的目标是将过去的信息 $ x_1, \ldots, x_{t-1} $ 压缩到长期神经记忆模块 $ \mathbf{M}_t $ 的参数中。如前所述,违反预期的事件(即令人惊讶的事件)对人类来说更容易被记住(Mandler,2014)。受此启发,模型惊讶度的一个简单定义可以是:其相对于输入的梯度。梯度越大,输入数据与过去数据的差异越大。因此,使用这种惊讶度评分,我们可以更新记忆如下:

\[\mathbf{M}_t = \mathbf{M}_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) \quad \text{(惊讶度)} \quad (8)\]

然而,这种惊讶度度量可能会导致错过在重大惊讶时刻之后的重要信息。也就是说,梯度在几次惊讶步骤后可能变得非常小,导致陷入平坦区域(即局部最小值),并错过序列某些部分的信息。从人类记忆的角度来看,一个事件可能不会在长时间内持续让我们感到惊讶,尽管它是值得记忆的。原因是初始时刻足够令人惊讶,足以在长时间内吸引我们的注意力,从而记住整个时间段。为了改进上述惊讶度度量(公式 8),我们将惊讶度度量分为:

  • (1)过往惊讶度(past surprise),衡量最近过去的惊讶程度;
  • (2)瞬时惊讶度(momentary surprise),衡量输入数据的惊讶程度:
\[\mathbf{M}_t = \mathbf{M}_{t-1} + S_t, \quad (9)\] \[S_t = \eta_t S_{t-1} \quad \text{(过去惊讶度)} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) \quad \text{(瞬时惊讶度)} \quad (10)\]

有趣的是,这个公式类似于带有动量的梯度下降,其中:

  • $ S_t $ 是动量项

因此,这里的动量充当了跨时间(序列长度)的惊讶度记忆。在这个公式中:

  • 项 $ \eta_t $ 是一个数据依赖的惊讶度衰减($ x_t $ 的函数),控制惊讶度随时间衰减的程度
  • 项 $ \theta_t $ 则控制瞬时惊讶度应以数据依赖的方式纳入最终惊讶度度量的多少

这种数据依赖性在这个设计中尤为重要:虽然前一个 token 的惊讶度可能需要影响下一个 token 的惊讶度,但这只有在所有 token 都相关且处于同一上下文中时才有效。因此,数据依赖的 $ \eta $ 可以控制记忆是否需要:

  • (1)通过设置 $ \eta_t \to 0 $ 忽略上一次的惊讶度(可能由于上下文的变化),
  • (2)通过设置 $ \eta_t \to 1 $ 完全纳入上一次的惊讶度(可能因为 token 与其最近的过去 token 高度相关)。

目标。我们上述的惊讶度度量基于损失函数 $ \ell(\cdot; \cdot) $,这是我们的记忆模块在测试时学习的目标。也就是说,我们的记忆模块是一个元模型,它基于损失函数 $ \ell(\cdot; \cdot) $ 学习一个函数。

在本节中,我们重点讨论关联记忆,其目标是将过去的数据存储为k-V对。给定输入 $ x_t $,类似于 Transformers(Vaswani 等人,2017),我们使用两个线性层将 $ x_t $ 投影为key和value:

\[\mathbf{k}_t = x_t \mathbf{W}_K, \quad \mathbf{v}_t = x_t \mathbf{W}_V, \quad (11)\]

其中 $ W_K $ 和 $ W_V \in R^{d_{in} \times d_{in}} $。接下来,我们希望记忆模块能够学习键和值之间的关联。为此,我们定义损失函数如下:

\[\ell(\mathbf{M}_{t-1}; x_t) = \|\mathbf{M}_{t-1}(\mathbf{k}_t) - \mathbf{v}_t\|_2^2, \quad (12)\]

通过在元模型(记忆)的内循环中优化上述损失函数,模型学习如何在测试时记忆键和值之间的映射。需要注意的是,类似于元学习模型(Nichol,2018;Zintgraf 等人,2019),记忆的训练是在内循环中进行的,因此参数 $ \mathbf{W}_K $ 和 $ \mathbf{W}_V $ 是上述损失函数中的超参数。因此,在内循环中,我们优化记忆模块 $ \mathbf{M} $ 的权重,而在外循环中,我们优化整个架构的其他参数。

遗忘机制

当处理非常长的序列(例如数百万个 token)时,管理哪些过去信息应该被遗忘至关重要——即使使用深度或非常大的矩阵值记忆。为此,我们使用一种自适应遗忘机制,允许记忆遗忘不再需要的信息,从而更好地管理记忆的有限容量。具体来说,给定下一个 token $ x_t $,我们修改更新规则如下:

\[\mathbf{M}_t = (1 - \alpha_t) \mathbf{M}_{t-1} + S_t, \quad (13)\] \[S_t = \eta_t S_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t), \quad (14)\]

其中:

  • $ \alpha_t \in [0, 1] $ 是一个门控机制,灵活控制记忆;即决定应该遗忘多少信息。例如:

  • 通过设置 $ \alpha_t \to 0 $,可以在不影响过去抽象的情况下更新记忆;
  • 通过设置 $ \alpha_t \to 1 $,可以清除整个记忆

在本节后面,我们将展示这种权重衰减机制与现代 RNN 中的门控机制密切相关(Dao 和 Gu,2024;Orvieto 等人,2023)。

记忆架构

在本文中,我们专注于使用具有 $ L_M \geq 1 $ 层的简单多层感知机(MLP)作为长期记忆的架构。选择这种架构的主要原因是,我们希望集中精力更好地激励长期记忆的设计及其融入架构的方式。然而,我们的公式和架构设计为设计在数据记忆方面更有效和高效的神经架构开辟了新的研究方向。最近,有一些有前景的工作致力于设计此类架构(Berges 等人,2024;Cetin 等人,2024;J. Zhang 等人,2024),将这些架构融入我们的框架(即用此类架构替换简单的 MLP)可能是一个有趣的未来工作方向。

当使用向量值或矩阵值记忆(De 等人,2024;Orvieto 等人,2023;S. Yang, B. Wang, Shen 等人,2024)时,记忆模块会压缩过去的数据并将其拟合到一条线上。也就是说,从元学习或在线学习的角度来看(Yu Sun 等人,2024),使用矩阵值记忆 $ \mathbf{M} = \mathbf{W} \in \mathbb{R}^{d_{\text{in}} \times d_{\text{in}}} $ 等同于优化 $ \ell(\mathbf{W}{t-1}; x_t) = |\mathbf{W}{t-1} \mathbf{k}_t - \mathbf{v}_t|_2^2 $,这是一个在线线性回归目标,因此最优解假设历史数据的潜在依赖关系是线性的。另一方面,我们认为深度记忆模块(即 $ L_M \geq 2 $ 层)在实践中更有效。这与理论结果一致,即至少具有两层的 MLP 严格比线性模型更具表达能力(Hornik, Stinchcombe, and White, 1989)。在第 5.5 节中,我们展示了深度记忆模块在实际应用中的有效性。


记忆检索

在上面,我们讨论了如何设计和训练一个在测试时学习记忆的长期记忆模块。一个关键的问题是:如何从记忆中检索信息?我们简单地使用不更新权重的前向传播(即推理)来检索与查询对应的记忆。形式上,给定输入 $ x_t $,我们使用线性层 $ \mathbf{W}_Q $ 投影输入,即 $ \mathbf{q}_t = x_t \mathbf{W}_Q $,并通过以下方式从记忆中检索相应的(或有用的)信息 $ y_t $:

\[y_t = \mathbf{M}^*(\mathbf{q}_t). \quad (15)\]

图片名称

图2:记忆作为上下文(MAC)架构。该架构包括三个分支:(1)核心分支,(2)上下文(长期)记忆分支,以及(3)持久记忆分支。核心分支将相应的长期记忆和持久记忆与输入序列连接起来。接下来,注意力机制在序列上执行,并决定哪些信息应存储在长期记忆中。在测试时,与上下文记忆对应的参数仍在学习,与核心分支对应的参数负责上下文学习,而持久记忆的参数负责存储任务知识,因此是固定的。

3.2 如何并行化长期记忆的训练

如上所述,我们的长期记忆模块的设计等同于通过优化关联记忆损失函数 $ \ell(\mathbf{M}{t-1}; x_t) = |\mathbf{M}{t-1}(\mathbf{k}_t) - \mathbf{v}_t|_2^2 $ 来训练一个元模型,使用带有动量和权重衰减的梯度下降。因此,理论上,长期记忆模块的训练需要 $ O(N) $ 的浮点运算(FLOPs),其中 $ N $ 是序列长度。然而,在实践中,我们需要并行化训练过程,并充分利用硬件加速器(例如 TPU、GPU),因此需要将过程张量化并使用更多的矩阵乘法(matmuls)。

接下来,我们展示如何通过小批量梯度下降、数据依赖的学习率和权重衰减来重新表述内循环中的权重计算,使其仅使用矩阵乘法和求和。我们基于 Yu Sun 等人(2024)的工作,该工作表明,使用小批量梯度下降(具有恒定学习率)优化的模型的前向传播可以通过矩阵乘法计算。我们可以将序列分割为大小为 $ b \geq 1 $ 的块,并将小批量梯度下降表示为:

\[\mathbf{M}_t = (1 - \alpha_t) \mathbf{M}_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) = \beta_t \mathbf{M}_0 - \sum_{i=1}^t \theta_i \frac{\beta_t}{\beta_i} \nabla \ell(\mathbf{M}_{t'}; x_i), \quad (16)\]

其中 $ t’ = t - \text{mod}(t, b) $,且 $ \beta_i = \prod_{j=1}^i (1 - \alpha_j) $。为了简化,我们专注于第一个块,即 $ t = b $,因此 $ t’ = 0 $。此外,我们解释当 $ \mathbf{M}_t = \mathbf{W}_t $ 是线性时的情况。对于具有 $ N_p \geq 2 $ 层的 MLP,过程类似。使用我们的损失函数,我们有:

\[\nabla \ell(\mathbf{W}_0; x_t) = (\mathbf{W}_0 x_t - x_t) x_t^\top \Rightarrow \sum_{i=1}^b \theta_i \frac{\beta_b}{\beta_i} \nabla \ell(\mathbf{W}_0; x_i) = \Theta_b \mathbf{B}_b (\mathbf{W}_0 \mathbf{X} - \mathbf{X}) \mathbf{X}^\top, \quad (17)\]

其中 $ \Theta_b = \text{diag}(\theta_1, \theta_2, \ldots, \theta_b) $,且 $ \mathbf{B}b $ 类似地定义在 $ \frac{\beta_b}{\beta_i} $ 上。需要注意的是,我们不需要存储所有 $ \Theta{kb} $ 和 $ \mathbf{B}_{kb} $($ k = 1, \ldots, N/b $),而是为每个块存储这些矩阵,从而减少内存使用。接下来,我们扩展这种表示,以便还可以纳入动量项。在带有动量的小批量梯度下降中,如果我们看动量项,我们有:

\[S_t = \eta_t S_{t-1} - \theta_t u_t, \quad (18)\]

其中 $ u_t = \nabla \ell(\mathbf{M}_{t’}; x_t) $。需要注意的是,我们可以同时计算所有 $ u_t $,因此公式 (18) 是一个线性递归,其中 $ u_t $ 是输入,$ S_t $ 是隐藏状态,$ \eta_t $ 是输入依赖的转移值。因此,我们可以使用并行关联扫描(J. T. Smith, Warrington, and Linderman, 2023)来计算该块中的 $ S_t $。

参数作为块的函数

与其让参数 $ \alpha_t $、$ \theta_t $ 和 $ \eta_t $ 依赖于输入(即 token $ x_t $ 的函数),我们可以让它们成为块的函数。尽管这会降低表达能力,但这种表述可以帮助使训练更快。在这种情况下,我们在每个块中对 $ \alpha $、$ \theta $ 和 $ \eta $ 使用相同的值。因此,在公式 (17) 中,我们可以使用单个标量存储 $ \Theta $。类似地,我们可以使公式 (18) 更快。也就是说,当 $ \eta $ 和 $ \theta $ 在每个块内可学习但时间不变时,该方程变为线性时不变系统(LTI),可以通过全局卷积计算(Gu, Goel, and Re, 2022)。在我们的实验中,我们将这些参数作为 token 的函数。然而,这种简化(即作为块的函数)可能是未来工作的兴趣点,以便以更高效的方式训练更大的模型。

3.3 持久记忆

我们的长期记忆也可以被视为一种上下文记忆,这意味着输出完全依赖于上下文。因此,除了长期记忆外,我们还使用一组可学习但与输入无关的参数来充当任务相关的记忆。这种类型的记忆在文献中被称为持久记忆或元记忆(X. Dong 等人,2024;Sukhbaatar, Grave 等人,2019)。给定 $ N_p \geq 1 $,我们使用可学习参数 $ P = [p_1, p_2, \ldots, p_{N_p}] $ 并将其附加到序列的开头:即,给定上下文窗口大小为 $ N $,我们将输入修改为:

\[x_{\text{new}} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (19)\]

其中 $ \parallel $ 表示连接操作。接下来,我们从三个角度讨论持久记忆的动机:


记忆视角

如前所述,我们的神经长期记忆是一种上下文记忆,其中所有参数都依赖于输入。然而,一个有效的记忆系统还需要与输入无关的参数来存储任务知识的抽象。也就是说,掌握一个任务需要记忆如何完成该任务的知识,而这些参数负责存储此类知识。


前馈网络视角

在 Transformer 架构中,注意力模块之后有全连接层,这些层被证明类似于注意力权重,但具有与数据无关的参数。即,Sukhbaatar, Grave 等人(2019)表明,将全连接层中的 ReLU 替换为 Softmax 可以产生类似注意力的权重,其中权重与数据无关:

\[FFN(x) = W_V \text{Softmax}(W_K x), \quad (20)\]

实际上,当 $ W_K $ 和 $ W_V $ 与输入无关时,它们的作用类似于注意力模块中的 $ K $ 和 $ V $ 矩阵。持久记忆权重预计具有相同的功能,这意味着在序列的开头部分使用它们会导致具有与输入无关的注意力权重(Sukhbaatar, Grave 等人,2019)。


技术视角

带有因果掩码的注意力机制对序列中的初始 token 具有隐式偏差,因此注意力权重几乎总是对初始 token 高度活跃,从而导致性能下降。从技术角度来看,序列开头的这些可学习参数可以通过更有效地重新分配注意力权重来缓解这种影响(Han 等人,2024;Xiao 等人,2024)。


总结

  • 持久记忆的作用:存储任务知识的抽象,与输入无关。
  • 前馈网络的类比:持久记忆权重类似于注意力机制中的 $ K $ 和 $ V $ 矩阵,但具有与数据无关的特性。
  • 技术优势:通过在序列开头引入可学习参数,持久记忆可以缓解注意力机制对初始 token 的偏差,从而提升模型性能。

持久记忆的引入为模型提供了任务知识的存储能力,并通过优化注意力权重的分配进一步提升了模型的性能。

图片名称

图3

4 如何融入记忆?

一个重要但尚未解答的问题是:如何有效且高效地将设计的神经记忆融入深度学习架构中?如前所述,从记忆的角度来看,Transformers 中的键值对矩阵可以解释为关联记忆块。由于其依赖关系的精确建模以及有限的上下文窗口,我们将其解释为短期记忆模块,专注于当前上下文窗口大小。另一方面,我们的神经记忆能够从数据中持续学习并将其存储在其权重中,可以扮演长期记忆的角色。在本节中,我们旨在通过提出 Titans 的三种不同变体来回答上述问题。在后续实验中,我们将展示每种变体的优缺点,以及在超长上下文中的效率与有效性之间的权衡。


4.1 记忆作为上下文(Memory as a Context, MAC)

在第一种架构设计中(见图 2),我们将记忆视为当前信息的上下文。即,给定一个长序列 $ x \in \mathbb{R}^{N \times d_{\text{in}}} $,我们首先将序列分割为固定大小的段 $ S^{(i)} $($ i = 1, \ldots, N/C $)。给定传入的段 $ S^{(t)} $,我们将其视为当前上下文,并将其过去的段视为历史信息。因此,设 $ \mathbf{M}{t-1} $ 为段 $ S^{(t)} $ 之前的长期记忆状态,我们使用输入上下文作为查询,从长期记忆 $ \mathbf{M}{t-1} $ 中检索相应的信息。即,我们检索与 $ S^{(t)} $ 对应的过去信息:

\[h_t = \mathbf{M}^*_{t-1}(\mathbf{q}_t), \quad (21)\]

其中 $ \mathbf{q}_t = S^{(t)} \mathbf{W}_Q $。接下来,我们将此历史信息与持久记忆参数一起作为注意力模块的输入序列:

\[\tilde{S}^{(t)} = [p_1, p_2, \ldots, p_{N_p}] \parallel h_t \parallel S^{(t)}, \quad (22)\] \[y_t = \text{Attn}(\tilde{S}^{(t)}). \quad (23)\]

整个序列的注意力图结构如图 3a 所示。然后,我们使用 $ y_t $ 更新长期记忆模块以用于下一个段和最终输出:

\[\mathbf{M}_t = \mathbf{M}_{t-1}(y_t), \quad (24)\] \[o_t = y_t \otimes \mathbf{M}^*_t(y_t). \quad (25)\]

需要注意的是,在上述过程中,我们通过前向传播更新 $ \mathbf{M}_{t-1} $ 的权重。

图片名称

图4

该架构具有两个关键优势:

  1. 注意力机制:通过同时拥有历史和当前上下文,能够决定在当前数据下是否需要长期记忆信息。
  2. 记忆管理:注意力模块帮助长期记忆仅存储当前上下文中有用的信息,避免记忆溢出。
  3. 测试时的行为
    • 持久记忆参数固定,因为它们编码了任务知识,不应更改。
    • 注意力模块权重是上下文学习者。
    • 长期记忆模块在测试时仍在学习(记忆)信息,即神经记忆的权重在测试时也会更新,以编码过去的抽象。

4.2 门控记忆(Gated Memory)

在第二种变体中(见图 4),我们在一个分支中直接使用输入数据更新长期记忆,在另一个分支中使用滑动窗口注意力(SWA):

\[\tilde{x} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (26)\] \[y = \text{SW-Attn}^*(\tilde{x}), \quad (27)\] \[o = y \otimes \mathbf{M}(\tilde{x}), \quad (28)\]

其中 $ \text{SW-Attn}^* $ 是带有前缀的滑动窗口注意力(见图 3b)。需要注意的是,与之前的设计不同,我们不对输入数据进行分段。此外,我们简化符号并使用 $ \mathbf{M}(x) $ 表示序列中所有 token 递归后的记忆最终输出。在上述等式中,$ \otimes $ 可以是任何非线性门控。在我们的实验中,我们使用可学习的向量值权重对输出 $ y $ 和 $ \mathbf{M}(\tilde{x}) $ 进行归一化,然后应用非线性函数 $ \sigma(\cdot) $。

该设计的整体注意力掩码如图 3b 所示。在这种设计中,滑动窗口注意力充当精确的短期记忆,而神经记忆模块则充当模型的衰减记忆。这种架构设计也可以被视为一种多头架构,其中头的结构不同(X. Dong 等人,2024)。

4.3 记忆作为层(Memory as a Layer, MAL)

在第三种变体中,我们将神经记忆作为深度神经网络的一层(见图 5)。这种架构设计在文献中更为常见,其中混合模型将循环模型与完整或滑动窗口注意力机制堆叠在一起。给定输入 $ x $,我们有:

\[\tilde{x} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (29)\] \[y = \mathbf{M}(\tilde{x}), \quad (30)\] \[o = \text{SW-Attn}(y), \quad (31)\]

其中 $ \text{SW-Attn} $ 是滑动窗口注意力。这种设计的主要缺点是模型的能力受限于每一层,因此无法充分利用注意力和神经记忆模块的互补数据处理能力。在我们的实验中,为了评估这种设计中的记忆,我们使用了类似于 H3(D. Y. Fu 等人,2023)的架构,其中我们将序列模型替换为我们的神经记忆模块(LMM)。


无注意力的记忆

尽管上述讨论中我们将 MAL 视为 LMM 和注意力机制的顺序组合,但 MAL 的一个简单变体是将 LMM 视为没有任何注意力机制的序列模型。从记忆的角度来看,如第 1 节所述,我们期望记忆系统的每个部分都能独立工作,即使其他组件受到干扰。因此,即使没有短期记忆(即注意力机制),长期记忆模块仍然应该是一个强大的模型。我们在实验中称这种变体为 LMM 或 Titans(LMM)。我们在附录 C 中提供了关于 Titans 与其他现代循环模型联系的更多讨论。

图片名称

图5


4.4 架构细节

为了简洁和清晰,我们避免讨论实现细节,例如使用残差连接、线性层门控和归一化。在所有块中,我们使用残差连接。在我们的实现中,我们使用 SiLU(.) 激活函数(Elfwing, Uchibe, and Doya, 2018)作为计算查询、键和值的非线性激活,并使用 $ \ell_2 $-范数对查询和键进行归一化。


卷积

遵循最近的现代线性循环模型(Gu 和 Dao,2024;S. Yang, Kautz, and Hatamizadeh,2024),我们在每个查询、键和值投影之后加入一维深度可分离卷积层。虽然这些一维卷积对性能的影响不大,但它们已被证明可以提升性能,并且在计算上也很高效。


门控

我们还遵循最近的架构,在最终输出投影之前使用归一化和线性层门控(Mehta 等人,2023)。


定理 4.1

与 Transformers、对角线性循环模型和 DeltaNet 不同,这些模型都受限于 $ \text{TC}^0 $(Merrill, Petty, and Sabharwal, 2024),Titans 能够解决超出 $ \text{TC}^0 $ 的问题,这意味着 Titans 在状态跟踪任务中理论上比 Transformers 和大多数现代线性循环模型更具表达能力。

#

https://arxiv.org/pdf/2501.00663v1

meta Ins在《QuickUpdate: a Real-Time Personalization System for Large-Scale Recommendation Models》给出了它们的系统实现:

摘要

深度学习推荐模型在在线公司中扮演着重要角色,并且占据了用于训练和推理的AI基础设施的主要部分。这些模型的准确性高度依赖于它们在服务端的发布速度。提高模型更新延迟和更新频率的主要挑战之一是:模型大小(model size),这些模型size已经达到了TB级别,并且预计未来还会进一步增加。大的模型size导致了在分布式服务器中更新模型时的高延迟(和写入带宽)。我们提出了QuickUpdate,一个用于大规模推荐模型实时个性化的系统,它能够作为在线训练的一部分,可以高频率地进行模型发布,提供与完全新鲜模型相当的服务准确性。该系统采用了新技术来最小化所需的写入带宽,包括:优先参数更新、间歇性全模型更新、模型转换和宽松一致性(relaxed consistency)。我们使用真实世界的数据在Meta的一个最大生产模型上评估了QuickUpdate。结果表明,QuickUpdate提供了与完全新鲜模型相当的服务准确性,同时将平均发布的更新大小和所需带宽减少了超过13倍。它为实时服务生产模型提供了一个可扩展的解决方案,这在网络和存储带宽有限的情况下,否则是不可能大规模实现的。

1 引言

深度学习推荐模型(DLRM)在许多在线公司中被广泛使用。这些模型通过大规模数据进行训练,以学习用户和产品特征,从而在各种场景中提供个性化推荐。例如,Netflix [7] 和 YouTube [4] 为用户提供电影列表;Amazon [19] 和 Alibaba [20] 根据用户的搜索查询推荐相关产品;Google [3] 和 Meta [23] 则根据用户兴趣展示广告和内容。DLRM 在这些公司中占据了 AI 基础设施的主要部分。以 Meta 为例,DLRM 消耗了超过 80% 的机器学习推理周期 [8] 和超过 50% 的训练周期。

推荐模型有助于业务增长。例如,它们贡献了 Amazon 总购买量的 35% [8, 14]。由于这种广泛的业务影响,准确性成为大规模推荐模型的重要性能指标。特别是在 Meta 的业务中,设计检查点(checkpoint)和量化算法时要求准确性损失小于 0.01%[5]。这是一个非常狭窄的容差范围,表明了推荐模型及其准确性的重要性。

模型新鲜度 是个性化推荐模型准确性的关键因素 [4, 6, 9, 22, 25]。由于模型在高度动态的环境中运行推理,准确性可能会迅速下降。例如,每天都有新用户和物品注册到系统中,用户的兴趣可能会受到近期事件的影响。如果模型没有频繁更新,它将无法反映用户和产品的变化,从而导致准确性逐渐下降。为了进一步强调新鲜度的影响,图 3 展示了模型在数小时未刷新时的显著准确性损失。因此,为了将准确性保持在可接受的水平,推荐模型需要使用最新数据进行重新训练,并使用更新后的模型来服务实时推理。

保持推理模型新鲜的一种常见技术是在线训练。与每次从头开始重新训练模型不同,它使用实时流数据不断训练和优化模型。定期创建模型的快照并将其发布到位于不同地理区域的数百台服务器中。这些服务器随后利用该模型对在线查询进行实时预测。然而,更新服务模型会带来训练集群与分布式服务主机之间的延迟,导致模型刷新延迟,这主要是由于现代模型的规模庞大。

多年来,模型规模迅速增长,达到了 TB 级别,并包含数万亿个参数 [5, 10, 15],以捕获数百万个sparse特征并提高模型准确性。有限的写入带宽在将如此大的模型传输到分布式服务器和存储时构成了挑战。因此,更新延迟可能会延长到数小时。如第 3 节详细讨论的那样,这种长时间的延迟可能会对准确性产生不利影响。

为了解决上述由大模型规模及其导致的更新延迟带来的挑战,我们提出了 QuickUpdate。QuickUpdate 采用以下设计元素来实现大规模 DLRM 的实时个性化:

  1. 优先参数更新:在数百个地理分布的节点中完全更新所有服务模型的参数需要大量的网络和存储带宽,这构成了瓶颈。
    QuickUpdate 通过优先参数选择来最小化更新规模。它对服务模型中的特定参数进行排序和选择,同时从更新中修剪其余参数。这种方法显著减少了总体更新规模并缓解了带宽需求。
    参数排序算法在最小化更新规模时避免准确性下降至关重要。

  2. 间歇性全模型更新:间歇性全模型更新是指在一系列连续的部分更新之后进行一次完整的模型更新。这些完整更新的主要目的是保持服务模型的长期准确性。每次部分更新后,服务模型会偏离训练模型,因为前者仍然使用过时的参数值。随着更多部分更新的进行,这种偏差会变得更大,从而导致潜在的准确性影响。为了提高准确性,间歇性地发布完整模型更新以限制服务模型与训练模型之间的差距。

  3. 实时更新的模型转换:QuickUpdate 采用了几种模型转换技术来减少发布的模型规模,包括推理剪枝和量化。
    量化已在一些研究中成功实施 [11, 24, 27],以在不牺牲准确性的情况下降低浮点精度。它有助于减少推理集群中的存储需求和通信成本。推理剪枝则应用于非常大的查找表。在 DLRM 中,实体(如用户或视频)及其对应的向量以查找嵌入表的形式存储。实际不活动的实体索引(或 ID)从服务平台中修剪,以显著减少服务模型的规模。

  4. 简化的服务设计和宽松的一致性要求:在传统的服务设计中,模型在开始服务查询之前会完全加载到服务平台中,以保持强一致性。在这种设计中,每个推理请求都基于特定版本的模型权重执行,确保一致且可靠的结果。然而,这种方法由于使用额外的缓冲节点而带来了相当大的基础设施开销。
    在 QuickUpdate 中,我们通过放宽一致性要求引入了一种更高效的服务设计。权重直接在服务节点中更新,而不是使用缓冲节点。这消除了对额外基础设施的需求并减少了开销。然而,这种宽松的设计可能会导致嵌入表中的一些不一致性,因为它们可能包含新鲜和过时权重的混合
    尽管嵌入表中可能存在不一致性,但我们的评估表明,服务模型的准确性并未受到影响,反而带来了准确性的提升。

我们使用真实世界的数据和 Meta 生产中部署的最大模型之一对 QuickUpdate 进行了评估。总体而言,我们的结果表明,QuickUpdate 能够提供与完全新鲜模型相当的服务准确性,同时将所需的写入带宽减少了超过 13 倍。它为实时服务生产 DLRM 提供了一个可扩展的解决方案,而这在网络和存储带宽有限的情况下是无法大规模实现的。QuickUpdate 通过利用新颖的技术实现了这一点,包括选择性发布每次更新的最重要部分,同时仍然结合低频的间歇性全模型更新以确保长期准确性

2 背景

2.1 深度学习推荐模型(DLRM)

通常,深度学习推荐模型由dense和dense层组成,如图 1 所示 [5, 10, 26]。sparse层实际上是嵌入表,其中:每个嵌入表表示一个分类特征,表的每一行表示一个特定ID(例如用户 ID 或视频 ID)。嵌入表将每个 ID 转换为一个固定大小的浮点值向量,这些向量是可训练的。模型中其余可训练的部分称为dense层。

图片名称

图 1

图 1 展示了数据在 DLRM 中的流动方式。sparse特征通过嵌入表进行转换;dense特征通过底部的dense层进行转换。转换后的特征随后被连接起来,并在顶部的dense层中进一步转换,以计算输入数据的可能性。

2.1.1 训练 DLRM

并行化是规模化训练推荐模型的主要方法 [5, 8]。sparse层和dense层可以采用不同的并行化逻辑。sparse层占整个模型大小的 99% 以上,可能达到数 TB 的规模。由于将所有sparse层存储在单个节点中是不可行的,因此采用模型并行化的方法将表分片到多个节点上。另一方面,dense层的规模足够小,可以容纳在每个节点中,因此它们被复制到各个节点上以利用数据并行化。

在 Meta,一个典型的训练集群包括 16 个节点,每个节点包含一个多插槽 CPU 和 8 个 GPU:

  • sparse层在所有 GPU 上分片,在正向和反向计算期间进行“all-to-all”通信。
  • dense层在所有 GPU 上复制,在反向传播期间使用“all-reduce”通信来聚合多个 GPU 中计算的梯度 [16]。

在训练期间,sparse层和dense层的新权重被同步计算和更新,以避免准确性下降。

2.1.2 服务 DLRM

为了以高吞吐量的方式高效地服务批量请求,通常使用 GPU 进行模型服务。在 Meta,服务节点位于专用的服务集群中。一个服务节点由主机 CPU 和附加的 GPU 组成。服务模型在服务节点之间复制,并使用数据并行化进行规模化模型服务。

  • 广告嵌入表存储在单个 GPU 中,因为它们需要更高的读取吞吐量。
  • 其他嵌入表存储在 CPU 中,CPU 通常具有更大的内存容量(例如 1.5 TB DRAM)。

为了存储嵌入表,使用紧凑的数据结构来最小化大小并以 GPU 友好的方式存储。特别是,嵌入表是连续存储的,每个嵌入表以行优先顺序存储在数组数据结构中。

为了刷新服务模型,使用额外的缓冲节点来避免暂停当前的服务节点(双buffer机制)。在新发布的模型加载到缓冲节点后,请求流量会切换到缓冲节点。

2.2 在线训练与离线训练

当服务模型需要使用实时数据流不断训练和更新时,会实施在线训练。在在线训练中,服务模型在提供预测的同时,会定期(以分钟到小时为单位)更新。训练在后台(通常在单独的集群中)继续运行以微调模型。训练模型发布到服务平台的速率对其生成的预测准确性有显著影响。

相比之下,离线训练不使用实时数据流进行训练,也没有严格的时间限制来训练模型并将其发布到服务平台。相反,它通常使用已经存储在数据存储中的大量数据。模型使用所有可用数据进行训练,当满足某些最优条件时训练停止,之后将其发布到服务平台

选择在线训练还是离线训练取决于具体的使用场景。当模型需要及时更新时,会实施在线训练系统。典型的用例可能是广告、搜索和视频的 DLRM(例如 [13, 18, 21]),当环境高度动态且需要模型几乎实时更新(以分钟为单位)时。在线训练帮助这些模型整合最新数据并避免准确性下降。当业务需求不需要实时模型更新时,可以使用离线训练方法来更新模型。

2.3 优化器状态作为特征重要性度量

通常,大型 DLRM 可以使用数十万个特征进行训练;然而,其中一些特征及其对应的权重对准确性没有影响。这些特征可能属于不活跃的用户或内容 ID,或者某些其他特征可能无法提供训练信号。在模型发布时,维护所有这些参数会消耗额外的带宽,或者在服务平台中运行推理时会消耗额外的计算和 GPU 存储。

为了减轻这些不利影响,我们可以计算特征重要性,并相应地修剪那些实际上不影响准确性的特征和权重。

QuickUpdate 使用优化器状态(或梯度动量)来计算特征重要性。它属于基于梯度的特征重要性度量家族(例如 [2, 12, 18])。优化器状态是梯度值的历史平均值,比梯度值更稳定,因为梯度值有时会在正值和负值之间振荡。优化器状态可以为我们提供以下指示:

  1. 对准确性的影响:它实际上显示了在训练过程中某个特定参数被更新的频率和幅度。较高的梯度值表明对提高准确性有较大影响;如果这种影响在历史上持续存在,我们可以更有信心地推断该参数对准确性很重要;因此,优化器状态是特征对准确性影响的稳定指示。

  2. 访问率:优化器状态为零(或接近零)表明该参数在训练过程中实际上没有被更新。这有两层含义。首先,这可能意味着该参数没有被访问过。这可能发生在一些不活跃的用户 ID 或过时的视频上。其次,如果实体是活跃的,接近零的值可能意味着从相关数据中可能没有重要的信号可以学习。例如,某个特定用户 ID 可能没有使用平台点击广告。

基于上述信息,QuickUpdate 使用优化器状态来完成两项不同的任务:

  1. 推理剪枝:当发布完整模型时,QuickUpdate 使用优化器状态来减少完整模型更新的规模。在此阶段,QuickUpdate 关注优化器状态值的低尾部分,并修剪那些对准确性没有影响的参数。

  2. 优先参数选择:当发布部分更新时,QuickUpdate 使用优化器状态来选择更重要的参数进行更新。在此阶段,QuickUpdate 关注优化器状态值的高尾部分,以更新更重要的参数。

2.4 推理剪枝

推理剪枝是为了在将完整模型快照发布到服务平台时减少模型的大小。剪枝特别针对查找嵌入表实施,并显著减少其大小(例如减少 50%)。由于查找表占 DLRM 大小的 99% 以上,剪枝可以在不影响准确性的情况下显著减少模型的大小。减少模型大小有助于消耗更少的带宽来发布模型更新;因此,更新可以以更短的延迟发布到数百个地理分布的集群中。此外,它还有助于更快地执行推理,因为计算中涉及的行数更少。

图片名称

图 2

图 2 展示了一个查找表的示例,其中包含索引和对应的行。每个索引代表分配给用户、视频等的唯一 ID。每一行可以被视为一个由可训练的浮点值组成的向量,模型使用这些向量为用户生成个性化推荐。

直观上,推理剪枝算法识别出代表不活跃实体或无法提供训练信号以提高准确性的行。从数学上讲,这是通过使用优化器状态向量来实现的。具体来说,训练器可以为每一行提供一个优化器状态向量。优化器状态向量中的每个元素表示该行中对应元素的梯度动量。优化器状态向量中元素的平均值用于量化行的重要性值。如果行的重要性值接近零,则意味着该行的元素在训练过程中实际上没有被更新;因此,该行可以被剪枝。

图 2 还展示了剪枝前后查找表的变化。使用行重要性值,推理剪枝算法确定最不重要的索引,并在将完整更新发布到服务器之前将其剪枝。出于操作目的,原始索引值会被重新映射到新的索引。新索引会简单地按照它们在原始表中的出现顺序递增。

需要注意的是,在本文中,完整模型更新指的是执行推理剪枝后的模型快照。

3 动机

在本节中,我们通过真实世界的数据来展示开发 QuickUpdate 的动机。首先,我们展示了当模型一小时或更长时间未更新时,准确性如何显著下降。接着,我们讨论了模型规模扩大带来的影响。如果不修改模型发布方法,我们只能选择接受更长的更新延迟和逐渐下降的准确性,或者大量投资基础设施以保持更新延迟的一致性。最后,我们强调了无损模型更新的局限性,强调了优先更新的必要性。

准确性增益

更新完整的服务模型是一个耗时的过程,可能需要数小时。因此,用户的最新行为和兴趣(例如发布新动态或与特定内容互动)在几小时内不会反映在服务模型中,这可能会降低模型的准确性。

图片名称

图 3 展示了在 Meta 的一个大规模模型中,随着模型更新延迟(1 到 7 小时),准确性如何下降。它将陈旧服务模型的准确性损失与完全新鲜模型进行了比较。结果显示,当模型更新延迟时,准确性损失显著增加,7 小时后损失超过 0.6%。

减少更新规模有助于加速模型更新并提高服务模型的准确性。

模型规模

近年来,DLRM 的规模显著增加。这些模型利用大量数据和参数来更好地理解用户兴趣和产品特征,从而提高准确性。这一进展催生了具有数万亿参数 [15] 和数 TB 模型大小 [5] 的复杂模型。此外,这一趋势预计在未来仍将持续。

随着模型规模的增加,由于传输所需的带宽增加,模型更新延迟也随之延长。如果不加以解决,预计这将导致未来模型新鲜度的下降。鉴于模型规模的持续扩大,单纯增加基础设施并不是一个可行的长期解决方案。因此,对大规模 DLRM 进行部分更新似乎是一个有前景的策略,旨在减少更新延迟,而无需增加更多基础设施。

无损模型更新

为了更好地理解模型随时间变化的比例,我们监控了更新的嵌入行,并据此计算了模型中被修改的平均比例。图 4 显示了模型随时间更新的百分比。很明显,模型的大部分在短时间内被更新。例如,在短短 10 分钟的时间间隔内,58% 的模型被更新。更新 58% 的模型是资源密集型的,需要比每小时更新完整模型更多的基础设施。这促使我们探索一种优先更新的方法,以显著减少更新规模。

图片名称

图 4

4 系统概述

图 5 提供了 QuickUpdate 架构的概览。DLRM 系统由训练节点、服务节点和用于保存模型快照的远程存储组成。QuickUpdate 的发布逻辑主要在 UpdateSelectorUpdatePatcher 代理中实现,这两个代理分别部署在训练节点和服务节点中。

  • UpdateSelector 负责决定模型的哪一部分应该更新,并在保存到远程存储之前对其进行量化。
  • UpdatePatcher 根据执行的更新类型实现不同的修补策略。

以下部分提供了更多详细信息。

图片名称

图 5 系统架构

4.1 更新什么

QuickUpdate 专注于对嵌入表执行部分更新,这些表通常占深度学习推荐模型的绝大部分(在我们的工作负载中超过 99%)。在这些模型中,每个表表示一个分类特征(例如用户、视频),表中的每一行对应于与该特征相关的特定 ID。

在我们的探索中,我们考虑了两种更新嵌入表的选项:

    1. 更新选定表的所有行。
    1. 更新所有表中的选定行(不同表的选定行索引可能不同)。

我们发现,以行级粒度进行更新可以在最小化整体更新规模的同时提高准确性。因此,QuickUpdate 决定服务端需要更新表中的哪些特定行。这种方法使 QuickUpdate 能够优先更新更有可能提高准确性的内容或用户 ID,从而确保更新策略的高效性和有效性。

对于模型中的dense层,QuickUpdate 执行完整更新。这是因为这些层的更新规模相对较小,针对这些层的任何优化对整体更新过程的影响不大。

4.2 UpdateSelector

QuickUpdate 的 UpdateSelector 组件在训练集群中实现。这是因为它需要从训练器中获取某些模型信息(例如参数值)以准备模型更新。

在在线训练期间,训练器以批次间隔运行。在每个训练间隔结束时,训练器将模型状态和优化器状态共享给 UpdateSelector。模型状态包括分片的嵌入表和dense参数值,而优化器状态包括梯度值及其动量。这些状态从 GPU 内存复制到主机 CPU 内存中。

UpdateSelector 使用优化器状态对 CPU 中的模型副本执行以下两项任务:

  1. 优先参数选择:此任务的主要目标是仅更新模型参数的一小部分,同时最小化准确性的下降(与完整更新相比)。在此阶段,QuickUpdate 根据优化器状态值选择嵌入行,优先选择那些可能对准确性提升较大的行。
  2. 推理剪枝:此任务在发布完整模型时执行。推理剪枝专注于sparse嵌入表,旨在减少完整模型更新的规模。在此阶段,QuickUpdate 识别低尾优化器状态值,并剪枝值接近零的嵌入行。这些行对模型准确性的影响可以忽略不计。

一旦更新(无论是完整还是部分)准备就绪,它们会经过量化以减少其大小。量化作为一种压缩方法,对模型准确性的影响可以忽略不计。量化后的更新随后存储在远程存储中,准备用于更新过程。

4.3 UpdatePatcher

UpdatePatcher 负责加载发布的快照并更新服务模型。它对部分和完整模型更新都采用了一种高效的非原子更新方法。在非原子更新过程中,多个线程可以访问模型参数,并逐步将参数修补到服务器中。这种方法允许多个线程并发修补参数,而无需锁定服务器或模型。因此,服务器可以在应用更新的同时继续对传入流量进行推理。这种方法确保了在更新过程中实时流量的高效且不间断的服务。

4.4 工作流程

图 6 展示了 QuickUpdate 的工作流程。为简化说明,我们仅展示了训练器、UpdateSelector 和一个服务节点中的时间尺度。模型的演化是一个可重复的模式,因此我们专注于一个周期,该周期进一步分为多个间隔。在周期 $ c $ 的每个间隔 $ i $ 开始时,UpdateSelector 可以访问完整模型 $ F_{c,i} $ 以确定模型的哪一部分应该更新。具体来说,首先会发布一个完整快照(即 $ F_{c,1} $)并加载到服务器中,然后连续的部分更新($ P_{c,i} $ 其中 $ i > 1 $)会被发布并修补到完整快照中,以创建服务快照 $ S_{c,i} $。

图片名称

图6 QuickUpdate工作流

将更多部分更新与服务模型合并可能会导致服务模型 $ S_{c,i} $ 与当前训练器状态 $ F_{c,i} $ 之间的偏差增大。这种偏差可能会导致准确性下降。因此,另一个完整的新鲜快照(即 $ F_{c+1,1} $)将被发布到服务集群,标志着当前周期的结束。服务端的模型演化可以表示如下:

\[S_{c,1} = F_{c,1} \\ S_{c,i} = M(S_{c,i-1}, P_{c,i}) \quad \text{对于} \quad 1 < i \leq I\]

其中:

  • $ I $ 是一个周期中的间隔数
  • $ M $ 是合并操作符。合并操作符简单地复制 $ P_{c,i} $ 的参数值并将其更新到 $ S_{c,i-1} $ 中。

5 设计

在本节中,我们讨论了设计选项及其对准确性指标的影响。我们首先定义了指导设计和评估的具体准确性指标。通过在整个设计过程中优先考虑准确性,我们的目标是创建一个有效的系统,在解决网络和存储带宽瓶颈的同时,提供高服务准确性。需要注意的是,QuickUpdate 是可配置的,并在生产环境中进行监控,以应对罕见的准确性下降情况。

5.1 准确性指标

二元交叉熵(Binary Cross Entropy)或熵(Entropy)[17] 是评估广告模型准确性的一个众所周知的综合指标。在本研究中,我们使用归一化熵(Normalized Entropy, NE),其定义为二元交叉熵除以一个常数。为了理解部分更新相对于完全新鲜快照和过时模型的表现,我们计算了 NE 的以下变体。为简化说明,我们从符号中省略了周期下标 $ c $。

  1. NE 损失:它表示使用模型 $ S_i $ 而不是相应的完全新鲜模型 $ F_i $ 运行推理时的准确性下降。

    \[\text{NE}_{\text{loss}}(S_i) = \frac{\text{NE}_{S_i} - \text{NE}_{F_i}}{\text{NE}_{F_i}} \times 100 \quad (1)\]

    其中:

    • $ NE_{S_i} $ 和 $ NE_{F_i} $ 分别表示模型 $ S_i $ 和 $ F_i $ 的归一化熵。
  2. NE 增益:它表示如果使用 $ S_i $ 进行推理而不是过时模型,可以预期的准确性提升:

    \[\text{NE}_{\text{gain}}(S_i) = \frac{\text{NE}_{S_i} - \text{NE}_{\text{stale}}}{\text{NE}_{\text{stale}}} \times 100 \quad (2)\]

    过时模型被认为是最近发布的完整模型 $ F_1 $。

  3. NE 恢复:它表示模型 $ S_i $ 已达到的最大 NE 增益的百分比。我们假设如果可以使用完全训练的模型 $ F_i $ 进行推理,则可以实现最大 NE 增益。因此,NE 恢复定义为:

    \[\text{NE}_{\text{recovery}}(S_i) = \frac{\text{NE}_{\text{gain}}(S_i)}{\text{NE}_{\text{gain}}(F_i)} \times 100 \quad (3)\]

5.2 选择标准

为了优先更新能够带来更大准确性增益的行,我们需要一个在训练过程中保持稳定的可靠指标。虽然梯度向量可以作为标准,但其在正值和负值之间的振荡引入了数值不稳定性。相反,我们可以使用优化器状态向量(也称为动量),它提供了更稳定的度量。优化器状态向量表示特定行的历史梯度的平均平方和。通过将:

  • $ OS^r_{c,i} $: 表示为模型在间隔 $ i $ 时行 $ r $ 的优化器状态向量
  • $ \overline{OS^r_{c,i}}$: 表示为其元素的平均值,我们可以利用该度量作为行重要性的指示。

直观上,具有较大 $ \overline{OS}^r_{c,i} $ 值的行更有可能提高准确性。例如,这些行可能代表频繁使用平台点击广告的特定用户,或者代表具有高访问率的特定视频。除了给定间隔的 $ \overline{OS}^r_{c,i} $ 的大小外,跟踪其随时间的变化也可能很重要。这对于我们更倾向于优先选择相对于旧版本发生变化的参数的情况可能具有潜在的信息价值。基于这些直觉,我们评估以下选择标准:

  1. 绝对优化器状态

    \[\text{abs}(\overline{OS_{c,i}^r}) \quad \text{对于} \quad i > 1 \quad (4)\]
  2. 增量优化器状态

    \[\text{abs}(\overline{OS_{c,i}^r} - \overline{OS_{c,i-1}^r)} \quad \text{对于} \quad i > 1 \quad (5)\]

选择使用绝对优化器状态还是增量优化器状态作为选择标准取决于它们各自的优势和权衡。虽然绝对优化器状态提供了行对准确性影响的稳定和综合度量,但增量优化器状态捕捉了与前一个间隔相比影响的变化。然而,使用增量优化器状态需要额外的内存来存储前一个间隔的优化器状态。为了评估这些标准的影响,我们进行了实验,间隔长度为 30 分钟,更新规模为 10%。在发布完整快照并再训练一小时后,我们根据这两个标准发布了更新规模为 10% 的部分快照。然后,我们评估了与完全新鲜模型相比的服务准确性。表 1 中的结果(在多次此类实验中一致)表明,增量优化器状态实现了 100% 的 NE 恢复,而绝对优化器状态实现了 70% 的 NE 恢复。这意味着基于增量优化器状态选择行可以减少服务模型与相应完整快照之间的差异

图片名称

表1 不同选择标准的NE recovery

5.3 增量选择的基线

增量优化器状态是基于基线计算的。在计算增量优化器状态时,我们考虑了两种选择基线的方法:

  1. 上一次更新时的模型状态:此选项将上一次更新时的模型状态作为基线。这与上一节中增量优化器状态的定义相同。

  2. 上一次完整更新时的模型状态:如第 5.6 节所述,QuickUpdate 还利用间歇性完整更新。在此基线选项中,将上一次间歇性完整更新时的模型状态用作增量。在这种情况下,增量优化器状态定义为:

    \[\text{abs}(\overline{\text{OS}^r_{c,i}} - \overline{\text{OS}^r_{c,1}}) \quad \text{对于} \quad i > 1 \quad (6)\]

对于第一种选项,需要在每个训练间隔结束时保存基线,而对于第二种选项,只需在周期的第一个间隔中保存一个基线。因此,第一种选项提供了更新鲜的基线,但需要额外的计算资源将其保存在内存中。

我们通过实验检查了不同基线对服务准确性的影响,每个间歇性完整更新后跟随四个部分更新。我们评估了与完全新鲜模型相比的服务准确性(NE 恢复)。表 2 中的结果显示,使用上一次间隔的基线可以实现 3.11%(95.94% 对比 99.05%)更高的 NE 恢复。这表明,使用上一次更新的模型状态作为基线可以更好地反映最近的用户兴趣,因为它在每个更新间隔中都会刷新。此外,使用完整更新作为基线可能会优先选择在前一个间隔中重要但不再对准确性有贡献的参数。随着时间的推移,这些参数的优化器状态可能达到平稳状态,但由于其较大的增量优化器状态值,完整更新基线可能仍然会考虑它们。更频繁地刷新基线有助于消除对此类参数的优先选择,转而优先选择最近变化的参数,这些参数更有可能对准确性提升有贡献。

5.4 实时推理剪枝

如第 2.4 节所述,推理剪枝有助于减少服务模型的大小和所需的 GPU 数量。它实际上会剪枝那些不再活跃或对准确性影响可以忽略的行(或 ID),以减少嵌入表的大小。然后,剪枝后的表以 GPU 访问友好的方式紧凑地存储,以进一步减少服务集群中的大小。

剪枝仅在完整模型发布到服务集群时实施。对于后续的间隔,我们希望部分更新能够与完整模型更新中的剪枝表兼容。理想情况下,部分更新中的行 ID 应存在于剪枝表中。这有助于我们简单地更新现有行的值,而无需重新构建 GPU 中的表。然而,情况并非总是如此。由于部分更新的训练数据与完整模型更新不同,可能会出现某些行 ID 在部分更新中变得重要,而这些 ID 在服务端的剪枝表中不存在的情况。当发生这种情况时,一种简单的实现方法是将缺失的行插入服务端的剪枝表中,但这可能是资源密集型的,并且可能需要重新调整所有 GPU 上的嵌入表以确保可访问性和效率(例如,避免内存碎片)。

为了避免嵌入表的密集重新调整,我们探索了两种与部分更新兼容的推理剪枝策略:

  1. 固定索引剪枝(见图 7a):在此策略中,QuickUpdate 执行优先参数选择以选择要更新的候选行索引。然而,仅更新嵌入表中已存在的行,而剪枝的行保持不变。

  2. 固定剪枝比例(见图 7b):在此策略中,每次完整更新时从嵌入表中剪枝固定比例的行。当 QuickUpdate 执行优先参数选择时,它最多选择 $ X $ 个索引进行更新,其中 $ X $ 是服务平台上给定表中的总行数。这确保了表中的行数保持一致。

图片名称

图 7

第一种策略避免了重新调整,因为只会进行行更新操作,而不会向嵌入表中插入新行。第二种策略通过使用行更新和索引重映射操作来避免重新调整。由于嵌入表的大小在第二种策略中不会改变,因此也避免了在 GPU 之间重新分片嵌入表的需要。

为了评估这两种剪枝策略,我们考虑了三种训练场景:1-无剪枝,2-固定剪枝索引,3-每个表固定剪枝比例。

我们的实验表明,剪枝导致的 NE 损失实际上可以忽略不计(<0.001%),且两种剪枝策略之间没有准确性差异。考虑到实现需求,我们选择了固定剪枝索引策略,因为其实现更简单。与固定剪枝比例策略不同,它不需要在每次新更新时更新索引映射。

6 评估

我们在 Meta 部署的最大推荐模型之一上评估了 QuickUpdate,使用真实世界的数据,并在类似于 [15] 的生产训练集群上进行训练。该模型是 [16] 中提出的 DLRM 模型的扩展,但其规模显著更大,达到 TB 级别。我们在所有实验中使用相同的预记录数据流,使实验可重复且可比较,并消除了由于时间数据变化导致的潜在结果偏差。模型最初使用几周的真实世界数据进行训练作为预热期,以达到稳定状态。对于准确性评估,我们评估了在训练数据之后的时间段内数据流的服务预测(即推理期间评估的数据未在之前的训练中使用)。

6.1 准确性

在本节中,我们比较了不同更新粒度对准确性的影响,并推导出最小完整快照频率,以确保 NE 损失不超过 0.01%。在这些实验中,我们在开始时发布一个完整快照,并继续发布具有不同粒度的部分更新。这些部分更新应用于完整服务快照之上,并使用相同的记录数据集进行准确性评估。

6.1.1 与过时模型相比的 NE 增益

我们首先比较 QuickUpdate 与过时模型的准确性,以量化准确性增益并验证在完整快照之上应用部分更新不会对准确性产生负面影响。这里的过时模型指的是最初发布的完整快照。图 8 显示了不同更新粒度(且无间歇性完整模型更新)相对于过时模型的 NE 增益。所有更新规模的 NE 增益均高于过时模型,并且 NE 增益随时间增加。5% 和 10% 的更新提供了非常相似的 NE 增益,但 1% 的更新返回的 NE 增益较少,表明一些重要的行未包含在 1% 的更新中。总体而言,这些趋势表明,即使在应用部分更新超过 10 小时后,也没有负面影响,并且与过时模型相比,准确性提高了 0.7%。

图片名称

图 8

6.1.2 与完全新鲜模型相比的 NE 损失

在本节中,我们研究了使用部分更新发布的 QuickUpdate 模型的 NE 损失,与理想的完全新鲜服务模式进行比较。图 9 中的结果显示,使用 10% 更新时,NE 损失在整个 10 小时内低于 0.005%。使用 5% 更新时,NE 损失始终高于 10% 更新,但在超过 6 小时内仍低于 0.01%。随着训练周期的增加,NE 损失增加,因为服务模型与相应训练模型之间的差异增加。结果还展示了采用不同更新粒度对完整模型发布延迟的影响,同时确保 NE 损失保持在可接受的 0.01% 阈值以下。通过采用 10% 的粒度,我们可以有效地将完整模型发布的需求延迟超过 10 小时。同样,当使用 5% 的粒度时,我们可以将完整模型发布延迟 6 小时,同时仍将 NE 损失保持在可接受范围内。这突显了 5% 粒度下部分更新在捕获重要更新并在相当长的时间内保持模型准确性和新鲜度方面的有效性。

图片名称

图 9

6.1.3 短期内的 NE 损失

为了分析短期内的 NE 损失,我们进行了一项评估,涉及四个连续的 10 分钟更新。检查的更新粒度为 5%、3% 和 1%。每次更新后,使用未见过的数据测量与完全新鲜模型相比的 NE 损失。图 10 显示了不同 10 分钟间隔内的变化,强调了流数据的波动性。然而,当在多个短时间间隔内平均时,NE 损失趋于稳定。正如预期的那样,结果显示,随着粒度的增加,NE 损失减少。最后一列显示的平均 NE 损失证实,5% 的粒度在我们的工作负载中会返回可接受的 NE 损失(平均而言)。

图片名称

图 10

6.1.4 结论

准确性结果证明了 QuickUpdate 在采用 5% 更新粒度长达 6 小时的有效性,同时保持与完全新鲜模型相当的准确性水平,并确保 NE 损失低于 0.01% 的阈值。

此外,使用 QuickUpdate 的 5% 更新粒度允许在需要发布完整模型之前延迟 6 小时。这种延迟之所以可能,是因为部分更新成功捕获并整合了重要变化,从而生成了准确且最新的模型。

基于这些发现,QuickUpdate 默认每 6 小时触发一次间歇性完整模型发布,从而优化了准确性与更新频率之间的平衡。

6.2 分析长期行收敛性

在之前的分析中,我们的重点是基于准确性指标最小化部分更新粒度,并确定间歇性完整更新的适当频率。结果表明,使用 5% 粒度的部分更新持续 6 小时可以达到令人满意的准确性。

在本实验中,我们的目标是探索部分更新更新了模型中多少百分比的重要行。

为了确定重要行的代理,我们训练模型 6 小时(即与满意准确性相同的持续时间)。我们将重要行定义为训练模型中排名前百分之几的行,使得在 6 小时结束时发布这些行(而不是整个模型)将返回令人满意的准确性(即与完全新鲜模型相比差异低于 0.01%)。

图 11 显示了在 6 小时训练后,不同大小的单次更新与完全新鲜模型相比的 NE 损失。可以看出,单次 5% 的部分更新不足以将 NE 损失降低到可接受的 0.01% 阈值以下。然而,10% 的部分更新证明足以将 NE 损失降低到可接受的水平。这表明排名前 10% 的嵌入行是此时间窗口内重要行的良好代理。

图片名称

图 11

为了了解这些重要行中有多少百分比被多个较小的 5% 更新覆盖,我们运行了 QuickUpdate 6 小时,并使用多个 5% 粒度的部分更新。在将所有这些更新合并为一个联合集后,我们观察到该集合涵盖了上述重要行的 70%,并总体覆盖了模型中所有行的 7.3%。因此,大部分重要行被连续的较小部分更新所覆盖。

6.3 带宽使用

在 QuickUpdate 中,更新大小是带宽使用的代理。带宽使用量取决于粒度、更新间隔和间歇性完整模型更新的频率。通常,这些参数是可配置的,并可能根据 DLRM 的类型和所需的准确性而变化。在本节中,我们评估了基于发布模型百分比的不同策略的带宽使用情况。详细信息如下并在图 12 中展示:

图片名称

图 12

  1. 基线 1:每小时发布一次完整模型。
  2. 基线 2:每 10 分钟发布一次完整模型(未在图中显示)。
  3. 5% 更新(默认策略):每 10 分钟发布一次部分更新,粒度为 5%,每 6 小时发布一次间歇性完整更新(如 6.1 节所述)。
  4. 10% 更新:与之前的策略类似,每 10 分钟发布一次部分更新,但粒度为 10%。每 6 小时发布一次间歇性完整更新。

为了比较这些策略,我们平均了消耗的带宽。结果显示,默认策略(5% 更新粒度,6 小时间歇性完整更新间隔)平均每小时写入模型大小的 43.6%,而策略 3(10% 更新)为 68.2%,基线 1 为 100%。基线 2 提供了与策略 2 和 3 相当的准确性,但需要每小时发布模型大小的 600%。

总体而言,使用默认策略,QuickUpdate 能够将消耗的带宽比基线 1 减少 2.3 倍,同时提供与完全新鲜模型相当的更好准确性。与基线 2 相比(由于网络和存储带宽限制,无法大规模实施),QuickUpdate 能够将所需带宽减少超过 13 倍,同时仍提供相当的准确性。

6.4 宽松一致性

传统上,服务模型以原子方式更新以保持一致的推理。这涉及将所有模型权重加载到缓冲节点中,这些节点随后成为计算推理的服务节点。然而,这种方法由于使用缓冲节点而资源密集。为了解决这个问题,QuickUpdate 放宽了一致性要求,并在执行推理查询的同时直接在服务节点中更新参数。

我们评估了在 QuickUpdate 中间歇性完整模型更新期间的 NE 恢复(与完全新鲜模型相比),作为已更新权重百分比的函数。如图 13 所示,放宽一致性可以在加载期间提高生产中的准确性。随着更多参数的加载,NE 恢复增加。我们的数据显示,通过修补 30% 的参数,我们可以捕获约 54% 的 NE 恢复。在修补仅 70% 的参数后,NE 恢复达到约 94%。

图片名称

图 13

宽松一致性允许早期服务新鲜行(而不是等待整个模型更新),从而整体提高准确性。尽管在加载期间表的视图不一致(意味着不同的行可能属于不同的状态),服务一部分新鲜行已经导致准确性增加。NE 恢复随着时间的推移继续增长,直到整个模型更新完毕。

7 相关工作

异步或部分更新策略已在少数实时 DLRM 中实施 [13,18,21]。在 Kraken [21] 中,dense参数每隔几秒批量更新一次,而sparse参数在训练器中值发生变化时更新。这是一种无损参数更新,对于具有 1000-10000 亿参数 [15] 和地理分布式服务器的大型模型,可能会产生大量流量。Monolith [13] 主要专注于开发具有无冲突嵌入表的sparse特征系统。sparse参数可以在训练时以分钟级粒度更新,其值自上次同步以来发生变化。与 Kraken 类似,这是一种无损更新,可能会产生巨大的流量。总体而言,无损模型更新可能非常资源密集,如第 3 节所述。为了克服这个问题,QuickUpdate 可以执行优先参数选择,从而减少约 78% − 92% 的带宽,并且准确性损失可以忽略不计(< 0.01%)。在另一项研究中,Ekko [18] 被设计为一个高效的系统,用于将更新从训练模型广播到所有服务推理节点。为了快速更新服务模型中的较大嵌入表,他们使用了嵌入表更新中的sparse性和时间局部性。Ekko 系统与 QuickUpdate 正交,两者可以一起实施。在 QuickUpdate 中,我们优化了设计元素,如发布间隔、更新粒度和参数选择标准,以实现所需的准确性并最小化完整模型的发布。优先参数选择是我们在本文中使用的技术之一。在以往的研究中(例如 [1,2,12]),基于梯度的参数选择已在分布式训练系统中探索。Ekko [18] 进一步扩展了这一标准,并额外考虑了每个参数的请求频率和参数新鲜度作为选择标准。在 QuickUpdate 中,我们决定选择梯度动量的增量,这是一个比梯度本身更稳定的度量,并且它发布的参数能够返回与基线快照相比的最高准确性。

8 结论

QuickUpdate 是一个支持在线训练执行低延迟部分更新的系统,同时提供与完全新鲜模型相当的服务准确性。它为实时服务生产规模的 DLRM 提供了一个可扩展的解决方案。这一点尤其有价值,因为由于网络和存储带宽的限制,大规模实时服务此类模型具有挑战性。

QuickUpdate 通过利用创新技术实现了其可扩展性和准确性目标。其中一项技术是选择性发布每次更新的最重要部分,从而在保持准确性的同时减少整体更新规模。此外,QuickUpdate 以低频率结合间歇性完整模型更新,以确保长期准确性。这种选择性部分更新和间歇性完整更新的结合使 QuickUpdate 能够在低延迟服务和长期保持准确性之间取得平衡。

我们使用大规模个性化广告模型的真实世界数据对 QuickUpdate 进行了评估,结果表明 QuickUpdate 能够提供与完全新鲜模型相当的服务准确性,同时将所需的写入带宽减少超过 13 倍。

附录