google在《Titans: Learning to Memorize at Test Time》提出了区别于Transformer的的一种新架构:Titans。我们来看一下它的实现,是否有前景:

摘要

在过去的十多年里,关于如何有效利用循环模型(recurrent models)和注意力机制(attentions)的研究已经非常广泛。循环模型的目标是将数据压缩到一个固定大小的内存中(称为hidden state),而注意力机制则允许模型关注整个上下文窗口,捕捉所有token之间的直接依赖关系。然而,这种更精确的依赖关系建模带来了二次方的计算成本(quadratic cost),限制了模型只能处理固定长度的上下文。我们提出了一种新的神经长期记忆模块(neural long-term memory module),该模块能够学习记忆历史上下文,并帮助注意力机制在利用过去信息的同时关注当前上下文。我们展示了这种神经记忆模块具有快速并行化训练的优势,同时保持了快速的推理能力。从记忆的角度来看,我们认为注意力机制由于其有限的上下文但精确的依赖关系建模,起到了短期记忆的作用;而神经记忆(neural memory)由于其能够记忆数据的能力,起到了长期、更持久的记忆作用。基于这两个模块,我们引入了一系列新的架构,称为Titans,并提出了三种变体,以探讨如何有效地将记忆融入这一架构中。我们在语言建模、常识推理、基因组学和时间序列任务上的实验结果表明,Titans比Transformers和最近的现代线性循环模型更有效。此外,与基线模型相比,Titans能够有效地扩展到超过200万的上下文窗口大小,并在“大海捞针”任务中表现出更高的准确性。

1.介绍

Transformers,一种纯基于注意力机制的架构(Vaswani 等人,2017),已被牢牢确立为序列建模中的最先进模型,主要归功于其上下文学习能力和大规模学习能力(Kaplan 等人,2020)。Transformers 的核心构建模块——注意力模块——充当关联记忆模块(Bietti 等人,2024),它们学习存储key-value关联性(associations),并通过计算query(即搜索信号)和key(即上下文)之间的成对相似性来检索这些关联。因此,从设计上看,Transformer 的输出完全取决于当前上下文窗口中token的直接依赖关系。然而,这种精确的依赖关系建模带来了与上下文长度相关的二次方时间和内存复杂度。在复杂的现实任务中(例如语言建模(N. F. Liu 等人,2024)、视频理解(C.-Y. Wu 等人,2019)、长期时间序列预测(H. Zhou 等人,2021)),上下文窗口可能变得非常大,这使得 Transformers 在这些下游任务中的适用性面临挑战。

为了克服 Transformers 的可扩展性问题,最近的研究旨在设计不同变体的线性 Transformers(Kacham、Mirrokni 和 P. Zhong,2024;Katharopoulos 等人,2020;S. Yang、B. Wang、Shen 等人,2024),其中注意力机制中的 softmax 被核函数取代(详见 §2.1),从而显著降低了内存消耗。尽管线性 Transformers 具有高效性并能够扩展到更长的上下文,但与 Transformers 相比,它们的性能并不具有竞争力,因为核技巧使模型变成了线性循环网络,其中数据被压缩为矩阵值状态(Katharopoulos 等人,2020)。然而,这带来了关于线性循环(或线性 Transformers)模型的一个矛盾事实:一方面,我们使用这些线性模型来增强可扩展性和效率(线性与二次方复杂度),其优势在非常长的上下文中显现;另一方面,非常长的上下文无法被适当地压缩到一个小的向量值或矩阵值状态中(S. Wang,2024)。

此外,除了效率问题外,大多数现有架构——从 Hopfield 网络(Hopfield,1982)到 LSTM(Jürgen Schmidhuber 和 Hochreiter,1997)以及 Transformers(Vaswani 等人,2017)——在处理泛化、长度外推和/或推理(Anil 等人,2022;Qin、Y. Zhong 和 Deng,2024)时都面临挑战,而这些是许多复杂现实任务中不可分割的部分。尽管这些架构从人类大脑中汲取了灵感,但它们都缺少以下关键部分:

  • (1)学习过程中的关键组件——例如短期记忆、长期记忆、元记忆、关注当前上下文等(Cowan,2008);
  • (2)这些组件如何作为可以独立运行的互联系统;以及/或
  • (3)从数据中主动学习并记忆过去历史的抽象能力。

我们认为,在一个有效的学习范式中,类似于人类大脑,存在独立但相互关联的模块,每个模块都负责学习过程中至关重要的组件。

记忆视角

记忆(memory)是一种基本的心理过程,也是人类学习中不可分割的组成部分(Terry,2017)。如果没有一个正常运作的记忆系统,人类和动物将只能局限于基本的反射和刻板行为。因此,记忆一直是机器学习文献中许多开创性研究的灵感来源;例如,Hopfield 网络(Hopfield,1982)、LSTM(Jürgen Schmidhuber 和 Hochreiter,1997)以及 Transformers(Vaswani 等人,2017)。

从神经心理学文献中对记忆和学习的常见定义中汲取灵感(Okano、Hirano 和 Balaban,2000),大多数现有架构将记忆视为由输入引起的神经更新,并将学习定义为在给定目标的情况下获取有效且有用记忆的过程。从这个角度来看,循环神经网络(RNN)(Williams 和 Zipser,1989)可以被定义为具有向量值记忆模块 M(也称为hidden state)的模型,其主要步骤包括:

在时间 𝑡 给定新输入 $𝑥_𝑡$ 时,模型

  • (1)使用函数 $𝑓(M_{𝑡−1}, 𝑥_𝑡) $ 更新记忆(带有压缩);
  • (2)使用函数 $𝑔(M_𝑡, 𝑥_𝑡)$ 检索输入的相应记忆(详见 §2.1)。

类似地,Transformers 可以被视为具有增长记忆和两个相似步骤的架构。即,key和value矩阵对充当模型的记忆,模型:

  • (1)通过将key和value附加到记忆中来更新记忆(无压缩),
  • (2)通过查找query向量与key向量的相似性来检索query向量的相应记忆,然后将其用于加权value向量以生成输出。

这种视角可以帮助我们更好地理解现有范式、它们的关键差异,并设计更有效的架构。例如,Transformers(Vaswani 等人,2017)和线性 Transformers(Katharopoulos 等人,2020)之间的主要区别在于记忆结构以及记忆更新步骤,其中:线性 Transformers 将历史数据压缩为固定大小的矩阵值记忆,而 Transformers 则保留所有历史数据(在上下文长度内)而不进行任何压缩。虽然线性 Transformers 和线性 RNN(包括状态空间模型)都在记忆更新步骤中压缩信息,但关键区别在于记忆的结构,其中线性 RNN(相对于线性 Transformers)使用向量值记忆(相对于矩阵值记忆)。因此,这种视角促使我们提出以下问题:

  • (Q1)什么是良好的记忆结构
  • (Q2)什么是适当的记忆更新机制
  • (Q3)什么是良好的记忆检索过程

重新审视我们对人类记忆的理解,它既不是一个单一的过程,也不服务于单一的功能(Cowan,2008)。事实上,记忆是一个系统的联合体——例如短期记忆、工作记忆(working memory)和长期记忆——每个系统服务于不同的功能,具有不同的神经结构,并且每个系统都能够独立运行(Willingham,1997)。这一事实促使我们提出:

  • (Q4)如何设计一个包含不同互联记忆模块的高效架构

最后,存储记忆是一个神经过程,需要对过去的抽象进行编码和存储。假设一个单一向量或矩阵(其参数以线性方式编码数据)足以存储长期历史可能过于简化。

  • (Q5)是否需要深度记忆模块来有效存储/记住遥远的过去?

贡献与路线图

在本文中,我们旨在通过设计一个长期神经记忆模块来回答上述五个问题,该模块能够在测试时高效且有效地学习记忆。基于其设计,我们讨论了如何将其融入架构中。

神经记忆(§3)。我们提出了一种(深度)神经长期记忆模块,它(作为元上下文模型)学习如何在测试时将数据记忆/存储到其参数中。受人类长期记忆系统(Mandler,2014)的启发,

我们设计了这个记忆模块,使得违反预期的事件(即令人惊讶的事件: surprising)更容易被记住。为此,我们通过神经网络在关联记忆损失中对输入的梯度来衡量输入的“惊讶度(surprise)”(详见 §3.1)。为了更好地处理有限的内存,我们提出了一种衰减机制,该机制考虑了内存大小与数据惊讶度的比例,从而实现更好的内存管理。我们展示了这种衰减机制实际上是现代循环模型中遗忘机制的泛化(Dao 和 Gu,2024;Gu 和 Dao,2024;S. Yang、Kautz 和 Hatamizadeh,2024)。有趣的是,我们发现这种机制等同于使用小批量梯度下降、动量和权重衰减来优化元神经网络。基于张量化小批量梯度下降以使用更多矩阵乘法操作(Yu Sun 等人,2024),我们提出了一种快速且可并行化的算法来训练我们的深度神经长期记忆模块。

Titans 架构(§4)

在设计完长期神经记忆模块后,一个重要的问题是:如何高效且有效地将记忆融入深度学习架构中。我们提出了 Titans,这是一个由三个超头部(hyper-heads)组成的深度模型家族:

  • (1)核心模块:该模块包含短期记忆,负责数据处理的主要流程(我们使用有限窗口大小的注意力机制);
  • (2)长期记忆模块:这一分支是我们的神经长期记忆模块,负责存储/记住遥远的过去;
  • (3)持久记忆模块:这是一组可学习但与数据无关的参数,用于编码任务相关知识。

最后,作为概念验证,我们提出了 Titans 的三种变体,其中我们将记忆分别融入为:

  • (i)一个上下文(context)
  • (ii)层(layer)
  • (iii)一个门控分支(gated branch)

实验结果(§5)

我们在语言建模、常识推理、记忆密集型任务、“大海捞针”任务、时间序列预测和 DNA 建模任务上进行了实验评估。我们观察到,Titans 架构在所有现代循环模型及其混合变体(结合滑动窗口注意力机制)的综合基准测试中均表现优异。此外,Titans 在相同上下文窗口下优于 Transformers,并在使用整个上下文的 Transformers 中表现出竞争力。这些结果是在 Titans 能够扩展到超过 200 万上下文窗口大小的情况下实现的,而 Transformers 则无法做到这一点

2 预备知识

在本节中,我们将讨论本文中使用的符号和一些背景概念。我们令:

  • $ x \in \mathbb{R}^{N \times d_{\text{in}}} $ 表示输入
  • $ \mathbf{M} $ 表示神经网络(神经记忆模块:neural memory)
  • $ \mathbf{Q} $、$ \mathbf{K} $、$ \mathbf{V} $ 分别表示注意力机制中的query、key和value
  • $ M $ 表示注意力掩码(attention mask)
  • $ S^{(i)} $ 表示:在对序列进行分段时,使用第 $ i $ 段。

在本文中,我们简化符号并使用下标来指代矩阵、向量或段中的特定元素。例如,我们令:

  • $ S^{(i)}_j $ 表示第 $ i $ 段中的第 $ j $ 个 token。

唯一的例外是下标为 $ t $ 的情况,我们保留它来表示时间上的递归或神经网络在时间 $ t $ 的状态。

给定:神经网络 $ \mathbf{N} $ 和数据样本 $ x $,我们使用:

  • $ \mathbf{N}(x) $(或 $ \mathbf{N}^*(x) $)表示带权重调整(或不带权重调整)的前向传播

此外,我们简化符号并使用:

  • $ \mathbf{N}^{(k)} $ 表示神经网络的第 $ k $ 层

接下来,我们首先讨论注意力机制及其高效变体的背景,然后回顾现代线性 RNN,最后讨论这些架构的记忆视角,这促使我们设计了 Titans。

2.1 背景

注意力机制。Transformers(Vaswani 等人,2017)作为许多深度学习模型的实际骨干,基于注意力机制。给定:

  • 输入 $ x \in \mathbb{R}^{N \times d_{\text{in}}} $

因果注意力机制基于输入依赖的key、value和query矩阵计算输出 $ y \in \mathbb{R}^{N \times d_{\text{in}}} $:

\[\mathbf{Q} = x \mathbf{W}_Q, \quad \mathbf{K} = x \mathbf{W}_K, \quad \mathbf{V} = x \mathbf{W}_V, \quad (1)\] \[y_i = \frac{\sum_{j=1}^i \exp\left(\frac{\mathbf{Q}_i^\top \mathbf{K}_j}{\sqrt{d_{\text{in}}}}\right) \mathbf{V}_j}{\sum_{\ell=1}^i \exp\left(\frac{\mathbf{Q}_i^\top \mathbf{K}_\ell}{\sqrt{d_{\text{in}}}}\right)}, \quad (2)\]

其中:

  • $ W_Q, W_K, W_V \in R^{d_{in} \times d_{in}} $ 是可学习参数

尽管 Transformers 在召回能力上表现出色,但它们至少需要 $ N \times d $ 次操作来计算输出,导致内存消耗较大且对较长序列的吞吐量较低。

高效注意力机制。为了提高软注意力机制在长序列上的内存消耗和吞吐量,许多研究集中在注意力机制的 I/O 感知实现(Dao 2024;Dao, D. Fu 等人,2022),通过稀疏化注意力矩阵(B. Chen 等人,2021;Choromanski 等人,2021;Dai 等人,2019)、近似 softmax(Arora 等人,2024)或开发基于核的(线性)注意力机制(Aksenov 等人,2024;Kacham, Mirrokni 和 P. Zhong,2024;Schlag, Irie 和 Jürgen Schmidhuber,2021;S. Yang, B. Wang, Shen 等人,2024)来设计更高效的注意力机制。在本部分,我们重点关注后者,即线性注意力机制,其中标准注意力中的 softmax 被替换为替代核函数 $ \phi(\cdot, \cdot) $,使得 $ \phi(x, y) = \phi(x) \phi(y) $。因此,注意力可以写成:

\[y_i = \frac{\sum_{j=1}^i \phi(\mathbf{Q}_i^\top \mathbf{K}_j) \mathbf{V}_j}{\sum_{\ell=1}^i \phi(\mathbf{Q}_i^\top \mathbf{K}_\ell)} = \frac{\phi(\mathbf{Q}_i)^\top \sum_{j=1}^i \phi(\mathbf{K}_j) \mathbf{V}_j}{\phi(\mathbf{Q}_i)^\top \sum_{\ell=1}^i \phi(\mathbf{K}_\ell)}, \quad (3)\]

由于:

  • 项 $ \sum_{j=1}^i \phi(K_j), \sum_{\ell=1}^i \phi(K_\ell) $ 在每一步中重复使用,因此吞吐量更高。

当选择核函数为单位矩阵时(Yutao Sun 等人,2023),上述公式可以写成递归形式:

\[\mathbf{M}_t = \mathbf{M}_{t-1} + \mathbf{K}_t^\top \mathbf{V}_t, \quad (4)\] \[y_t = \mathbf{Q}_t \mathbf{M}_t, \quad (5)\]

这使得线性注意力机制能够高效推理。

现代线性模型及其记忆视角。如前所述,可以将学习定义为获取有效且有用记忆的过程。基于此,可以将循环神经网络(RNN)的hidden state视为记忆单元,模型旨在将信息压缩到其中。因此,在一般形式的循环神经网络中,hidden state可以被视为记忆单元,递归过程可以分为记忆单元的读和写操作。即,令:

  • $ x \in \mathbb{R}^{N \times d_{\text{in}}} $ 为输入
  • $ \mathbf{M} \in \mathbb{R}^d $ 为记忆单元
  • $ y \in \mathbb{R}^{d_{\text{in}}} $ 为输出

则循环神经网络的一般形式定义为:

\[\mathbf{M}_t = f(\mathbf{M}_{t-1}, x_t), \quad \text{写操作} \quad (6)\] \[y_t = g(\mathbf{M}_t, x_t), \quad \text{读操作} \quad (7)\]

其中:

  • $ f(\cdot, \cdot) $ 是读操作,
  • $ g(\cdot, \cdot) $ 是写操作。

注意,这里的 $ \mathbf{M}_t $ 下标表示记忆在时间 $ t $ 的状态。

从这一视角来看,线性 Transformers 的递归公式(见公式 4)等同于将键和值 $ (\mathbf{K}_t, \mathbf{V}_t) $ 加性地压缩并写入矩阵值记忆单元 $ \mathbf{M}_t $ 中。因此,在处理长上下文数据时,这种加性特性会导致内存溢出,显著损害模型性能。为了解决这一问题,研究集中在两个有前景的方向上:

  • (1)添加遗忘机制:一些研究提出了线性模型的自适应(数据依赖)遗忘门机制,可以在需要时擦除记忆。例如,GLA(S. Yang, B. Wang, Shen 等人,2024)、LRU(Orvieto 等人,2023)、Griffin(De 等人,2024)、xLSTM(Beck 等人,2024)和 Mamba2(Dao 和 Gu,2024)等模型,后者还与离散化的传统状态空间模型(Gu 和 Dao,2024)相关联。
  • (2)改进写操作:为了克服传统循环模型中记忆写操作的加性特性,Widrow 和 Hoff(1988)提出了 Delta 规则,在添加记忆(即键值对)之前,模型首先移除其过去的值。为了增强可并行化训练和扩展性,S. Yang, B. Wang, Yu Zhang 等人(2024)提出了一种快速并行化算法。最后,最近 S. Yang, Kautz 和 Hatamizadeh(2024)通过添加遗忘门改进了 DeltaNets。

记忆模块。记忆一直是神经网络设计的核心部分之一(Graves, Wayne 和 Danihelka,2014;JH Schmidhuber,1992;Jürgen Schmidhuber 和 Hochreiter,1997;J. Zhang 等人,2024)。将线性层视为键值(关联)记忆系统(key-value (associative) memory system)的思想可以追溯到快速权重程序(fast weight programs),其中动态快速程序被纳入循环神经网络中作为可写记忆(JH Schmidhuber,1992)。Hebbian(Hebb,2005)和 delta(Prados 和 Kak,1989)学习规则是快速权重程序中最流行的学习规则,已在各种研究中广泛探索(Irie, Schlag 等人,2021;Munkhdalai, Sordoni 等人,2019;Munkhdalai 和 H. Yu,2017;Schlag, Irie 和 Jürgen Schmidhuber,2021;JH Schmidhuber,1992;S. Yang, Kautz 和 Hatamizadeh,2024;S. Yang, B. Wang, Yu Zhang 等人,2024)。然而,所有这些模型都基于瞬时惊讶度,忽略了序列中的 token 流(见第 3.1 节),并且大多数模型缺乏遗忘门,导致内存管理不佳。

我们在附录 C 中进一步讨论了我们的架构与最近模型的联系。其他相关工作在附录 A 中讨论。

3 测试时的记忆学习

为了克服长期记忆的不足,并使模型能够学习、遗忘和检索信息,本节提出了一种神经长期记忆模块,这是一种在测试时学习记忆的元模型。

  • 在 3.1 节中,我们首先讨论神经记忆的动机和设计。
  • 在 3.2 节中,我们讨论如何通过快速且可并行化的训练使我们的架构设计受益。
  • 在 3.3 节中,我们通过持久记忆模块增强我们的架构,其中使用可学习但与数据无关的参数来学习任务的元信息。

图片名称

图1 关于如何并行训练神经记忆并使用矩阵乘法(matmuls)的图示说明。

3.1 长期记忆

为了设计一个神经长期记忆模块,我们需要一个能够将过去历史的抽象编码到其参数中的模型。一个例子是大语言模型(LLMs),它们被证明能够记忆训练数据(Leybzon 和 Kervadec,2024;Schwarzschild 等人,2024;Staab 等人,2024)。因此,一个简单的想法是:训练一个神经网络并期望它记忆其训练数据。然而,记忆化几乎总是被认为是神经网络中的不良现象,因为它限制了模型的泛化能力(Bayat 等人,2024),引发隐私问题(Staab 等人,2024),并导致测试时性能不佳。此外,训练数据的记忆化在测试时可能没有帮助,因为数据可能是分布外的。我们认为,我们需要一个在测试时学习如何记忆/遗忘数据的在线元模型。在这种设置中,模型学习的是一个能够记忆的函数,但它不会过度拟合训练数据,从而在测试时实现更好的泛化。

学习过程与惊讶度度量(Surprise Metric)。训练长期记忆的关键思想是:将其训练视为一个在线学习问题,我们的目标是将过去的信息 $ x_1, \ldots, x_{t-1} $ 压缩到长期神经记忆模块 $ \mathbf{M}_t $ 的参数中。如前所述,违反预期的事件(即令人惊讶的事件)对人类来说更容易被记住(Mandler,2014)。受此启发,模型惊讶度的一个简单定义可以是:其相对于输入的梯度。梯度越大,输入数据与过去数据的差异越大。因此,使用这种惊讶度评分,我们可以更新记忆如下:

\[\mathbf{M}_t = \mathbf{M}_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) \quad \text{(惊讶度)} \quad (8)\]

然而,这种惊讶度度量可能会导致错过在重大惊讶时刻之后的重要信息。也就是说,梯度在几次惊讶步骤后可能变得非常小,导致陷入平坦区域(即局部最小值),并错过序列某些部分的信息。从人类记忆的角度来看,一个事件可能不会在长时间内持续让我们感到惊讶,尽管它是值得记忆的。原因是初始时刻足够令人惊讶,足以在长时间内吸引我们的注意力,从而记住整个时间段。为了改进上述惊讶度度量(公式 8),我们将惊讶度度量分为:

  • (1)过往惊讶度(past surprise),衡量最近过去的惊讶程度;
  • (2)瞬时惊讶度(momentary surprise),衡量输入数据的惊讶程度:
\[\mathbf{M}_t = \mathbf{M}_{t-1} + S_t, \quad (9)\] \[S_t = \eta_t S_{t-1} \quad \text{(过去惊讶度)} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) \quad \text{(瞬时惊讶度)} \quad (10)\]

有趣的是,这个公式类似于带有动量的梯度下降,其中:

  • $ S_t $ 是动量项

因此,这里的动量充当了跨时间(序列长度)的惊讶度记忆。在这个公式中:

  • 项 $ \eta_t $ 是一个数据依赖的惊讶度衰减($ x_t $ 的函数),控制惊讶度随时间衰减的程度
  • 项 $ \theta_t $ 则控制瞬时惊讶度应以数据依赖的方式纳入最终惊讶度度量的多少

这种数据依赖性在这个设计中尤为重要:虽然前一个 token 的惊讶度可能需要影响下一个 token 的惊讶度,但这只有在所有 token 都相关且处于同一上下文中时才有效。因此,数据依赖的 $ \eta $ 可以控制记忆是否需要:

  • (1)通过设置 $ \eta_t \to 0 $ 忽略上一次的惊讶度(可能由于上下文的变化),
  • (2)通过设置 $ \eta_t \to 1 $ 完全纳入上一次的惊讶度(可能因为 token 与其最近的过去 token 高度相关)。

目标。我们上述的惊讶度度量基于损失函数 $ \ell(\cdot; \cdot) $,这是我们的记忆模块在测试时学习的目标。也就是说,我们的记忆模块是一个元模型,它基于损失函数 $ \ell(\cdot; \cdot) $ 学习一个函数。

在本节中,我们重点讨论关联记忆,其目标是将过去的数据存储为k-V对。给定输入 $ x_t $,类似于 Transformers(Vaswani 等人,2017),我们使用两个线性层将 $ x_t $ 投影为key和value:

\[\mathbf{k}_t = x_t \mathbf{W}_K, \quad \mathbf{v}_t = x_t \mathbf{W}_V, \quad (11)\]

其中 $ W_K $ 和 $ W_V \in R^{d_{in} \times d_{in}} $。接下来,我们希望记忆模块能够学习键和值之间的关联。为此,我们定义损失函数如下:

\[\ell(\mathbf{M}_{t-1}; x_t) = \|\mathbf{M}_{t-1}(\mathbf{k}_t) - \mathbf{v}_t\|_2^2, \quad (12)\]

通过在元模型(记忆)的内循环中优化上述损失函数,模型学习如何在测试时记忆键和值之间的映射。需要注意的是,类似于元学习模型(Nichol,2018;Zintgraf 等人,2019),记忆的训练是在内循环中进行的,因此参数 $ \mathbf{W}_K $ 和 $ \mathbf{W}_V $ 是上述损失函数中的超参数。因此,在内循环中,我们优化记忆模块 $ \mathbf{M} $ 的权重,而在外循环中,我们优化整个架构的其他参数。

遗忘机制

当处理非常长的序列(例如数百万个 token)时,管理哪些过去信息应该被遗忘至关重要——即使使用深度或非常大的矩阵值记忆。为此,我们使用一种自适应遗忘机制,允许记忆遗忘不再需要的信息,从而更好地管理记忆的有限容量。具体来说,给定下一个 token $ x_t $,我们修改更新规则如下:

\[\mathbf{M}_t = (1 - \alpha_t) \mathbf{M}_{t-1} + S_t, \quad (13)\] \[S_t = \eta_t S_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t), \quad (14)\]

其中:

  • $ \alpha_t \in [0, 1] $ 是一个门控机制,灵活控制记忆;即决定应该遗忘多少信息。例如:

  • 通过设置 $ \alpha_t \to 0 $,可以在不影响过去抽象的情况下更新记忆;
  • 通过设置 $ \alpha_t \to 1 $,可以清除整个记忆

在本节后面,我们将展示这种权重衰减机制与现代 RNN 中的门控机制密切相关(Dao 和 Gu,2024;Orvieto 等人,2023)。

记忆架构

在本文中,我们专注于使用具有 $ L_M \geq 1 $ 层的简单多层感知机(MLP)作为长期记忆的架构。选择这种架构的主要原因是,我们希望集中精力更好地激励长期记忆的设计及其融入架构的方式。然而,我们的公式和架构设计为设计在数据记忆方面更有效和高效的神经架构开辟了新的研究方向。最近,有一些有前景的工作致力于设计此类架构(Berges 等人,2024;Cetin 等人,2024;J. Zhang 等人,2024),将这些架构融入我们的框架(即用此类架构替换简单的 MLP)可能是一个有趣的未来工作方向。

当使用向量值或矩阵值记忆(De 等人,2024;Orvieto 等人,2023;S. Yang, B. Wang, Shen 等人,2024)时,记忆模块会压缩过去的数据并将其拟合到一条线上。也就是说,从元学习或在线学习的角度来看(Yu Sun 等人,2024),使用矩阵值记忆 $ \mathbf{M} = \mathbf{W} \in \mathbb{R}^{d_{\text{in}} \times d_{\text{in}}} $ 等同于优化 $ \ell(\mathbf{W}{t-1}; x_t) = |\mathbf{W}{t-1} \mathbf{k}_t - \mathbf{v}_t|_2^2 $,这是一个在线线性回归目标,因此最优解假设历史数据的潜在依赖关系是线性的。另一方面,我们认为深度记忆模块(即 $ L_M \geq 2 $ 层)在实践中更有效。这与理论结果一致,即至少具有两层的 MLP 严格比线性模型更具表达能力(Hornik, Stinchcombe, and White, 1989)。在第 5.5 节中,我们展示了深度记忆模块在实际应用中的有效性。


记忆检索

在上面,我们讨论了如何设计和训练一个在测试时学习记忆的长期记忆模块。一个关键的问题是:如何从记忆中检索信息?我们简单地使用不更新权重的前向传播(即推理)来检索与查询对应的记忆。形式上,给定输入 $ x_t $,我们使用线性层 $ \mathbf{W}_Q $ 投影输入,即 $ \mathbf{q}_t = x_t \mathbf{W}_Q $,并通过以下方式从记忆中检索相应的(或有用的)信息 $ y_t $:

\[y_t = \mathbf{M}^*(\mathbf{q}_t). \quad (15)\]

图片名称

图2:记忆作为上下文(MAC)架构。该架构包括三个分支:(1)核心分支,(2)上下文(长期)记忆分支,以及(3)持久记忆分支。核心分支将相应的长期记忆和持久记忆与输入序列连接起来。接下来,注意力机制在序列上执行,并决定哪些信息应存储在长期记忆中。在测试时,与上下文记忆对应的参数仍在学习,与核心分支对应的参数负责上下文学习,而持久记忆的参数负责存储任务知识,因此是固定的。

3.2 如何并行化长期记忆的训练

如上所述,我们的长期记忆模块的设计等同于通过优化关联记忆损失函数 $ \ell(\mathbf{M}{t-1}; x_t) = |\mathbf{M}{t-1}(\mathbf{k}_t) - \mathbf{v}_t|_2^2 $ 来训练一个元模型,使用带有动量和权重衰减的梯度下降。因此,理论上,长期记忆模块的训练需要 $ O(N) $ 的浮点运算(FLOPs),其中 $ N $ 是序列长度。然而,在实践中,我们需要并行化训练过程,并充分利用硬件加速器(例如 TPU、GPU),因此需要将过程张量化并使用更多的矩阵乘法(matmuls)。

接下来,我们展示如何通过小批量梯度下降、数据依赖的学习率和权重衰减来重新表述内循环中的权重计算,使其仅使用矩阵乘法和求和。我们基于 Yu Sun 等人(2024)的工作,该工作表明,使用小批量梯度下降(具有恒定学习率)优化的模型的前向传播可以通过矩阵乘法计算。我们可以将序列分割为大小为 $ b \geq 1 $ 的块,并将小批量梯度下降表示为:

\[\mathbf{M}_t = (1 - \alpha_t) \mathbf{M}_{t-1} - \theta_t \nabla \ell(\mathbf{M}_{t-1}; x_t) = \beta_t \mathbf{M}_0 - \sum_{i=1}^t \theta_i \frac{\beta_t}{\beta_i} \nabla \ell(\mathbf{M}_{t'}; x_i), \quad (16)\]

其中 $ t’ = t - \text{mod}(t, b) $,且 $ \beta_i = \prod_{j=1}^i (1 - \alpha_j) $。为了简化,我们专注于第一个块,即 $ t = b $,因此 $ t’ = 0 $。此外,我们解释当 $ \mathbf{M}_t = \mathbf{W}_t $ 是线性时的情况。对于具有 $ N_p \geq 2 $ 层的 MLP,过程类似。使用我们的损失函数,我们有:

\[\nabla \ell(\mathbf{W}_0; x_t) = (\mathbf{W}_0 x_t - x_t) x_t^\top \Rightarrow \sum_{i=1}^b \theta_i \frac{\beta_b}{\beta_i} \nabla \ell(\mathbf{W}_0; x_i) = \Theta_b \mathbf{B}_b (\mathbf{W}_0 \mathbf{X} - \mathbf{X}) \mathbf{X}^\top, \quad (17)\]

其中 $ \Theta_b = \text{diag}(\theta_1, \theta_2, \ldots, \theta_b) $,且 $ \mathbf{B}b $ 类似地定义在 $ \frac{\beta_b}{\beta_i} $ 上。需要注意的是,我们不需要存储所有 $ \Theta{kb} $ 和 $ \mathbf{B}_{kb} $($ k = 1, \ldots, N/b $),而是为每个块存储这些矩阵,从而减少内存使用。接下来,我们扩展这种表示,以便还可以纳入动量项。在带有动量的小批量梯度下降中,如果我们看动量项,我们有:

\[S_t = \eta_t S_{t-1} - \theta_t u_t, \quad (18)\]

其中 $ u_t = \nabla \ell(\mathbf{M}_{t’}; x_t) $。需要注意的是,我们可以同时计算所有 $ u_t $,因此公式 (18) 是一个线性递归,其中 $ u_t $ 是输入,$ S_t $ 是隐藏状态,$ \eta_t $ 是输入依赖的转移值。因此,我们可以使用并行关联扫描(J. T. Smith, Warrington, and Linderman, 2023)来计算该块中的 $ S_t $。

参数作为块的函数

与其让参数 $ \alpha_t $、$ \theta_t $ 和 $ \eta_t $ 依赖于输入(即 token $ x_t $ 的函数),我们可以让它们成为块的函数。尽管这会降低表达能力,但这种表述可以帮助使训练更快。在这种情况下,我们在每个块中对 $ \alpha $、$ \theta $ 和 $ \eta $ 使用相同的值。因此,在公式 (17) 中,我们可以使用单个标量存储 $ \Theta $。类似地,我们可以使公式 (18) 更快。也就是说,当 $ \eta $ 和 $ \theta $ 在每个块内可学习但时间不变时,该方程变为线性时不变系统(LTI),可以通过全局卷积计算(Gu, Goel, and Re, 2022)。在我们的实验中,我们将这些参数作为 token 的函数。然而,这种简化(即作为块的函数)可能是未来工作的兴趣点,以便以更高效的方式训练更大的模型。

3.3 持久记忆

我们的长期记忆也可以被视为一种上下文记忆,这意味着输出完全依赖于上下文。因此,除了长期记忆外,我们还使用一组可学习但与输入无关的参数来充当任务相关的记忆。这种类型的记忆在文献中被称为持久记忆或元记忆(X. Dong 等人,2024;Sukhbaatar, Grave 等人,2019)。给定 $ N_p \geq 1 $,我们使用可学习参数 $ P = [p_1, p_2, \ldots, p_{N_p}] $ 并将其附加到序列的开头:即,给定上下文窗口大小为 $ N $,我们将输入修改为:

\[x_{\text{new}} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (19)\]

其中 $ \parallel $ 表示连接操作。接下来,我们从三个角度讨论持久记忆的动机:


记忆视角

如前所述,我们的神经长期记忆是一种上下文记忆,其中所有参数都依赖于输入。然而,一个有效的记忆系统还需要与输入无关的参数来存储任务知识的抽象。也就是说,掌握一个任务需要记忆如何完成该任务的知识,而这些参数负责存储此类知识。


前馈网络视角

在 Transformer 架构中,注意力模块之后有全连接层,这些层被证明类似于注意力权重,但具有与数据无关的参数。即,Sukhbaatar, Grave 等人(2019)表明,将全连接层中的 ReLU 替换为 Softmax 可以产生类似注意力的权重,其中权重与数据无关:

\[FFN(x) = W_V \text{Softmax}(W_K x), \quad (20)\]

实际上,当 $ W_K $ 和 $ W_V $ 与输入无关时,它们的作用类似于注意力模块中的 $ K $ 和 $ V $ 矩阵。持久记忆权重预计具有相同的功能,这意味着在序列的开头部分使用它们会导致具有与输入无关的注意力权重(Sukhbaatar, Grave 等人,2019)。


技术视角

带有因果掩码的注意力机制对序列中的初始 token 具有隐式偏差,因此注意力权重几乎总是对初始 token 高度活跃,从而导致性能下降。从技术角度来看,序列开头的这些可学习参数可以通过更有效地重新分配注意力权重来缓解这种影响(Han 等人,2024;Xiao 等人,2024)。


总结

  • 持久记忆的作用:存储任务知识的抽象,与输入无关。
  • 前馈网络的类比:持久记忆权重类似于注意力机制中的 $ K $ 和 $ V $ 矩阵,但具有与数据无关的特性。
  • 技术优势:通过在序列开头引入可学习参数,持久记忆可以缓解注意力机制对初始 token 的偏差,从而提升模型性能。

持久记忆的引入为模型提供了任务知识的存储能力,并通过优化注意力权重的分配进一步提升了模型的性能。

图片名称

图3

4 如何融入记忆?

一个重要但尚未解答的问题是:如何有效且高效地将设计的神经记忆融入深度学习架构中?如前所述,从记忆的角度来看,Transformers 中的键值对矩阵可以解释为关联记忆块。由于其依赖关系的精确建模以及有限的上下文窗口,我们将其解释为短期记忆模块,专注于当前上下文窗口大小。另一方面,我们的神经记忆能够从数据中持续学习并将其存储在其权重中,可以扮演长期记忆的角色。在本节中,我们旨在通过提出 Titans 的三种不同变体来回答上述问题。在后续实验中,我们将展示每种变体的优缺点,以及在超长上下文中的效率与有效性之间的权衡。


4.1 记忆作为上下文(Memory as a Context, MAC)

在第一种架构设计中(见图 2),我们将记忆视为当前信息的上下文。即,给定一个长序列 $ x \in \mathbb{R}^{N \times d_{\text{in}}} $,我们首先将序列分割为固定大小的段 $ S^{(i)} $($ i = 1, \ldots, N/C $)。给定传入的段 $ S^{(t)} $,我们将其视为当前上下文,并将其过去的段视为历史信息。因此,设 $ \mathbf{M}{t-1} $ 为段 $ S^{(t)} $ 之前的长期记忆状态,我们使用输入上下文作为查询,从长期记忆 $ \mathbf{M}{t-1} $ 中检索相应的信息。即,我们检索与 $ S^{(t)} $ 对应的过去信息:

\[h_t = \mathbf{M}^*_{t-1}(\mathbf{q}_t), \quad (21)\]

其中 $ \mathbf{q}_t = S^{(t)} \mathbf{W}_Q $。接下来,我们将此历史信息与持久记忆参数一起作为注意力模块的输入序列:

\[\tilde{S}^{(t)} = [p_1, p_2, \ldots, p_{N_p}] \parallel h_t \parallel S^{(t)}, \quad (22)\] \[y_t = \text{Attn}(\tilde{S}^{(t)}). \quad (23)\]

整个序列的注意力图结构如图 3a 所示。然后,我们使用 $ y_t $ 更新长期记忆模块以用于下一个段和最终输出:

\[\mathbf{M}_t = \mathbf{M}_{t-1}(y_t), \quad (24)\] \[o_t = y_t \otimes \mathbf{M}^*_t(y_t). \quad (25)\]

需要注意的是,在上述过程中,我们通过前向传播更新 $ \mathbf{M}_{t-1} $ 的权重。

图片名称

图4

该架构具有两个关键优势:

  1. 注意力机制:通过同时拥有历史和当前上下文,能够决定在当前数据下是否需要长期记忆信息。
  2. 记忆管理:注意力模块帮助长期记忆仅存储当前上下文中有用的信息,避免记忆溢出。
  3. 测试时的行为
    • 持久记忆参数固定,因为它们编码了任务知识,不应更改。
    • 注意力模块权重是上下文学习者。
    • 长期记忆模块在测试时仍在学习(记忆)信息,即神经记忆的权重在测试时也会更新,以编码过去的抽象。

4.2 门控记忆(Gated Memory)

在第二种变体中(见图 4),我们在一个分支中直接使用输入数据更新长期记忆,在另一个分支中使用滑动窗口注意力(SWA):

\[\tilde{x} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (26)\] \[y = \text{SW-Attn}^*(\tilde{x}), \quad (27)\] \[o = y \otimes \mathbf{M}(\tilde{x}), \quad (28)\]

其中 $ \text{SW-Attn}^* $ 是带有前缀的滑动窗口注意力(见图 3b)。需要注意的是,与之前的设计不同,我们不对输入数据进行分段。此外,我们简化符号并使用 $ \mathbf{M}(x) $ 表示序列中所有 token 递归后的记忆最终输出。在上述等式中,$ \otimes $ 可以是任何非线性门控。在我们的实验中,我们使用可学习的向量值权重对输出 $ y $ 和 $ \mathbf{M}(\tilde{x}) $ 进行归一化,然后应用非线性函数 $ \sigma(\cdot) $。

该设计的整体注意力掩码如图 3b 所示。在这种设计中,滑动窗口注意力充当精确的短期记忆,而神经记忆模块则充当模型的衰减记忆。这种架构设计也可以被视为一种多头架构,其中头的结构不同(X. Dong 等人,2024)。

4.3 记忆作为层(Memory as a Layer, MAL)

在第三种变体中,我们将神经记忆作为深度神经网络的一层(见图 5)。这种架构设计在文献中更为常见,其中混合模型将循环模型与完整或滑动窗口注意力机制堆叠在一起。给定输入 $ x $,我们有:

\[\tilde{x} = [p_1, p_2, \ldots, p_{N_p}] \parallel x, \quad (29)\] \[y = \mathbf{M}(\tilde{x}), \quad (30)\] \[o = \text{SW-Attn}(y), \quad (31)\]

其中 $ \text{SW-Attn} $ 是滑动窗口注意力。这种设计的主要缺点是模型的能力受限于每一层,因此无法充分利用注意力和神经记忆模块的互补数据处理能力。在我们的实验中,为了评估这种设计中的记忆,我们使用了类似于 H3(D. Y. Fu 等人,2023)的架构,其中我们将序列模型替换为我们的神经记忆模块(LMM)。


无注意力的记忆

尽管上述讨论中我们将 MAL 视为 LMM 和注意力机制的顺序组合,但 MAL 的一个简单变体是将 LMM 视为没有任何注意力机制的序列模型。从记忆的角度来看,如第 1 节所述,我们期望记忆系统的每个部分都能独立工作,即使其他组件受到干扰。因此,即使没有短期记忆(即注意力机制),长期记忆模块仍然应该是一个强大的模型。我们在实验中称这种变体为 LMM 或 Titans(LMM)。我们在附录 C 中提供了关于 Titans 与其他现代循环模型联系的更多讨论。

图片名称

图5


4.4 架构细节

为了简洁和清晰,我们避免讨论实现细节,例如使用残差连接、线性层门控和归一化。在所有块中,我们使用残差连接。在我们的实现中,我们使用 SiLU(.) 激活函数(Elfwing, Uchibe, and Doya, 2018)作为计算查询、键和值的非线性激活,并使用 $ \ell_2 $-范数对查询和键进行归一化。


卷积

遵循最近的现代线性循环模型(Gu 和 Dao,2024;S. Yang, Kautz, and Hatamizadeh,2024),我们在每个查询、键和值投影之后加入一维深度可分离卷积层。虽然这些一维卷积对性能的影响不大,但它们已被证明可以提升性能,并且在计算上也很高效。


门控

我们还遵循最近的架构,在最终输出投影之前使用归一化和线性层门控(Mehta 等人,2023)。


定理 4.1

与 Transformers、对角线性循环模型和 DeltaNet 不同,这些模型都受限于 $ \text{TC}^0 $(Merrill, Petty, and Sabharwal, 2024),Titans 能够解决超出 $ \text{TC}^0 $ 的问题,这意味着 Titans 在状态跟踪任务中理论上比 Transformers 和大多数现代线性循环模型更具表达能力。

#

https://arxiv.org/pdf/2501.00663v1

meta Ins在《QuickUpdate: a Real-Time Personalization System for Large-Scale Recommendation Models》给出了它们的系统实现:

摘要

深度学习推荐模型在在线公司中扮演着重要角色,并且占据了用于训练和推理的AI基础设施的主要部分。这些模型的准确性高度依赖于它们在服务端的发布速度。提高模型更新延迟和更新频率的主要挑战之一是:模型大小(model size),这些模型size已经达到了TB级别,并且预计未来还会进一步增加。大的模型size导致了在分布式服务器中更新模型时的高延迟(和写入带宽)。我们提出了QuickUpdate,一个用于大规模推荐模型实时个性化的系统,它能够作为在线训练的一部分,可以高频率地进行模型发布,提供与完全新鲜模型相当的服务准确性。该系统采用了新技术来最小化所需的写入带宽,包括:优先参数更新、间歇性全模型更新、模型转换和宽松一致性(relaxed consistency)。我们使用真实世界的数据在Meta的一个最大生产模型上评估了QuickUpdate。结果表明,QuickUpdate提供了与完全新鲜模型相当的服务准确性,同时将平均发布的更新大小和所需带宽减少了超过13倍。它为实时服务生产模型提供了一个可扩展的解决方案,这在网络和存储带宽有限的情况下,否则是不可能大规模实现的。

1 引言

深度学习推荐模型(DLRM)在许多在线公司中被广泛使用。这些模型通过大规模数据进行训练,以学习用户和产品特征,从而在各种场景中提供个性化推荐。例如,Netflix [7] 和 YouTube [4] 为用户提供电影列表;Amazon [19] 和 Alibaba [20] 根据用户的搜索查询推荐相关产品;Google [3] 和 Meta [23] 则根据用户兴趣展示广告和内容。DLRM 在这些公司中占据了 AI 基础设施的主要部分。以 Meta 为例,DLRM 消耗了超过 80% 的机器学习推理周期 [8] 和超过 50% 的训练周期。

推荐模型有助于业务增长。例如,它们贡献了 Amazon 总购买量的 35% [8, 14]。由于这种广泛的业务影响,准确性成为大规模推荐模型的重要性能指标。特别是在 Meta 的业务中,设计检查点(checkpoint)和量化算法时要求准确性损失小于 0.01%[5]。这是一个非常狭窄的容差范围,表明了推荐模型及其准确性的重要性。

模型新鲜度 是个性化推荐模型准确性的关键因素 [4, 6, 9, 22, 25]。由于模型在高度动态的环境中运行推理,准确性可能会迅速下降。例如,每天都有新用户和物品注册到系统中,用户的兴趣可能会受到近期事件的影响。如果模型没有频繁更新,它将无法反映用户和产品的变化,从而导致准确性逐渐下降。为了进一步强调新鲜度的影响,图 3 展示了模型在数小时未刷新时的显著准确性损失。因此,为了将准确性保持在可接受的水平,推荐模型需要使用最新数据进行重新训练,并使用更新后的模型来服务实时推理。

保持推理模型新鲜的一种常见技术是在线训练。与每次从头开始重新训练模型不同,它使用实时流数据不断训练和优化模型。定期创建模型的快照并将其发布到位于不同地理区域的数百台服务器中。这些服务器随后利用该模型对在线查询进行实时预测。然而,更新服务模型会带来训练集群与分布式服务主机之间的延迟,导致模型刷新延迟,这主要是由于现代模型的规模庞大。

多年来,模型规模迅速增长,达到了 TB 级别,并包含数万亿个参数 [5, 10, 15],以捕获数百万个sparse特征并提高模型准确性。有限的写入带宽在将如此大的模型传输到分布式服务器和存储时构成了挑战。因此,更新延迟可能会延长到数小时。如第 3 节详细讨论的那样,这种长时间的延迟可能会对准确性产生不利影响。

为了解决上述由大模型规模及其导致的更新延迟带来的挑战,我们提出了 QuickUpdate。QuickUpdate 采用以下设计元素来实现大规模 DLRM 的实时个性化:

  1. 优先参数更新:在数百个地理分布的节点中完全更新所有服务模型的参数需要大量的网络和存储带宽,这构成了瓶颈。
    QuickUpdate 通过优先参数选择来最小化更新规模。它对服务模型中的特定参数进行排序和选择,同时从更新中修剪其余参数。这种方法显著减少了总体更新规模并缓解了带宽需求。
    参数排序算法在最小化更新规模时避免准确性下降至关重要。

  2. 间歇性全模型更新:间歇性全模型更新是指在一系列连续的部分更新之后进行一次完整的模型更新。这些完整更新的主要目的是保持服务模型的长期准确性。每次部分更新后,服务模型会偏离训练模型,因为前者仍然使用过时的参数值。随着更多部分更新的进行,这种偏差会变得更大,从而导致潜在的准确性影响。为了提高准确性,间歇性地发布完整模型更新以限制服务模型与训练模型之间的差距。

  3. 实时更新的模型转换:QuickUpdate 采用了几种模型转换技术来减少发布的模型规模,包括推理剪枝和量化。
    量化已在一些研究中成功实施 [11, 24, 27],以在不牺牲准确性的情况下降低浮点精度。它有助于减少推理集群中的存储需求和通信成本。推理剪枝则应用于非常大的查找表。在 DLRM 中,实体(如用户或视频)及其对应的向量以查找嵌入表的形式存储。实际不活动的实体索引(或 ID)从服务平台中修剪,以显著减少服务模型的规模。

  4. 简化的服务设计和宽松的一致性要求:在传统的服务设计中,模型在开始服务查询之前会完全加载到服务平台中,以保持强一致性。在这种设计中,每个推理请求都基于特定版本的模型权重执行,确保一致且可靠的结果。然而,这种方法由于使用额外的缓冲节点而带来了相当大的基础设施开销。
    在 QuickUpdate 中,我们通过放宽一致性要求引入了一种更高效的服务设计。权重直接在服务节点中更新,而不是使用缓冲节点。这消除了对额外基础设施的需求并减少了开销。然而,这种宽松的设计可能会导致嵌入表中的一些不一致性,因为它们可能包含新鲜和过时权重的混合
    尽管嵌入表中可能存在不一致性,但我们的评估表明,服务模型的准确性并未受到影响,反而带来了准确性的提升。

我们使用真实世界的数据和 Meta 生产中部署的最大模型之一对 QuickUpdate 进行了评估。总体而言,我们的结果表明,QuickUpdate 能够提供与完全新鲜模型相当的服务准确性,同时将所需的写入带宽减少了超过 13 倍。它为实时服务生产 DLRM 提供了一个可扩展的解决方案,而这在网络和存储带宽有限的情况下是无法大规模实现的。QuickUpdate 通过利用新颖的技术实现了这一点,包括选择性发布每次更新的最重要部分,同时仍然结合低频的间歇性全模型更新以确保长期准确性

2 背景

2.1 深度学习推荐模型(DLRM)

通常,深度学习推荐模型由dense和dense层组成,如图 1 所示 [5, 10, 26]。sparse层实际上是嵌入表,其中:每个嵌入表表示一个分类特征,表的每一行表示一个特定ID(例如用户 ID 或视频 ID)。嵌入表将每个 ID 转换为一个固定大小的浮点值向量,这些向量是可训练的。模型中其余可训练的部分称为dense层。

图片名称

图 1

图 1 展示了数据在 DLRM 中的流动方式。sparse特征通过嵌入表进行转换;dense特征通过底部的dense层进行转换。转换后的特征随后被连接起来,并在顶部的dense层中进一步转换,以计算输入数据的可能性。

2.1.1 训练 DLRM

并行化是规模化训练推荐模型的主要方法 [5, 8]。sparse层和dense层可以采用不同的并行化逻辑。sparse层占整个模型大小的 99% 以上,可能达到数 TB 的规模。由于将所有sparse层存储在单个节点中是不可行的,因此采用模型并行化的方法将表分片到多个节点上。另一方面,dense层的规模足够小,可以容纳在每个节点中,因此它们被复制到各个节点上以利用数据并行化。

在 Meta,一个典型的训练集群包括 16 个节点,每个节点包含一个多插槽 CPU 和 8 个 GPU:

  • sparse层在所有 GPU 上分片,在正向和反向计算期间进行“all-to-all”通信。
  • dense层在所有 GPU 上复制,在反向传播期间使用“all-reduce”通信来聚合多个 GPU 中计算的梯度 [16]。

在训练期间,sparse层和dense层的新权重被同步计算和更新,以避免准确性下降。

2.1.2 服务 DLRM

为了以高吞吐量的方式高效地服务批量请求,通常使用 GPU 进行模型服务。在 Meta,服务节点位于专用的服务集群中。一个服务节点由主机 CPU 和附加的 GPU 组成。服务模型在服务节点之间复制,并使用数据并行化进行规模化模型服务。

  • 广告嵌入表存储在单个 GPU 中,因为它们需要更高的读取吞吐量。
  • 其他嵌入表存储在 CPU 中,CPU 通常具有更大的内存容量(例如 1.5 TB DRAM)。

为了存储嵌入表,使用紧凑的数据结构来最小化大小并以 GPU 友好的方式存储。特别是,嵌入表是连续存储的,每个嵌入表以行优先顺序存储在数组数据结构中。

为了刷新服务模型,使用额外的缓冲节点来避免暂停当前的服务节点(双buffer机制)。在新发布的模型加载到缓冲节点后,请求流量会切换到缓冲节点。

2.2 在线训练与离线训练

当服务模型需要使用实时数据流不断训练和更新时,会实施在线训练。在在线训练中,服务模型在提供预测的同时,会定期(以分钟到小时为单位)更新。训练在后台(通常在单独的集群中)继续运行以微调模型。训练模型发布到服务平台的速率对其生成的预测准确性有显著影响。

相比之下,离线训练不使用实时数据流进行训练,也没有严格的时间限制来训练模型并将其发布到服务平台。相反,它通常使用已经存储在数据存储中的大量数据。模型使用所有可用数据进行训练,当满足某些最优条件时训练停止,之后将其发布到服务平台

选择在线训练还是离线训练取决于具体的使用场景。当模型需要及时更新时,会实施在线训练系统。典型的用例可能是广告、搜索和视频的 DLRM(例如 [13, 18, 21]),当环境高度动态且需要模型几乎实时更新(以分钟为单位)时。在线训练帮助这些模型整合最新数据并避免准确性下降。当业务需求不需要实时模型更新时,可以使用离线训练方法来更新模型。

2.3 优化器状态作为特征重要性度量

通常,大型 DLRM 可以使用数十万个特征进行训练;然而,其中一些特征及其对应的权重对准确性没有影响。这些特征可能属于不活跃的用户或内容 ID,或者某些其他特征可能无法提供训练信号。在模型发布时,维护所有这些参数会消耗额外的带宽,或者在服务平台中运行推理时会消耗额外的计算和 GPU 存储。

为了减轻这些不利影响,我们可以计算特征重要性,并相应地修剪那些实际上不影响准确性的特征和权重。

QuickUpdate 使用优化器状态(或梯度动量)来计算特征重要性。它属于基于梯度的特征重要性度量家族(例如 [2, 12, 18])。优化器状态是梯度值的历史平均值,比梯度值更稳定,因为梯度值有时会在正值和负值之间振荡。优化器状态可以为我们提供以下指示:

  1. 对准确性的影响:它实际上显示了在训练过程中某个特定参数被更新的频率和幅度。较高的梯度值表明对提高准确性有较大影响;如果这种影响在历史上持续存在,我们可以更有信心地推断该参数对准确性很重要;因此,优化器状态是特征对准确性影响的稳定指示。

  2. 访问率:优化器状态为零(或接近零)表明该参数在训练过程中实际上没有被更新。这有两层含义。首先,这可能意味着该参数没有被访问过。这可能发生在一些不活跃的用户 ID 或过时的视频上。其次,如果实体是活跃的,接近零的值可能意味着从相关数据中可能没有重要的信号可以学习。例如,某个特定用户 ID 可能没有使用平台点击广告。

基于上述信息,QuickUpdate 使用优化器状态来完成两项不同的任务:

  1. 推理剪枝:当发布完整模型时,QuickUpdate 使用优化器状态来减少完整模型更新的规模。在此阶段,QuickUpdate 关注优化器状态值的低尾部分,并修剪那些对准确性没有影响的参数。

  2. 优先参数选择:当发布部分更新时,QuickUpdate 使用优化器状态来选择更重要的参数进行更新。在此阶段,QuickUpdate 关注优化器状态值的高尾部分,以更新更重要的参数。

2.4 推理剪枝

推理剪枝是为了在将完整模型快照发布到服务平台时减少模型的大小。剪枝特别针对查找嵌入表实施,并显著减少其大小(例如减少 50%)。由于查找表占 DLRM 大小的 99% 以上,剪枝可以在不影响准确性的情况下显著减少模型的大小。减少模型大小有助于消耗更少的带宽来发布模型更新;因此,更新可以以更短的延迟发布到数百个地理分布的集群中。此外,它还有助于更快地执行推理,因为计算中涉及的行数更少。

图片名称

图 2

图 2 展示了一个查找表的示例,其中包含索引和对应的行。每个索引代表分配给用户、视频等的唯一 ID。每一行可以被视为一个由可训练的浮点值组成的向量,模型使用这些向量为用户生成个性化推荐。

直观上,推理剪枝算法识别出代表不活跃实体或无法提供训练信号以提高准确性的行。从数学上讲,这是通过使用优化器状态向量来实现的。具体来说,训练器可以为每一行提供一个优化器状态向量。优化器状态向量中的每个元素表示该行中对应元素的梯度动量。优化器状态向量中元素的平均值用于量化行的重要性值。如果行的重要性值接近零,则意味着该行的元素在训练过程中实际上没有被更新;因此,该行可以被剪枝。

图 2 还展示了剪枝前后查找表的变化。使用行重要性值,推理剪枝算法确定最不重要的索引,并在将完整更新发布到服务器之前将其剪枝。出于操作目的,原始索引值会被重新映射到新的索引。新索引会简单地按照它们在原始表中的出现顺序递增。

需要注意的是,在本文中,完整模型更新指的是执行推理剪枝后的模型快照。

3 动机

在本节中,我们通过真实世界的数据来展示开发 QuickUpdate 的动机。首先,我们展示了当模型一小时或更长时间未更新时,准确性如何显著下降。接着,我们讨论了模型规模扩大带来的影响。如果不修改模型发布方法,我们只能选择接受更长的更新延迟和逐渐下降的准确性,或者大量投资基础设施以保持更新延迟的一致性。最后,我们强调了无损模型更新的局限性,强调了优先更新的必要性。

准确性增益

更新完整的服务模型是一个耗时的过程,可能需要数小时。因此,用户的最新行为和兴趣(例如发布新动态或与特定内容互动)在几小时内不会反映在服务模型中,这可能会降低模型的准确性。

图片名称

图 3 展示了在 Meta 的一个大规模模型中,随着模型更新延迟(1 到 7 小时),准确性如何下降。它将陈旧服务模型的准确性损失与完全新鲜模型进行了比较。结果显示,当模型更新延迟时,准确性损失显著增加,7 小时后损失超过 0.6%。

减少更新规模有助于加速模型更新并提高服务模型的准确性。

模型规模

近年来,DLRM 的规模显著增加。这些模型利用大量数据和参数来更好地理解用户兴趣和产品特征,从而提高准确性。这一进展催生了具有数万亿参数 [15] 和数 TB 模型大小 [5] 的复杂模型。此外,这一趋势预计在未来仍将持续。

随着模型规模的增加,由于传输所需的带宽增加,模型更新延迟也随之延长。如果不加以解决,预计这将导致未来模型新鲜度的下降。鉴于模型规模的持续扩大,单纯增加基础设施并不是一个可行的长期解决方案。因此,对大规模 DLRM 进行部分更新似乎是一个有前景的策略,旨在减少更新延迟,而无需增加更多基础设施。

无损模型更新

为了更好地理解模型随时间变化的比例,我们监控了更新的嵌入行,并据此计算了模型中被修改的平均比例。图 4 显示了模型随时间更新的百分比。很明显,模型的大部分在短时间内被更新。例如,在短短 10 分钟的时间间隔内,58% 的模型被更新。更新 58% 的模型是资源密集型的,需要比每小时更新完整模型更多的基础设施。这促使我们探索一种优先更新的方法,以显著减少更新规模。

图片名称

图 4

4 系统概述

图 5 提供了 QuickUpdate 架构的概览。DLRM 系统由训练节点、服务节点和用于保存模型快照的远程存储组成。QuickUpdate 的发布逻辑主要在 UpdateSelectorUpdatePatcher 代理中实现,这两个代理分别部署在训练节点和服务节点中。

  • UpdateSelector 负责决定模型的哪一部分应该更新,并在保存到远程存储之前对其进行量化。
  • UpdatePatcher 根据执行的更新类型实现不同的修补策略。

以下部分提供了更多详细信息。

图片名称

图 5 系统架构

4.1 更新什么

QuickUpdate 专注于对嵌入表执行部分更新,这些表通常占深度学习推荐模型的绝大部分(在我们的工作负载中超过 99%)。在这些模型中,每个表表示一个分类特征(例如用户、视频),表中的每一行对应于与该特征相关的特定 ID。

在我们的探索中,我们考虑了两种更新嵌入表的选项:

    1. 更新选定表的所有行。
    1. 更新所有表中的选定行(不同表的选定行索引可能不同)。

我们发现,以行级粒度进行更新可以在最小化整体更新规模的同时提高准确性。因此,QuickUpdate 决定服务端需要更新表中的哪些特定行。这种方法使 QuickUpdate 能够优先更新更有可能提高准确性的内容或用户 ID,从而确保更新策略的高效性和有效性。

对于模型中的dense层,QuickUpdate 执行完整更新。这是因为这些层的更新规模相对较小,针对这些层的任何优化对整体更新过程的影响不大。

4.2 UpdateSelector

QuickUpdate 的 UpdateSelector 组件在训练集群中实现。这是因为它需要从训练器中获取某些模型信息(例如参数值)以准备模型更新。

在在线训练期间,训练器以批次间隔运行。在每个训练间隔结束时,训练器将模型状态和优化器状态共享给 UpdateSelector。模型状态包括分片的嵌入表和dense参数值,而优化器状态包括梯度值及其动量。这些状态从 GPU 内存复制到主机 CPU 内存中。

UpdateSelector 使用优化器状态对 CPU 中的模型副本执行以下两项任务:

  1. 优先参数选择:此任务的主要目标是仅更新模型参数的一小部分,同时最小化准确性的下降(与完整更新相比)。在此阶段,QuickUpdate 根据优化器状态值选择嵌入行,优先选择那些可能对准确性提升较大的行。
  2. 推理剪枝:此任务在发布完整模型时执行。推理剪枝专注于sparse嵌入表,旨在减少完整模型更新的规模。在此阶段,QuickUpdate 识别低尾优化器状态值,并剪枝值接近零的嵌入行。这些行对模型准确性的影响可以忽略不计。

一旦更新(无论是完整还是部分)准备就绪,它们会经过量化以减少其大小。量化作为一种压缩方法,对模型准确性的影响可以忽略不计。量化后的更新随后存储在远程存储中,准备用于更新过程。

4.3 UpdatePatcher

UpdatePatcher 负责加载发布的快照并更新服务模型。它对部分和完整模型更新都采用了一种高效的非原子更新方法。在非原子更新过程中,多个线程可以访问模型参数,并逐步将参数修补到服务器中。这种方法允许多个线程并发修补参数,而无需锁定服务器或模型。因此,服务器可以在应用更新的同时继续对传入流量进行推理。这种方法确保了在更新过程中实时流量的高效且不间断的服务。

4.4 工作流程

图 6 展示了 QuickUpdate 的工作流程。为简化说明,我们仅展示了训练器、UpdateSelector 和一个服务节点中的时间尺度。模型的演化是一个可重复的模式,因此我们专注于一个周期,该周期进一步分为多个间隔。在周期 $ c $ 的每个间隔 $ i $ 开始时,UpdateSelector 可以访问完整模型 $ F_{c,i} $ 以确定模型的哪一部分应该更新。具体来说,首先会发布一个完整快照(即 $ F_{c,1} $)并加载到服务器中,然后连续的部分更新($ P_{c,i} $ 其中 $ i > 1 $)会被发布并修补到完整快照中,以创建服务快照 $ S_{c,i} $。

图片名称

图6 QuickUpdate工作流

将更多部分更新与服务模型合并可能会导致服务模型 $ S_{c,i} $ 与当前训练器状态 $ F_{c,i} $ 之间的偏差增大。这种偏差可能会导致准确性下降。因此,另一个完整的新鲜快照(即 $ F_{c+1,1} $)将被发布到服务集群,标志着当前周期的结束。服务端的模型演化可以表示如下:

\[S_{c,1} = F_{c,1} \\ S_{c,i} = M(S_{c,i-1}, P_{c,i}) \quad \text{对于} \quad 1 < i \leq I\]

其中:

  • $ I $ 是一个周期中的间隔数
  • $ M $ 是合并操作符。合并操作符简单地复制 $ P_{c,i} $ 的参数值并将其更新到 $ S_{c,i-1} $ 中。

5 设计

在本节中,我们讨论了设计选项及其对准确性指标的影响。我们首先定义了指导设计和评估的具体准确性指标。通过在整个设计过程中优先考虑准确性,我们的目标是创建一个有效的系统,在解决网络和存储带宽瓶颈的同时,提供高服务准确性。需要注意的是,QuickUpdate 是可配置的,并在生产环境中进行监控,以应对罕见的准确性下降情况。

5.1 准确性指标

二元交叉熵(Binary Cross Entropy)或熵(Entropy)[17] 是评估广告模型准确性的一个众所周知的综合指标。在本研究中,我们使用归一化熵(Normalized Entropy, NE),其定义为二元交叉熵除以一个常数。为了理解部分更新相对于完全新鲜快照和过时模型的表现,我们计算了 NE 的以下变体。为简化说明,我们从符号中省略了周期下标 $ c $。

  1. NE 损失:它表示使用模型 $ S_i $ 而不是相应的完全新鲜模型 $ F_i $ 运行推理时的准确性下降。

    \[\text{NE}_{\text{loss}}(S_i) = \frac{\text{NE}_{S_i} - \text{NE}_{F_i}}{\text{NE}_{F_i}} \times 100 \quad (1)\]

    其中:

    • $ NE_{S_i} $ 和 $ NE_{F_i} $ 分别表示模型 $ S_i $ 和 $ F_i $ 的归一化熵。
  2. NE 增益:它表示如果使用 $ S_i $ 进行推理而不是过时模型,可以预期的准确性提升:

    \[\text{NE}_{\text{gain}}(S_i) = \frac{\text{NE}_{S_i} - \text{NE}_{\text{stale}}}{\text{NE}_{\text{stale}}} \times 100 \quad (2)\]

    过时模型被认为是最近发布的完整模型 $ F_1 $。

  3. NE 恢复:它表示模型 $ S_i $ 已达到的最大 NE 增益的百分比。我们假设如果可以使用完全训练的模型 $ F_i $ 进行推理,则可以实现最大 NE 增益。因此,NE 恢复定义为:

    \[\text{NE}_{\text{recovery}}(S_i) = \frac{\text{NE}_{\text{gain}}(S_i)}{\text{NE}_{\text{gain}}(F_i)} \times 100 \quad (3)\]

5.2 选择标准

为了优先更新能够带来更大准确性增益的行,我们需要一个在训练过程中保持稳定的可靠指标。虽然梯度向量可以作为标准,但其在正值和负值之间的振荡引入了数值不稳定性。相反,我们可以使用优化器状态向量(也称为动量),它提供了更稳定的度量。优化器状态向量表示特定行的历史梯度的平均平方和。通过将:

  • $ OS^r_{c,i} $: 表示为模型在间隔 $ i $ 时行 $ r $ 的优化器状态向量
  • $ \overline{OS^r_{c,i}}$: 表示为其元素的平均值,我们可以利用该度量作为行重要性的指示。

直观上,具有较大 $ \overline{OS}^r_{c,i} $ 值的行更有可能提高准确性。例如,这些行可能代表频繁使用平台点击广告的特定用户,或者代表具有高访问率的特定视频。除了给定间隔的 $ \overline{OS}^r_{c,i} $ 的大小外,跟踪其随时间的变化也可能很重要。这对于我们更倾向于优先选择相对于旧版本发生变化的参数的情况可能具有潜在的信息价值。基于这些直觉,我们评估以下选择标准:

  1. 绝对优化器状态

    \[\text{abs}(\overline{OS_{c,i}^r}) \quad \text{对于} \quad i > 1 \quad (4)\]
  2. 增量优化器状态

    \[\text{abs}(\overline{OS_{c,i}^r} - \overline{OS_{c,i-1}^r)} \quad \text{对于} \quad i > 1 \quad (5)\]

选择使用绝对优化器状态还是增量优化器状态作为选择标准取决于它们各自的优势和权衡。虽然绝对优化器状态提供了行对准确性影响的稳定和综合度量,但增量优化器状态捕捉了与前一个间隔相比影响的变化。然而,使用增量优化器状态需要额外的内存来存储前一个间隔的优化器状态。为了评估这些标准的影响,我们进行了实验,间隔长度为 30 分钟,更新规模为 10%。在发布完整快照并再训练一小时后,我们根据这两个标准发布了更新规模为 10% 的部分快照。然后,我们评估了与完全新鲜模型相比的服务准确性。表 1 中的结果(在多次此类实验中一致)表明,增量优化器状态实现了 100% 的 NE 恢复,而绝对优化器状态实现了 70% 的 NE 恢复。这意味着基于增量优化器状态选择行可以减少服务模型与相应完整快照之间的差异

图片名称

表1 不同选择标准的NE recovery

5.3 增量选择的基线

增量优化器状态是基于基线计算的。在计算增量优化器状态时,我们考虑了两种选择基线的方法:

  1. 上一次更新时的模型状态:此选项将上一次更新时的模型状态作为基线。这与上一节中增量优化器状态的定义相同。

  2. 上一次完整更新时的模型状态:如第 5.6 节所述,QuickUpdate 还利用间歇性完整更新。在此基线选项中,将上一次间歇性完整更新时的模型状态用作增量。在这种情况下,增量优化器状态定义为:

    \[\text{abs}(\overline{\text{OS}^r_{c,i}} - \overline{\text{OS}^r_{c,1}}) \quad \text{对于} \quad i > 1 \quad (6)\]

对于第一种选项,需要在每个训练间隔结束时保存基线,而对于第二种选项,只需在周期的第一个间隔中保存一个基线。因此,第一种选项提供了更新鲜的基线,但需要额外的计算资源将其保存在内存中。

我们通过实验检查了不同基线对服务准确性的影响,每个间歇性完整更新后跟随四个部分更新。我们评估了与完全新鲜模型相比的服务准确性(NE 恢复)。表 2 中的结果显示,使用上一次间隔的基线可以实现 3.11%(95.94% 对比 99.05%)更高的 NE 恢复。这表明,使用上一次更新的模型状态作为基线可以更好地反映最近的用户兴趣,因为它在每个更新间隔中都会刷新。此外,使用完整更新作为基线可能会优先选择在前一个间隔中重要但不再对准确性有贡献的参数。随着时间的推移,这些参数的优化器状态可能达到平稳状态,但由于其较大的增量优化器状态值,完整更新基线可能仍然会考虑它们。更频繁地刷新基线有助于消除对此类参数的优先选择,转而优先选择最近变化的参数,这些参数更有可能对准确性提升有贡献。

5.4 实时推理剪枝

如第 2.4 节所述,推理剪枝有助于减少服务模型的大小和所需的 GPU 数量。它实际上会剪枝那些不再活跃或对准确性影响可以忽略的行(或 ID),以减少嵌入表的大小。然后,剪枝后的表以 GPU 访问友好的方式紧凑地存储,以进一步减少服务集群中的大小。

剪枝仅在完整模型发布到服务集群时实施。对于后续的间隔,我们希望部分更新能够与完整模型更新中的剪枝表兼容。理想情况下,部分更新中的行 ID 应存在于剪枝表中。这有助于我们简单地更新现有行的值,而无需重新构建 GPU 中的表。然而,情况并非总是如此。由于部分更新的训练数据与完整模型更新不同,可能会出现某些行 ID 在部分更新中变得重要,而这些 ID 在服务端的剪枝表中不存在的情况。当发生这种情况时,一种简单的实现方法是将缺失的行插入服务端的剪枝表中,但这可能是资源密集型的,并且可能需要重新调整所有 GPU 上的嵌入表以确保可访问性和效率(例如,避免内存碎片)。

为了避免嵌入表的密集重新调整,我们探索了两种与部分更新兼容的推理剪枝策略:

  1. 固定索引剪枝(见图 7a):在此策略中,QuickUpdate 执行优先参数选择以选择要更新的候选行索引。然而,仅更新嵌入表中已存在的行,而剪枝的行保持不变。

  2. 固定剪枝比例(见图 7b):在此策略中,每次完整更新时从嵌入表中剪枝固定比例的行。当 QuickUpdate 执行优先参数选择时,它最多选择 $ X $ 个索引进行更新,其中 $ X $ 是服务平台上给定表中的总行数。这确保了表中的行数保持一致。

图片名称

图 7

第一种策略避免了重新调整,因为只会进行行更新操作,而不会向嵌入表中插入新行。第二种策略通过使用行更新和索引重映射操作来避免重新调整。由于嵌入表的大小在第二种策略中不会改变,因此也避免了在 GPU 之间重新分片嵌入表的需要。

为了评估这两种剪枝策略,我们考虑了三种训练场景:1-无剪枝,2-固定剪枝索引,3-每个表固定剪枝比例。

我们的实验表明,剪枝导致的 NE 损失实际上可以忽略不计(<0.001%),且两种剪枝策略之间没有准确性差异。考虑到实现需求,我们选择了固定剪枝索引策略,因为其实现更简单。与固定剪枝比例策略不同,它不需要在每次新更新时更新索引映射。

6 评估

我们在 Meta 部署的最大推荐模型之一上评估了 QuickUpdate,使用真实世界的数据,并在类似于 [15] 的生产训练集群上进行训练。该模型是 [16] 中提出的 DLRM 模型的扩展,但其规模显著更大,达到 TB 级别。我们在所有实验中使用相同的预记录数据流,使实验可重复且可比较,并消除了由于时间数据变化导致的潜在结果偏差。模型最初使用几周的真实世界数据进行训练作为预热期,以达到稳定状态。对于准确性评估,我们评估了在训练数据之后的时间段内数据流的服务预测(即推理期间评估的数据未在之前的训练中使用)。

6.1 准确性

在本节中,我们比较了不同更新粒度对准确性的影响,并推导出最小完整快照频率,以确保 NE 损失不超过 0.01%。在这些实验中,我们在开始时发布一个完整快照,并继续发布具有不同粒度的部分更新。这些部分更新应用于完整服务快照之上,并使用相同的记录数据集进行准确性评估。

6.1.1 与过时模型相比的 NE 增益

我们首先比较 QuickUpdate 与过时模型的准确性,以量化准确性增益并验证在完整快照之上应用部分更新不会对准确性产生负面影响。这里的过时模型指的是最初发布的完整快照。图 8 显示了不同更新粒度(且无间歇性完整模型更新)相对于过时模型的 NE 增益。所有更新规模的 NE 增益均高于过时模型,并且 NE 增益随时间增加。5% 和 10% 的更新提供了非常相似的 NE 增益,但 1% 的更新返回的 NE 增益较少,表明一些重要的行未包含在 1% 的更新中。总体而言,这些趋势表明,即使在应用部分更新超过 10 小时后,也没有负面影响,并且与过时模型相比,准确性提高了 0.7%。

图片名称

图 8

6.1.2 与完全新鲜模型相比的 NE 损失

在本节中,我们研究了使用部分更新发布的 QuickUpdate 模型的 NE 损失,与理想的完全新鲜服务模式进行比较。图 9 中的结果显示,使用 10% 更新时,NE 损失在整个 10 小时内低于 0.005%。使用 5% 更新时,NE 损失始终高于 10% 更新,但在超过 6 小时内仍低于 0.01%。随着训练周期的增加,NE 损失增加,因为服务模型与相应训练模型之间的差异增加。结果还展示了采用不同更新粒度对完整模型发布延迟的影响,同时确保 NE 损失保持在可接受的 0.01% 阈值以下。通过采用 10% 的粒度,我们可以有效地将完整模型发布的需求延迟超过 10 小时。同样,当使用 5% 的粒度时,我们可以将完整模型发布延迟 6 小时,同时仍将 NE 损失保持在可接受范围内。这突显了 5% 粒度下部分更新在捕获重要更新并在相当长的时间内保持模型准确性和新鲜度方面的有效性。

图片名称

图 9

6.1.3 短期内的 NE 损失

为了分析短期内的 NE 损失,我们进行了一项评估,涉及四个连续的 10 分钟更新。检查的更新粒度为 5%、3% 和 1%。每次更新后,使用未见过的数据测量与完全新鲜模型相比的 NE 损失。图 10 显示了不同 10 分钟间隔内的变化,强调了流数据的波动性。然而,当在多个短时间间隔内平均时,NE 损失趋于稳定。正如预期的那样,结果显示,随着粒度的增加,NE 损失减少。最后一列显示的平均 NE 损失证实,5% 的粒度在我们的工作负载中会返回可接受的 NE 损失(平均而言)。

图片名称

图 10

6.1.4 结论

准确性结果证明了 QuickUpdate 在采用 5% 更新粒度长达 6 小时的有效性,同时保持与完全新鲜模型相当的准确性水平,并确保 NE 损失低于 0.01% 的阈值。

此外,使用 QuickUpdate 的 5% 更新粒度允许在需要发布完整模型之前延迟 6 小时。这种延迟之所以可能,是因为部分更新成功捕获并整合了重要变化,从而生成了准确且最新的模型。

基于这些发现,QuickUpdate 默认每 6 小时触发一次间歇性完整模型发布,从而优化了准确性与更新频率之间的平衡。

6.2 分析长期行收敛性

在之前的分析中,我们的重点是基于准确性指标最小化部分更新粒度,并确定间歇性完整更新的适当频率。结果表明,使用 5% 粒度的部分更新持续 6 小时可以达到令人满意的准确性。

在本实验中,我们的目标是探索部分更新更新了模型中多少百分比的重要行。

为了确定重要行的代理,我们训练模型 6 小时(即与满意准确性相同的持续时间)。我们将重要行定义为训练模型中排名前百分之几的行,使得在 6 小时结束时发布这些行(而不是整个模型)将返回令人满意的准确性(即与完全新鲜模型相比差异低于 0.01%)。

图 11 显示了在 6 小时训练后,不同大小的单次更新与完全新鲜模型相比的 NE 损失。可以看出,单次 5% 的部分更新不足以将 NE 损失降低到可接受的 0.01% 阈值以下。然而,10% 的部分更新证明足以将 NE 损失降低到可接受的水平。这表明排名前 10% 的嵌入行是此时间窗口内重要行的良好代理。

图片名称

图 11

为了了解这些重要行中有多少百分比被多个较小的 5% 更新覆盖,我们运行了 QuickUpdate 6 小时,并使用多个 5% 粒度的部分更新。在将所有这些更新合并为一个联合集后,我们观察到该集合涵盖了上述重要行的 70%,并总体覆盖了模型中所有行的 7.3%。因此,大部分重要行被连续的较小部分更新所覆盖。

6.3 带宽使用

在 QuickUpdate 中,更新大小是带宽使用的代理。带宽使用量取决于粒度、更新间隔和间歇性完整模型更新的频率。通常,这些参数是可配置的,并可能根据 DLRM 的类型和所需的准确性而变化。在本节中,我们评估了基于发布模型百分比的不同策略的带宽使用情况。详细信息如下并在图 12 中展示:

图片名称

图 12

  1. 基线 1:每小时发布一次完整模型。
  2. 基线 2:每 10 分钟发布一次完整模型(未在图中显示)。
  3. 5% 更新(默认策略):每 10 分钟发布一次部分更新,粒度为 5%,每 6 小时发布一次间歇性完整更新(如 6.1 节所述)。
  4. 10% 更新:与之前的策略类似,每 10 分钟发布一次部分更新,但粒度为 10%。每 6 小时发布一次间歇性完整更新。

为了比较这些策略,我们平均了消耗的带宽。结果显示,默认策略(5% 更新粒度,6 小时间歇性完整更新间隔)平均每小时写入模型大小的 43.6%,而策略 3(10% 更新)为 68.2%,基线 1 为 100%。基线 2 提供了与策略 2 和 3 相当的准确性,但需要每小时发布模型大小的 600%。

总体而言,使用默认策略,QuickUpdate 能够将消耗的带宽比基线 1 减少 2.3 倍,同时提供与完全新鲜模型相当的更好准确性。与基线 2 相比(由于网络和存储带宽限制,无法大规模实施),QuickUpdate 能够将所需带宽减少超过 13 倍,同时仍提供相当的准确性。

6.4 宽松一致性

传统上,服务模型以原子方式更新以保持一致的推理。这涉及将所有模型权重加载到缓冲节点中,这些节点随后成为计算推理的服务节点。然而,这种方法由于使用缓冲节点而资源密集。为了解决这个问题,QuickUpdate 放宽了一致性要求,并在执行推理查询的同时直接在服务节点中更新参数。

我们评估了在 QuickUpdate 中间歇性完整模型更新期间的 NE 恢复(与完全新鲜模型相比),作为已更新权重百分比的函数。如图 13 所示,放宽一致性可以在加载期间提高生产中的准确性。随着更多参数的加载,NE 恢复增加。我们的数据显示,通过修补 30% 的参数,我们可以捕获约 54% 的 NE 恢复。在修补仅 70% 的参数后,NE 恢复达到约 94%。

图片名称

图 13

宽松一致性允许早期服务新鲜行(而不是等待整个模型更新),从而整体提高准确性。尽管在加载期间表的视图不一致(意味着不同的行可能属于不同的状态),服务一部分新鲜行已经导致准确性增加。NE 恢复随着时间的推移继续增长,直到整个模型更新完毕。

7 相关工作

异步或部分更新策略已在少数实时 DLRM 中实施 [13,18,21]。在 Kraken [21] 中,dense参数每隔几秒批量更新一次,而sparse参数在训练器中值发生变化时更新。这是一种无损参数更新,对于具有 1000-10000 亿参数 [15] 和地理分布式服务器的大型模型,可能会产生大量流量。Monolith [13] 主要专注于开发具有无冲突嵌入表的sparse特征系统。sparse参数可以在训练时以分钟级粒度更新,其值自上次同步以来发生变化。与 Kraken 类似,这是一种无损更新,可能会产生巨大的流量。总体而言,无损模型更新可能非常资源密集,如第 3 节所述。为了克服这个问题,QuickUpdate 可以执行优先参数选择,从而减少约 78% − 92% 的带宽,并且准确性损失可以忽略不计(< 0.01%)。在另一项研究中,Ekko [18] 被设计为一个高效的系统,用于将更新从训练模型广播到所有服务推理节点。为了快速更新服务模型中的较大嵌入表,他们使用了嵌入表更新中的sparse性和时间局部性。Ekko 系统与 QuickUpdate 正交,两者可以一起实施。在 QuickUpdate 中,我们优化了设计元素,如发布间隔、更新粒度和参数选择标准,以实现所需的准确性并最小化完整模型的发布。优先参数选择是我们在本文中使用的技术之一。在以往的研究中(例如 [1,2,12]),基于梯度的参数选择已在分布式训练系统中探索。Ekko [18] 进一步扩展了这一标准,并额外考虑了每个参数的请求频率和参数新鲜度作为选择标准。在 QuickUpdate 中,我们决定选择梯度动量的增量,这是一个比梯度本身更稳定的度量,并且它发布的参数能够返回与基线快照相比的最高准确性。

8 结论

QuickUpdate 是一个支持在线训练执行低延迟部分更新的系统,同时提供与完全新鲜模型相当的服务准确性。它为实时服务生产规模的 DLRM 提供了一个可扩展的解决方案。这一点尤其有价值,因为由于网络和存储带宽的限制,大规模实时服务此类模型具有挑战性。

QuickUpdate 通过利用创新技术实现了其可扩展性和准确性目标。其中一项技术是选择性发布每次更新的最重要部分,从而在保持准确性的同时减少整体更新规模。此外,QuickUpdate 以低频率结合间歇性完整模型更新,以确保长期准确性。这种选择性部分更新和间歇性完整更新的结合使 QuickUpdate 能够在低延迟服务和长期保持准确性之间取得平衡。

我们使用大规模个性化广告模型的真实世界数据对 QuickUpdate 进行了评估,结果表明 QuickUpdate 能够提供与完全新鲜模型相当的服务准确性,同时将所需的写入带宽减少超过 13 倍。

附录

wechat在《Ekko: A Large-Scale Deep Learning Recommender System with Low-Latency Model Update》给出了它们的实时方案:

摘要

深度学习推荐系统(DLRSs)需要在低延迟下更新模型,以便及时为新用户和内容提供服务。然而,现有的DLRSs未能做到这一点。它们通常:在离线状态下训练/验证模型,并将整个模型广播到全局推理集群中。因此,它们会带来显著的模型更新延迟(例如几十分钟),这对服务级别目标(SLOs:Service-Level Objectives)产生了不利影响。

本文介绍了一种名为Ekko的新型DLRS,它能够实现低延迟的模型更新。其设计理念是允许模型更新立即传播到所有推理集群,从而绕过长时间延迟的模型checkpoint、验证和广播。为了实现这一理念:

  • 首先,我们设计了一种高效的点对点模型更新传播算法。该算法利用DLRS模型更新中的稀疏性和时间局部性,以提高模型更新的吞吐量和延迟。
  • 此外,Ekko还配备了一个模型更新调度器,可以在网络繁忙时优先发送对SLOs影响较大的模型更新。
  • 最后,Ekko还包含一个推理模型状态管理器,用于监控推理模型的SLOs,并在检测到对SLOs有害的偏差更新(detrimental biased update)时回滚模型。

评估结果表明,Ekko比最先进的DLRS系统快了几个数量级。Ekko已在生产环境中部署超过一年,每天为超过十亿用户提供服务,并将模型更新延迟从最先进系统的几十分钟减少到2.4秒。

1 引言

深度学习推荐系统(DLRSs)是大型技术组织(如Meta [54]、字节跳动 [23]、谷歌 [15] 和英伟达 [56])中的关键基础设施。DLRS通常包含一组大型参数服务器,这些服务器托管着众多机器学习(ML)模型(例如嵌入表 [10, 26, 54] 和深度神经网络 [18])。参数服务器在地理分布的数据中心中进行复制,以实现容错和与客户端的低延迟通信。每个数据中心都有一组推理服务器,这些服务器从本地参数服务器中拉取模型,并为客户端提供推荐结果。

为了确保能够及时为新用户和内容提供服务,DLRS必须持续更新ML模型:它首先使用训练服务器收集新的训练数据并计算模型梯度,然后通过广域网(WAN)将模型更新传播到模型副本。

大规模的DLRS需要为数十亿用户提供服务 [15, 23, 54],并且必须实现与延迟相关的服务级别目标(SLOs)[49],例如将新创建的内容提供给用户的延迟。为了最好地实现SLOs,DLRS的操作者对实现低延迟模型更新提出了新的需求。这有几个原因:

  • (i)最近的DLRS应用(如YouTube [24] 或 TikTok [8])使用户能够创建大量的短视频、文章和图像。所有这些内容都需要尽快提供给客户端,通常在几分钟甚至几秒钟内;
  • (ii)数据保护法律(如GDPR [60])允许DLRS用户匿名。匿名用户的行为需要在线学习
  • (iii)许多在线ML模型(如强化学习 [74])已被用于生产中以提高推荐质量。这些模型必须在线持续更新以实现最佳性能。

然而,在现有的DLRS中实现低延迟模型更新极为困难。现有系统(如Merlin [56]、TFRA [66]、Check-N-Run [21] 和 BigGraph [39])采用离线方式更新模型:在收集新训练数据后,这些系统离线计算模型梯度,验证模型checkpoint,并将checkpoint广播到所有数据中心。这样的模型更新过程可能需要几分钟甚至几小时 [21]。

另一种方法是使用WAN优化的ML系统 [28] 或联邦学习系统 [37]。这些系统使用本地收集的数据更新副本模型,并延迟同步副本。然而,延迟同步引入了显著的异步性,这通常会对SLOs的实现产生不利影响 [28, 42]。

我们希望探索一种DLRS设计,能够在实现低延迟模型更新的同时不损害SLOs。我们的核心思想是:允许训练服务器在线更新模型(使用梯度),并立即将模型更新传播到所有推理集群。这种设计使我们能够绕过长时间延迟的更新步骤,包括离线训练、模型checkpoint、验证和广播,从而减少模型更新延迟。为了使这一设计可行,我们需要解决几个挑战:

  • (i)如何在带宽有限且网络路径异构的WAN上高效传播大量模型更新 [28];
  • (ii)如何保护SLOs免受网络拥塞的影响,这种拥塞可能会延迟关键更新;
  • (iii)如何保护SLOs免受对模型准确性有害的偏差更新的影响。

本文介绍了Ekko,一种新型的大规模DLRS,能够以低延迟更新全局复制的模型。Ekko的设计做出了以下几个关键贡献:

(1) 高效的点对点模型更新传播

现有的参数服务器通常采用主备数据复制协议[11, 41, 67] 来实现模型更新。然而,在大模型更新的情况下,主备协议由于更新延迟长 [67] 和领导者瓶颈 [2] 而表现出不足的可扩展性。

为了解决这些问题,我们探索了如何实现点对点(P2P)[20] 模型更新传播。我们为地理分布的DLRS设计了一种高效的无日志状态同步算法(见§4)。该算法在DLRS中非常有效,因为模型更新通常集中在热门参数上 [21],并且它只传输模型参数的最新版本(即状态)。Ekko必须允许参数服务器以P2P方式高效发现模型状态的差异。为此,我们设计了:

  • (i)模型更新缓存:使参数服务器能够高效跟踪和比较模型状态;
  • (ii)分片版本:可以显著减少比较模型状态时的网络带宽消耗;
  • (iii)WAN优化的传播拓扑:使参数服务器能够优先选择带宽充足的区域内网络路径,而不是带宽有限的跨区域网络路径。

(2) SLO保护机制

Ekko允许模型更新在没有离线模型验证的情况下到达推理集群。这种设计可能会使SLOs(特别是与推荐结果的新鲜度和质量相关的SLOs)容易受到网络拥塞和偏差更新的影响,这两种情况在生产环境中都可能发生。

为了处理网络拥塞,我们设计了一个SLO感知的模型更新调度器(见§5)。该调度器计算包括:更新新鲜度优先级、更新重要性优先级、模型优先级在内的指标。这些指标预测模型更新对推理SLOs的影响。调度器根据这些指标在线计算每个模型更新的优先级。我们将调度器集成到参数服务器中,而不改变Ekko中P2P模型更新传播的分布式架构。

Ekko通过一种新颖的推理模型状态管理器处理偏差更新。该管理器为每组推理模型创建一个基线模型。该基线模型接收少量用户流量,并作为推理模型的基准。管理器持续监控基线和推理模型的质量相关SLOs。当偏差更新损坏推理模型的状态时,管理器通知见证服务器将模型回滚到健康状态。

我们使用测试床和大规模生产集群对Ekko进行了评估(见§6)。测试床实验结果表明,与最先进的参数服务器(如Adam [11])相比,Ekko将模型更新延迟减少了最多7倍

我们进一步在包含40 TB模型和分布在多个地理区域的4,600多台服务器的大规模生产环境中进行了实验。实验结果表明,Ekko在每秒执行10亿次更新(即212 GB/s)的情况下,传播更新的延迟仅为2.4秒。Ekko仅使用总网络带宽的3.0%进行同步,其余带宽用于训练和推理。这种秒级延迟性能比最先进的DLRS基础设施(如TFRA [66] 和 Check-N-Run [21])实现的分钟级延迟(即5分钟 [69])快了几个数量级。

2 DLRS中的低延迟模型更新

在本节中,我们将介绍DLRS及其更新模型的算法。然后,我们描述那些能够从减少模型更新延迟中受益的服务级别目标(SLOs)。最后,我们讨论实现低延迟模型更新所面临的系统挑战。

2.1 DLRS与模型更新

大多数技术组织采用如图1所示的系统架构来构建DLRS。DLRS通常服务于分布在全球的客户端(1)。为了最小化服务延迟,DLRS模型(例如嵌入表 [10, 26, 54] 和深度神经网络 [18])在多个数据中心中进行地理复制。当客户端的请求到达时,推理服务器从本地参数服务器中拉取模型参数,并基于该模型进行推理以响应请求。

图片名称

图1 一个典型的DLRS架构

数据管道在运行时从客户端收集训练数据(例如新内容和用户活动)。收集到的数据到达数据中心的训练服务器(2)。训练服务器使用优化器 [33] 计算梯度以修正相应的模型。所有更新后的模型(通常有数百到数千个)被持久化为checkpoint(3)。这些checkpoint首先经过验证,只有那些能够改善SLOs的checkpoint才会通过广域网(WAN)传播到面向推理的数据中心的参数服务器(4),从而完成模型更新过程。

在实践中,更新DLRS模型的延迟包括:计算模型更新和将更新传播到全球数据中心的时间。这个延迟定义假设我们已经使用了低延迟的消息队列(例如Kafka [36])来加速训练数据的摄取。最近的DLRS(如NVIDIA Merlin [56] 和 Meta Check-N-Run [21])报告了模型更新的分钟级和小时级延迟。假设我们想要更新一个包含大型嵌入表(通常大小为几TB)的DLRS模型。在这种情况下,将该模型持久化为checkpoint并验证模型可能需要几十分钟。再通过WAN传播该模型还需要十几分钟(假设该WAN提供数Gbps的带宽 [72])

2.2 低延迟模型更新的原因

DLRS需要实现许多服务级别目标(SLOs),这些目标通常与推荐结果的新鲜度和质量相关。以短视频推荐服务(如TikTok)为例,DLRS模型的准确性决定了该服务的质量SLOs,而将新制作的视频快速提供给用户的时间则决定了该服务的新鲜度SLOs。

在实际的DLRS中,我们观察到SLOs通常依赖于完成模型更新的延迟,这使得低延迟模型更新成为一项关键的系统需求。这有以下几个原因:

(1) 短时间内产生的大量新内容
全球DLRS(如YouTube [24]、TikTok [8] 和 Instagram [22])通常服务于数十亿用户,并且允许用户快速创建大量内容。DLRS需要通过低延迟更新模型,快速将这些新内容整合到推荐结果中,否则会影响用户参与度。

(2) 匿名用户的增加
数据保护法律(如GDPR [60])禁止许多DLRS跟踪用户活动。因此,即使这些用户之前使用过相同的服务,DLRS也可能拥有推荐模型未知的匿名用户。因此,DLRS必须快速响应匿名用户的在线活动,以满足他们的推荐需求。这种快速响应依赖于低延迟的模型更新。

(3) 在线推荐模型的增加
DLRS中越来越多的在线机器学习模型(例如使用强化学习 [74] 和持续学习 [69] 的模型)被用于提高推荐质量。这些模型需要从在线用户活动中收集训练数据,因此必须以低延迟持续更新模型参数。

2.3 我们的核心思想及相关挑战

我们希望探索如何在更新DLRS模型时实现低延迟。我们的观察是,更新延迟主要是由于几个离线步骤累积而成的:模型训练、验证和广播。假设我们绕过这些离线步骤,允许更新后的模型直接传播到推理集群,那么我们可以大幅减少更新模型的步骤,从而实现低延迟。然而,要实现这样的设计,我们必须解决以下几个挑战:

(1) 缺乏高效传播大规模模型更新的算法

现实中的DLRS通常拥有大量模型(例如通常有数百到数千个)。它需要在线更新其中许多模型。这些模型包括多阶段推荐管道 [10, 15] 中的模型以及用于A/B测试 [69] 的模型。这些模型通常占用数十TB的内存,并且需要在线完成大规模模型更新(例如每秒数百GB)。

假设我们使用传统的数据复制协议,例如链式复制 [41] 和两阶段提交 [11]。这些协议针对的是通用数据复制,缺乏在带宽有限的网络(即WAN)上协调ML模型更新的机制(这些更新可能对推理SLOs产生不同的影响)。此外,这些传统协议存在领导者瓶颈问题,并且由于异构的WAN路径和网络滞后节点而导致较长的更新延迟。因此,这些协议不适合满足我们的高吞吐量、低延迟需求。另一种选择是使用地理复制协议 [72]。然而,这些协议无法处理训练数据中心中的服务器故障,因此无法满足我们的系统可用性要求。

我们还考虑了网络高效的分布式ML系统,例如Gaia [28] 和 Google Federated [35]。这些系统 [7, 28, 35, 37, 46] 允许模型在每个数据中心中独立训练,从而提高模型更新的吞吐量和延迟。然而,它们延迟同步其状态,因此会导致模型状态过时 [47],这可能对推荐质量产生不利影响。因此,松散同步的分布式ML系统无法满足我们的模型准确性要求。

(2) 缺乏保护SLOs的机制
在DLRS中启用在线模型更新对SLOs提出了挑战。这样的DLRS可能会出现模型更新竞争网络带宽的情况,从而延迟关键更新(例如那些显著影响模型准确性或上线新内容的更新)。尽管有一些系统可以调度模型梯度的发送 [6],但这些系统针对的是训练集群。因此,它们基于梯度 [6, 28] 优先处理模型更新,而缺乏对这些更新如何影响推理模型SLOs的认识。

在线模型更新甚至可能是有害的。由于在线更新通常是基于一小批数据(在短时间内收集的数据:几秒或几分钟)计算的,它们通常包含噪声 [34]。当更新变得特别嘈杂时,它们会对推理SLOs产生不利影响(即降低推理模型的准确性)。为了解决这个问题,

  • 现有的模型服务系统(如Clipper [16] 和 Clockwork [25])使用离线模型验证,这种方法会对长时间(例如几小时)累积的模型更新进行平均
  • 其他模型服务系统(如Google TFRA [66])跟踪推理模型的SLO指标,并在SLOs恶化时重新加载checkpoint

然而,这样的设计在DLRS中实现起来具有挑战性。大型DLRS模型(例如面向推荐的Transformer模型 [18])越来越常见,重新加载这些模型会影响服务的可用性。

3 Ekko系统架构

本文介绍了Ekko,一种能够实现低延迟模型更新的新型DLRS系统。在本节中,我们将描述Ekko的系统模型,并概述其核心创新组件。

3.1 系统模型

Ekko是一个地理分布的DLRS系统。它在中心数据中心更新模型,然后将更新后的模型传播到靠近全球用户(即客户端)的地理分布数据中心。Ekko将模型表示为键值对,并将模型划分为分片(例如在我们的生产环境中有100,000个分片)。它将模型分片存储在键值存储中(在Ekko中称为参数存储)。参数存储通过哈希将键值对分配到分片中。由于模型经常在线整合新项目和特征过期 [32],模型大小可能会随时间变化。

Ekko使用基于软件的路由器将参数请求定向到模型分片。这些路由器将训练数据中心中的参数服务器指定为模型分片的主节点。它们还确保主节点的选择能够平衡参数请求的工作负载。路由器的实现遵循典型的键值存储和数据库 [38]。本文中我们省略了路由器实现的细节。

在路由器中,分片管理器可以处理资源过载、故障域 [55] 和副本集问题 [12]。与传统的分片管理器不同,Ekko的分片管理器实现了几个DLRS特定的优化:

  • (i)为了分摊请求处理开销,Ekko将针对同一模型的并发推理请求进行批处理 [16]。然而,批处理请求可能会查询不同参数服务器上的大量参数(例如数千个),从而导致长尾查询延迟 [19]。为了防止长尾延迟,Ekko限制分配给模型分片的服务器数量;
  • (ii)Ekko支持多个需要性能隔离的DLRS应用程序。它将不同应用程序的分片映射到不同的服务器上,因此一个应用程序的分片请求激增不会影响其他应用程序的分片。

3.2 架构概述

我们在图2中突出了Ekko的创新设计。如图所示,Ekko使参数服务器能够实现高效的点对点(P2P)模型更新( 1 )(见§4)。P2P模型更新算法避免了中心训练数据中心广播更新后的模型,而是利用数据中心内部和跨数据中心的所有网络路径(图中的实线),从而在传播模型更新时实现高吞吐量。在没有中央协调器的情况下,每个数据中心可以独立选择优化同步模型更新的间隔。

图片名称

图2 Ekko架构总览

Ekko支持大规模模型更新的并发传播。这些更新可能会竞争网络资源,从而延迟对SLOs有显著益处的更新。为了解决这个问题,Ekko依赖于一个SLO感知的模型更新调度器( 2 )(见§5.2)。该调度器预测每个模型更新将如何影响推理结果。预测结果有助于计算每个模型更新的优先级。基于优先级,Ekko协调在训练数据中心优先传播哪些模型更新,从而提高推理服务器上SLOs的整体满意度。

Ekko可以保护推理服务器免受有害模型更新的影响。为了实现这一点,它在推理集群中运行一个模型状态管理器( 3 )(见§5.3)。该模型状态管理器监控推理模型的SLO相关指标。如果某个推理模型表现出性能下降(由在线更新引起),管理器会将模型状态回滚到性能更好的状态,从而恢复推理模型的性能

4 高效的点对点模型更新

本节介绍Ekko中高效的点对点(P2P)模型更新机制。为了实现参数服务器中的P2P模型更新,Ekko的设计实现了以下目标:

  • 协调大量参数服务器:Ekko需要协调大量(例如数千个)分布在全球的参数服务器完成模型更新。为了避免网络延迟导致的滞后问题,我们为Ekko中的参数服务器设计了无日志同步机制(§4.3)。
  • 支持大规模模型更新:作为一个共享的DLRS,Ekko需要托管数千个模型。这些模型可以在线生成大量(例如每秒数十亿次)更新。为了支持这一点,Ekko使参数服务器能够通过对等节点高效发现模型更新,并在不过度消耗计算和网络资源的情况下拉取更新(§4.4)。
  • 支持地理分布式部署:Ekko需要支持地理分布式部署,这通常涉及跨WAN的异构网络路径以及服务器/网络故障。为此,Ekko设计了系统机制,以提高在WAN上发送模型更新的吞吐量/延迟,并容忍服务器/网络故障(§4.5)。

接下来,我们将概述P2P模型更新机制,并详细描述其实现。

4.1 模型更新概述

图3展示了Ekko中模型更新涉及的组件和步骤。假设我们需要在两个副本(分别称为副本1和副本2)之间同步一个分片(称为分片1)。与所有其他分片类似,分片1具有:

  • (i)分片知识(shard knowledge),用于总结参数更新;
  • (ii)更新缓存(update cache),用于基于参数版本跟踪最近的模型更新。每个分片还关联一个分片版本,用于指示该分片是否可能有需要同步的参数。

分片知识、更新缓存和分片版本共同加速了参数服务器之间的参数同步。

图片名称

图3 Ekko P2P模型更新总览

为了完成模型更新,副本2从副本1请求最近修改的分片版本( 1 )。副本1收到请求后,返回最近修改的分片版本列表( 2 )。副本2将副本1的所有分片版本与本地分片版本进行比较,然后向副本1发送相关的分片知识( 3 )。最后,副本1将所有更新的参数发送给副本2( 4 )。通过这些步骤,Ekko可以确保模型更新最终以低延迟传播到所有副本(即最终一致性)。

我们发现最终一致性在实际DLRS中是可接受的。尽管DNN副本可能在一个小时间窗口内存在差异,但它们通常表现出接近(甚至完全相同)的推理结果 [11]。这是因为DNN通常使用浮点数表示模型参数,因此即使本地参数值存在微小差异,DNN副本也会做出接近的预测。

4.2 DLRS中的参数版本

为了跟踪模型参数的状态,Ekko为每个键值对(即模型参数的存储格式)分配一个参数版本,定义如下:

定义1(参数版本):参数版本 $ v $ 是一个由时间戳 $ t $ 和唯一标识副本的 $ id $ 组成的对 $ (t, id) $。时间戳 $ t $ 基于现代物理时间源 [14, 43] 提供的时间范围生成。Ekko确保 $ t $ 在每个副本中单调递增,并使用计数器填充物理时间戳,以确保来自单个副本的任何两个更新不会共享相同的时间戳。我们定义参数版本的总序关系:

\[v_1 \geq v_2 \iff (t_1 > t_2) \lor ((t_1 = t_2) \land (id_1 \geq id_2))\]

在冲突解决期间,具有较大参数版本的参数将覆盖另一个参数 [62]。

在Ekko中,值得注意的是,时间戳基于实时时钟而不是逻辑时钟(逻辑时钟通常用于键值存储和存储服务)。我们发现这种设计在分布式DLRS中非常有效,原因如下:DLRS具有嵌入表,其中参数是稀疏更新的。假设主副本中有一个嵌入参数,该参数有大量更新计数,但主副本在失败之前未传播该参数。当主副本恢复时,计数器可能会用较小的更新计数覆盖当前主副本。这种覆盖可能会对推荐质量产生不利影响,因为被覆盖的主副本可能具有更新的参数(由最近收集的训练数据更新),从而产生更好的推荐结果。因此,逻辑计数器不足以解决分布式DLRS中的冲突。

4.3 无日志参数同步

一旦为参数分配了版本号,Ekko需要决定如何同步不同的副本。我们观察到DLRS通常会覆盖参数,只有最后一次写入决定参数的状态。因此,我们决定发送参数的最后一个版本。

Ekko需要决定同步副本的间隔。我们可以使用基于日志的同步算法 [9, 11]:这些算法选择同步间隔,以便模型更新可以以不超过网络中最慢链路带宽的速率发送。然而,这些算法会导致许多网络链路的利用率不足。更重要的是,它会导致滞后问题,从而显著增加同步延迟,使参数服务器在从故障中恢复时更有可能具有过时的状态。因此,我们希望实现参数服务器中的无日志参数同步,以便这些服务器可以根据每个链路的带宽动态选择与对等节点的同步间隔。

参数服务器中的分片知识:我们建议使用分片知识 [50, 51] 来实现无日志参数同步。更正式地说,在每个副本中,其所有分片都维护相应的分片知识。分片知识使用版本向量 [58] 实现,总结了它们学到的参数更新。与分片知识 $ VV_{\text{shard}} $ 关联的分片数据反映了应用来自每个副本 $ r $ 的所有历史参数更新后的空分片状态,其中更新对应的参数版本 $ v \leq VV_{\text{shard}}[r] $。假设在副本 $ r $ 中有一个参数 $ p $ 的更新需要处理。为了维护分片知识,该副本生成一个新的参数版本 $ v_p = (t, id) $ 并设置 $ VV_{\text{shard}}[id] = v_p $。

分片同步过程:为了同步一个分片,副本 $ r $ 将其分片知识 $ VV_{r1} $ 发送到选定的副本 $ s $。副本 $ s $ 记录其当前分片知识 $ VV_s $ —— 即原子地读取 $ VV_s $ 并从其存储中选择所有参数 $ p $,其参数版本 $ v_p = (t_p, id_p) > VV_{r1}[id_p] $ —— 并用 $ VV_s $ 响应 $ r $。然后,$ r $ 根据 $ s $ 的响应原子地应用所有参数更新,并进一步将 $ VV_s $ 与其当前分片知识 $ VV_{r2} $ 合并。

在同步过程中有几个注意事项:(i)当副本 $ r $ 与副本 $ s $ 同步时,$ r $ 可能同时与另一个副本(称为副本 $ k $)进行同步操作。这些操作可能在 $ r $ 完成处理 $ s $ 的响应之前完成。因此,$ VV_{r2} $(即 $ VV_r \cup VV_k $ 的结果)不一定等于 $ VV_{r1} $。(ii)在无故障场景中,同步过程会省略所有被覆盖的参数版本,这些场景中更新参数的请求总是路由到同一个主副本。我们发现这些无故障场景在我们的生产环境中很常见。

4.4 提高同步效率

Ekko必须确保参数同步对参数服务器的性能开销可以忽略不计。否则,同步可能会消耗过多的计算和通信资源,从而影响参数服务器在服务模型推理和训练请求时的性能。接下来,我们将讨论如何通过参数更新缓存(减少计算成本)和分片版本(减少通信成本)来提高参数同步的效率。

4.4.1 参数更新缓存

由于一个分片可能包含大量参数,简单地遍历所有参数来响应同步请求会带来巨大的计算成本。尽管我们可以使用索引来加速参数遍历,但维护这样的索引会消耗大量内存资源,而这些资源在参数服务器上难以提供。

我们设计了参数更新缓存来减少参数同步的计算成本。这种缓存的设计利用了我们在DLRS中经常观察到的稀疏性时间局部性 [21]。与密集的DNN训练系统(每次迭代更新整个模型)不同,DLRS只更新其参数的一个子集(即稀疏性)。例如,在我们的生产DLRS中,每小时只有3.08%的参数被更新。此外,模型更新通常会在一个时间窗口内覆盖某些参数(即时间局部性)。这是因为DLRS通常有热门项目和用户,它们的参数更新在短时间内占主导地位。

具体来说,参数更新缓存包含指向最近更新参数的指针。它利用主导版本向量(Dominator Version Vector,简称DVV)来判断同步请求到达时是否命中缓存。

缓存维护算法:缓存的维护保证两个不变性:(i)对于所有存在于分片中但不在缓存中的参数 $ p_{\text{uncached}} $,满足 $ DVV[id_{p_{\text{uncached}}}] \geq v_{p_{\text{uncached}}} $;(ii)对于所有缓存的参数 $ p_{\text{cached}} $,满足 $ DVV[id_{p_{\text{cached}}}] < v_{p_{\text{cached}}} $。

算法1描述了Ekko中参数更新缓存的维护。维护依赖于估计的更新传播时间 $ D_{\text{prop}} $。考虑更新缓存的函数:

1
UpdateCache
(第1行)。$ t_{\text{pruneto}} $ 是一个时间戳,用于描述 $ DVV_{\text{proposed}} $ —— 一个用于判断是否应修剪参数的版本向量。对于每个修改请求,如果修改参数 $ p $ 的参数版本 $ v_p = (t_p, id_p) $ 大于 $ DVV_{\text{proposed}}[id_p] $,则缓存记录指向该参数的指针(第5行)。否则,缓存将参数版本与 $ DVV $ 合并(第3行)。

图片名称

算法1

考虑修剪参数指针的函数:

1
PruneCache
(第7行)。该函数接收 $ D_{\text{prop}} $,这本质上允许Ekko利用对缓存命中率的在线观察来指导缓存修剪操作。假设我们希望在缓存大小超过限制时修剪参数指针,缓存首先确定 $ DVV’{\text{proposed}} $,它严格主导 $ DVV{\text{proposed}} $(第8行)。然后,缓存移除被 $ DVV’_{\text{proposed}} $ 主导的参数指针(第11行)。最后,缓存通过将 $ DVV $ 与修剪参数版本合并来更新 $ DVV $(第12行)。通过这种方式,Ekko实现了缓存大小的自适应管理,从而减少了其内存占用。

缓存命中分析:我们分析了参数更新何时命中缓存。假设副本 $ s $ 接收到来自副本 $ r $ 的同步请求,副本 $ r $ 持有分片知识 $ VV_r $。如果 $ VV_r $ 主导 $ DVV_s $,则请求命中缓存,其后续操作(例如选择参数)仅涉及缓存中的参数。

Ekko确保更新缓存的使用不会影响无日志参数同步的最终一致性:同步过程需要选择副本 $ s $ 中满足 $ v_p > VV_r[id_p] $ 的参数 $ p $。由于更新缓存保持不变性 $ DVV_s[id_{p_{\text{uncached}}}] \geq v_{p_{\text{uncached}}} $ 且 $ VV_r $ 主导 $ DVV_s $,因此该过程选择的参数集与之前的算法相同。

参数更新缓存在减少选择参数的成本方面特别有效。根据我们生产环境中部署的缓存跟踪数据,99.4%的同步请求可以命中缓存,从而将选择参数的成本降低了99%。

4.4.2 分片版本

我们引入分片版本来减少同步副本时的网络成本。分片版本捕获了副本上分片数据的部分因果关系,并且它们比版本向量小得多。我们可以允许副本维护分片版本列表,每个列表与一个邻居副本相关联。通过这种方式,副本可以通过交换和比较分片版本来识别可能更新的分片。正式地,我们将分片版本定义如下:

定义2(分片版本):分片版本 $ sv = (c, id) $ 是一个由计数器 $ c $ 和标识生成该版本的副本的 $ id $ 组成的对。计数器 $ c $ 在每个副本的每个分片中单调递增。对于同一个分片 $ s $,当且仅当 $ id_1 = id_2 $ 且 $ c_1 \geq c_2 $ 时,$ sv_1 \succeq sv_2 $。

分片版本维护:在初始化时,每个副本为其分片生成分片版本。当训练工作器发出参数更新时,副本会生成一个新的分片版本。由于每个分片都有一个主副本,因此在正常情况下,只有一个副本生成分片版本。

一旦接收到同步请求,响应副本(记为 $ s $)会回复其分片版本 $ sv_s $,同时附带 $ VV_s $ 和更新的参数。一旦请求副本(记为 $ r $)收到此回复,它会以原子方式完成以下操作:(1)将其分片知识 $ VV_r $ 与接收到的 $ VV_s $ 合并(合并结果记为 $ VV’_r $);(2)如果 $ VV’_r = VV_s $,则将其分片版本 $ sv’_r $ 更新为 $ sv_s $;否则,如果 $ VV’_r \neq VV_r $,则生成一个新的分片版本。需要注意的是,当 $ VV_r = VV_s $ 时,为了避免活锁,Ekko会根据确定性规则(例如选择数值较大的分片版本)从 $ s $ 和 $ r $ 中选择一个分片版本。

我们实现了簿记技术 [51],用于维护与不同副本相关联的分片版本列表。通过结合分片版本和簿记技术,Ekko可以有效地减少同步相关的网络流量。例如,在我们的一个生产DLRS中,Ekko在同步过程中过滤掉了98%的分片。

使用分片版本进行同步:我们讨论分片版本如何促进同步。Ekko维护一个不变性:只有当分片知识 $ VV_1 $ 主导 $ VV_2 $ 时,$ sv_1 \succeq sv_2 $ 才成立。因此,只有当 $ sv_r \nsucceq sv_s $ 时,副本 $ r $ 才需要与副本 $ s $ 同步分片。此外,考虑具有相同分片可比分片版本的不同副本,Ekko更倾向于与具有最大分片版本的副本同步,因为较大的分片版本表示参数的最新版本。

4.5 实现细节

WAN优化:Ekko针对地理分布式部署进行了优化,这种部署包括多个数据中心内部网络和一个跨数据中心的广域网(WAN)。为了提高在这种部署下的性能,Ekko采用了WAN优化的模型更新传播策略。该策略为P2P同步构建了一个灵活的通信拓扑。它允许每个数据中心使用Zookeeper [31] 为每个分片选举一个本地领导者。领导者从其他数据中心拉取模型更新,而其他副本则从该领导者拉取更新。通过这种方式,Ekko使得大部分同步流量通过带宽充足的数据中心内部网络,只有少量同步流量通过WAN。需要注意的是,参数同步的实现并不依赖于特定的通信拓扑。Ekko可以使用其他覆盖拓扑来进一步提高同步性能。

故障容忍:Ekko使用请求路由器来容忍故障。路由器决定客户端请求的路由,并通过心跳检测副本的健康状态。如果路由器推测某个副本发生故障(无论是完全故障还是性能下降 [30]),它会阻止客户端(推理服务器和训练服务器)向该副本发送请求。同时,路由器会跟踪集群中副本的分片知识。如果之前被怀疑故障的副本恢复并向路由器发送心跳,路由器将指示该副本与集群中更新充分的副本同步。当同步完成后,路由器会将客户端请求重新定向到该副本。如果某个副本丢失了状态,它将使用新的ID重新加入集群。如果训练服务器在给定时间内无法联系到路由器,它们将停止发送参数更新,从而在网络分区的情况下尽力保护模型参数免受分歧 [5]。


5 SLO保护机制

Ekko允许模型更新直接传播到推理集群中的参数服务器。然而,这为推荐服务的SLOs带来了两个挑战:(i)网络拥塞可能导致关键模型更新被延迟;(ii)基于小批量偏差数据的模型更新可能对推理结果产生不利影响。

本节介绍了保护推理SLOs免受网络拥塞和偏差更新影响的机制。我们首先定义SLOs(见§5.1),然后描述一个SLO感知的模型更新调度器(见§5.2),最后讨论一个处理偏差更新的推理模型状态管理器(见§5.3)。

5.1 DLRS中的SLOs

DLRS有两类主要的SLOs:

  • 新鲜度SLOs:衡量将新内容和用户纳入模型推理的延迟。这对于实时与用户交互的推荐服务(如TikTok和YouTube)至关重要。例如,这些服务通常需要及时捕捉新用户的兴趣,以确保他们有足够的参与度;否则,他们可能会因为失去兴趣而离开推荐应用。提高新鲜度SLOs通常会带来更好的用户体验。此外,新内容将获得更好的曝光,从而确保DLRS的繁荣。
  • 质量SLOs:衡量用户体验和参与度。它们对DLRS的盈利能力有直接影响。例如,这类目标包括观看的视频数量和用户观看时间。

图4描述了推理服务器如何影响新鲜度和质量SLOs。一旦接收到请求,推理服务器会选择相关的用户和项目嵌入,然后聚合这些嵌入并将聚合后的嵌入发送给一个DNN,该DNN返回推荐项目的分数。DLRS最终返回一个按分数排序的项目列表。在这种情况下,新鲜度SLO基于推荐项目的最新时间戳来衡量(理想情况下,该时间戳应尽可能接近当前时间)。质量SLO可以基于项目的观看时间和点击次数来衡量。在实践中,Ekko在线维护大量新鲜度和质量SLOs。这些SLOs的实现由DLRS应用开发者贡献。

图片名称

图4

5.2 SLO感知的模型更新调度器

Ekko通过SLO感知的模型更新调度器及其与P2P模型更新传播的集成,防止新鲜度和质量SLOs受到网络拥塞的影响。

5.2.1 模型更新的SLO感知优先级

Ekko在调度模型更新时计算一组优先级:

更新新鲜度优先级:Ekko计算更新新鲜度优先级 $ p_u $。该优先级基于以下观察设计:如果参数是最近创建的,则具有高优先级;否则,优先级相对较低。这是因为新创建的参数对推理结果的影响比长期服务的参数更大。例如,如果用户的嵌入在推理服务器中不可用,但她的请求已经到达,DLRS将无法回答该请求,从而影响质量SLOs。另一个例子是,如果推理服务器上的嵌入表未包含某个项目,DLRS将不会推荐该项目,从而影响新鲜度SLOs。

更新重要性优先级:Ekko根据梯度 $ g $ 为每个模型更新计算更新重要性优先级 $ p_g $。该优先级最初受到研究的启发,研究表明梯度大小 $ \mid g \mid$ 如何影响DNN的推理结果 [6, 28]。然而,在Ekko中,简单地采用梯度大小是不够的。作为一个共享的DLRS,Ekko在共享网络上复用了来自不同模型的更新。因此,Ekko必须有办法比较具有不同分布的梯度大小。为此,我们定义 $ p_g = \mid g\mid / \bar{\mid g \mid} $,其中 $ g $ 表示梯度的1-范数,$ \bar{\mid g \mid} $ 表示最近模型更新的平均梯度大小。直观地说,该定义对梯度大小进行了归一化,从而使它们具有可比性。

模型优先级:在DLRS中,模型通常以不同的速率接收推理请求,这表明它们在衡量SLOs整体满意度中的重要性不同。为了考虑这一点,Ekko允许处理大多数请求的模型被分配更高的优先级,而很少接收请求的模型优先级较低。为此,我们定义模型优先级为 $ p_m = c_m / \sum_{i=1}^M c_i $,其中 $ c_m $ 是模型 $ m $ 的请求计数,$ \sum_{i=1}^M c_i $ 表示所有 $ M $ 个模型的总请求计数。

优先级组合:我们将上述优先级组合起来,计算模型更新的总体优先级 $ p $:

\[p = (p_g + p_u) \cdot p_m\]

其中,重要性优先级 $ p_g $ 和新鲜度优先级 $ p_u $ 都已归一化,以便可以相加。然后将它们的和乘以模型优先级 $ p_m $。

需要注意的是,Ekko并不要求用户仅使用上述优先级。一些Ekko用户有自定义的优先级定义,包括更新计数、更新间隔和嵌入表中参数的位置。这些自定义优先级针对某些DLRS工作负载 [69],但它们不够通用,无法包含在默认设置中。Ekko通过支持用户定义函数(UDFs)来定义优先级,从而适应这些自定义优先级。

5.2.2 调度器实现

模型更新调度器在每次更新生成时计算其优先级。它需要确保优先级计算的开销可以忽略不计,否则它可能成为模型更新的瓶颈。为了实现这一点,调度器将优先级相关统计信息(例如每个模型 $ m $ 的 $ \bar{\mid g \mid} $ 和 $ p_m $)的维护卸载到一个后台线程中。此外,为了限制内存开销,它使用分位数草图(例如DDSketch [52])计算时间窗口内的第 $ k $ 百分位优先级 $ p_k $,其中 $ k $ 是由算法管理者设置的比率。Ekko使用WebAssembly [27] 执行用户定义的优先级计算,以实现UDFs之间的高效隔离。

将调度器集成到参数服务器中:为了实现优先级调度的承诺,我们必须将调度器集成到已启用无日志P2P同步的参数服务器中。为此,我们为每个参数引入重要版本(记为 $ sigv $),并为每个分片引入重要知识(记为 $ SVV $)。此外,Ekko为每个分片分配一个临时重要参数存储 $ store_{\text{significant}} $ 和相应的临时重要知识 $ T SVV $,以支持带有优先级调度的P2P同步。

算法2描述了带有优先级调度器的无日志P2P同步。假设我们有一个来自副本的模型更新,Ekko计算 $ p $。如果 $ p \geq p_k $,Ekko设置 $ sigv = v $,其中 $ v $ 是该更新的参数版本;否则,$ sigv $ 保持不变。然后,Ekko使用 $ sigv $ 构建 $ SVV_{\text{other}} $ 并调用

1
UPDATESVV
函数(第1行)。如果Ekko在同步中未应用优先级,副本将交换 $ SVV $ 并执行
1
UPDATESVV
函数。在将参数写入持久参数存储时,Ekko通过执行
1
WRITESTOREPARAMETER
函数修剪被覆盖的参数(第4行)。需要注意的是,副本会估计模型更新到达自身所需的时间。因此,当网络拥塞发生时,服务器将出现更新超时。在这种情况下,Ekko使用
1
PRIORITISEDSYNC
函数(第15行)触发同步中的优先级调度器。一旦接收到请求,副本优先返回重要参数存储中的参数。

图片名称

算法2

5.3 推理模型状态管理器

Ekko使用推理模型状态管理器来保护SLOs免受有害模型更新的影响。该管理器监控推理模型的健康状况(即质量SLOs),并根据需求进行低延迟的模型状态回滚。

5.3.1 监控模型健康状况

Ekko基于以下思想监控模型健康状况:对于DLRS应用程序,它为推理模型创建基线模型。基线模型处理少量用户流量(通常小于1%)。它们与在线推理模型不同,因为它们携带延迟的状态。换句话说,它们使用先前的训练样本进行训练,通常比当前推理模型的训练样本早几分钟。

Ekko基于从推理服务器和客户端(例如用户设备)收集的指标来衡量模型健康状况。为了计算这些指标,Ekko定义了自定义的水印和触发器 [3]。其状态管理器仅在确信时(即观察到一段时间的监控数据)发出异常检测事件。需要注意的是,Ekko并不局限于使用特定的异常检测算法。它支持自定义的异常检测算法,例如常用于时间序列数据的算法 [61]。

我们将模型状态(即健康或不健康)的转换建模为复制状态机 [63],并在模型状态管理器中实现。该管理器通过检查与健康状况相关的指标和模型更新延迟,在时间戳 $ t $ 处评估并记录模型健康状况。时间戳 $ t $ 单调递增。管理器判断模型状态是健康、损坏还是不确定。当管理器确信模型状态发生变化(即健康或损坏)时,它会将此信息记录在其复制状态中。如果模型状态已损坏,管理器将客户端请求重定向到其他健康的推理模型,然后启动模型状态回滚。

5.3.2 低延迟模型状态回滚

Ekko使用见证服务器以低延迟回滚损坏的模型状态。见证服务器参与副本同步,但不参与模型训练。与参数服务器不同,见证服务器(i)不会立即将更新的参数刷新到参数存储中,(ii)在同步中不运行优先级调度。具体来说,Ekko将尚未刷新的参数更新插入日志中。日志附带有同步的物理时间戳(记为 $ t $)。如果在短时间内有多个同步操作,Ekko会合并它们的日志以节省空间。

模型状态管理器控制见证服务器启动状态回滚。假设模型状态在时间 $ t $ 处被认为是健康的,见证服务器会找到一个满足以下两个条件的时间戳 $ t_{\text{max}} $:(i)它小于等于 $ t $;(ii)它不在任何发生损坏状态的时间间隔内。然后,见证服务器刷新时间戳小于等于 $ t_{\text{max}} $ 的日志。模型状态管理器记录此 $ t_{\text{max}} $,$ t_{\text{max}} $ 稍后将用于见证服务器以恢复健康的模型状态。通过这种方式,我们可以确保见证服务器上的参数存储 $ store_{\text{healthy}} $ 始终保存健康的模型状态。

回滚过程:图5展示了回滚模型状态的过程。假设发现某个模型已损坏,模型状态管理器首先通知参数服务器停止接受该模型的训练请求( 1 )。然后,它指示参数服务器停止基于优先级的同步,清除其 $ store_{\text{significant}} $,并重置 $ T SVV = SVV $。接着,管理器等待参数服务器和见证服务器上的模型分片收敛。随后,管理器选择见证服务器启动状态回滚( 2 )。我们需要确保恢复的模型分片可以一起使用。因此,管理器仅选择见证服务器上 $ store_{\text{healthy}} $ 中 $ t_{\text{max}} $ 在一个小时间窗口内的分片。

图片名称

图5

一个关键设计是,见证服务器会比较 $ store_{\text{healthy}} $ 和其当前状态以找到状态差异( 3 )。由于更新参数的局部性,这种差异通常很小。因此,我们只需将差异写入参数服务器以恢复状态。我们需要确保写入操作能够成功。因此,写入的参数被分配比参数服务器上当前参数版本更大的参数版本( 4 )。之后,管理器等待参数服务器和见证服务器上的模型分片收敛。最后,Ekko会在恢复的模型上恢复少量流量。当该模型的健康状况指标恢复正常时,管理器通知参数服务器恢复接受请求( 5 )。

需要注意的是,如果见证服务器发生故障,其未刷新的更新日志将被丢弃。这有助于Ekko防止潜在的损坏更新被刷新。如果参数服务器或见证服务器发生故障(或重新加入集群),回滚过程将重新执行。


6 评估

在本节中,我们通过测试床和生产环境实验评估Ekko的以下方面:(i)Ekko的更新延迟及其与数据中心数量的可扩展性(§6.1.1);(ii)Ekko在异构WAN中的更新延迟(§6.1.1);(iii)Ekko中实现的优化的性能分解(§6.1.2);(iv)Ekko在大规模生产DLRS中的实际延迟和可用性(§6.2.1);(v)低延迟模型更新在在线服务中的好处(§6.2.1);(vi)在繁忙网络中使用模型更新调度器的有效性(§6.2.2);(vii)模型损坏时回滚模型的延迟(§6.2.2)。

除非另有说明,更新延迟是指更新提交时间与更新在所有副本中可见时间 [68] 之间的最大时间差(无故障场景)。在所有实验中,我们测量更新延迟并报告所有更新的平均值。

6.1 测试床实验

我们在一个30台服务器的集群中进行测试床实验。每台服务器配备24核CPU、64 GB内存和5 Gbps网络链路。我们将每三台服务器分组为一个数据中心(DC),以模拟多DC场景,最多形成10个DC。我们选择其中一个DC作为训练导向的DC,该DC从一台服务器(充当DLRS客户端)接收模型更新。我们让其他DC作为推理导向的DC,并将它们与训练导向的DC连接。DC间的带宽为4,800 Mbps(除非另有说明),模拟WAN。

我们的测试床实验包括两个工作负载。第一个工作负载训练一个通常用于生产环境的大型排序模型。在此工作负载中,我们选择分片大小为0.4 MB。第二个工作负载使用按时间顺序排序的Criteo Terabyte Click Logs [17] 训练Wide & Deep模型 [10]。我们使用21天的数据日志初始化嵌入表。为了确保实验可重复,我们记录模型更新轨迹并在实验期间重放它们。

6.1.1 更新延迟

我们在同构WAN和异构WAN中评估Ekko的更新延迟。这两种WAN在现实世界中都很常见。第一个基线是Adam [11],它通常用于参数服务器中,通过两阶段提交协议同步模型更新。我们的Adam实现移除了更新广播之间的等待时间,从而提高了网络利用率。第二个基线是Checkpoint-Broadcast,这是DLRS中应用模型更新的事实标准方法 [1, 21]。我们省略了与通用键值存储(如PaxosStore [73] 和 TiKV [29])的实验,这些存储提供写入操作的线性一致性。我们的早期采用结果表明,这些键值存储的写入吞吐量较低,比生产DLRS所需的吞吐量低几个数量级。

为了公平比较,Ekko和基线都使用DRAM进行存储 [57],并采用相同的主节点分配和负载均衡方案。我们进一步确保它们的传播都是网络受限的,并使用相同数量的分片。

同构WAN结果:我们首先在同构WAN中将Ekko与Adam进行比较。我们分别测量了1个DC(3个副本)、5个DC(15个副本)和10个DC(30个副本)的延迟。图6a和图6b显示了结果。可以看出,Ekko在生产环境和Criteo工作负载中的延迟显著低于Adam。具体来说,在运行生产工作负载的10个DC中,Ekko实现了2.6秒的延迟,比Adam的18.8秒延迟低7倍。我们还观察到,随着DC数量的增加,Ekko和Adam之间的性能差距也在扩大。原因是Ekko具有可扩展的P2P同步架构,并针对WAN优化了其传播拓扑。相比之下,Adam依赖主副本发送更新,受限于训练DC中有限的带宽。

图片名称

图6

我们还将Ekko与Checkpoint-Broadcast进行比较。根据我们的实验结果,Checkpoint-Broadcast在WAN中同步4 GB参数需要超过7秒。总参数为113 GB。在10个DC的情况下,训练DC需要向所有其他推理DC发送113×9=1,017 GB的参数。因此,训练DC需要花费超过29分钟完成参数广播(因为WAN的带宽为4,800 Mbps)。这种广播延迟比Ekko实现的秒级延迟(例如2.6秒)高几个数量级。

异构WAN结果:然后我们在异构WAN中评估Ekko和基线。在此WAN中,我们将DC间带宽默认设置为256 Mbps。为了引入异构性,我们选择训练DC与另一个推理DC之间的一个链路,并将其带宽设置为128 Mbps。实验在每个DC中运行3个副本,总共10个DC。如图7a和图7b所示,Ekko在生产环境和Criteo工作负载中都能有效缓解慢速异构链路的影响。它允许副本以独立的速率同步,保持秒级同步延迟。这种低延迟性能显示了Ekko的无日志P2P同步在缓解异构网络路径不利影响方面的有效性。相比之下,Adam在WAN中受到慢速路径的影响,导致在生产工作负载中花费超过150秒同步副本,在Criteo工作负载中花费超过100秒。

图片名称

图7

除了Adam,我们还考虑了其他基于日志的同步方法,例如Multi-Paxos [9]。我们可以让这些方法将一段时间内到达的更新聚合到一个日志条目中,以节省WAN带宽。然而,这些方法仍然受到异构链路的影响。这是因为它们基于网络中最慢的链路选择聚合间隔,导致许多其他链路利用率不足。


6.1.2 性能分解

我们希望了解Ekko同步中各个组件的有效性。因此,我们对生产工作负载(10个DC)进行了性能分解分析。我们首先配置Ekko仅使用分片知识(见§4.3)进行同步。此配置是本实验的基线,相当于最先进的P2P同步技术——版本向量(VV)[50, 51]。

图8显示了结果。仅使用VV时,Ekko需要76.3秒同步所有参数。启用更新缓存(§4.4.1)后,Ekko将延迟减少到27.4秒(即2.8倍加速)。通过分析更新缓存的跟踪数据,我们发现缓存在生产工作负载中实现了100%的命中率。需要注意的是,测试床服务器上每个副本的总内存比生产服务器少10倍,这意味着分片中的参数比实际场景中少。随着分片中参数数量的增加,VV将花费更多时间进行同步,而更新缓存可以保持低延迟。

图8还显示了分片版本(§4.4.2)的效果。通过进一步启用分片版本,Ekko将延迟从27.4秒减少到6.0秒(即4.6倍加速)。这表明跳过未更新的分片可以有效减少同步带来的网络消耗。

图片名称

图8

最后,启用WAN优化(§4.5)后,Ekko将延迟从6.0秒进一步减少到2.6秒(即2.3倍加速)。这表明P2P同步必须考虑WAN中每个链路的可用带宽,否则无法充分发挥其潜力。总之,启用Ekko中的所有组件使P2P同步总共加速了29.3倍(即2.6秒 vs. 76.3秒)。


6.2 生产集群实验

我们已经将Ekko部署到生产中超过一年。生产环境包括分布在6个地理分布式DC中的4,600台服务器。截至2022年,我们已使用Ekko支持多种推荐服务,包括短视频推荐、搜索和广告。每天有超过10亿用户使用这些服务。在本节中,我们报告Ekko在该生产环境中的性能。

6.2.1 模型更新

我们从生产环境中收集跟踪数据,以分析Ekko在更新模型中的性能。生产环境中有数百个DLRS模型(总共40 TB参数或2500亿个键值对)。每个参数分片的大小从0.1 MB到20 MB不等,具体取决于模型大小。Ekko每秒可以执行10亿次更新(即212 GB/s)。

关于延迟性能,Ekko在所有DC中同步参数花费2.4秒,仅在训练DC中花费0.7秒。同步流量仅占总网络流量的3.0%,反映了Ekko作为参数服务器后台同步服务的有效性。Ekko的低延迟、高吞吐量性能并未影响系统可用性。自部署以来,Ekko在参数读写操作中实现了>99.999%的可用性。

更新缓存分析:我们对更新缓存在各种现实推荐服务中的性能特别感兴趣。我们的跟踪数据显示:更新缓存只需缓存0.13%-0.2%的参数,即可实现>99.4%的命中率。这些性能结果验证了更新局部性的广泛存在。事实上,我们的生产推荐服务每小时平均更新3.08%的参数。

我们选择一个更新密集的DLRS模型来揭示最坏情况下的更新局部性。图9显示了480分钟窗口内更新参数的比例。此时间窗口涵盖了我们生产DLRS一天中最繁忙的时间。我们报告了不同时间间隔的比例。在10分钟间隔内,只有4.3%的参数被更新,并且这一比例在480分钟的时间窗口内保持稳定。在60分钟间隔内,我们观察到类似的模式,比例仅略微增加到约10%。实际上,许多其他模型的更新工作负载较少,其更新参数的比例低于此模型。

图片名称

图9

低延迟模型更新的好处:我们希望了解低延迟模型更新是否真的能提高推荐服务的质量。为此,我们在短视频推荐服务 [65] 中进行了为期15天的在线A/B测试 [64]。该服务包括一个多阶段管道 [10, 15]。我们仅在排序阶段进行实验。我们将排序模型分为两组:实验组和对照组。每组接收1%的总流量用于训练和推理。我们通过将实时日志缓存到分布式文件系统中,将用于训练对照组模型的数据(即事件日志)延迟20分钟。

我们的A/B测试结果显示:与对照组相比,实验组在所有推荐视频中新鲜视频(发布在一小时内)的比例增加了3.82%。这意味着系统向实验组用户推荐了更多新鲜视频。此外,实验组用户滑动视频列表的比例减少了1.30%,而浏览视频的总时间增加了1.68%。这意味着实验组用户花费更多时间观看视频,并对推荐视频更感兴趣。

最后,实验组用户点击评论的比例增加了2.17%。这意味着实验组中的用户互动增加。值得注意的是,在现实世界的多阶段DLRS中,1%-3%的改进被认为是显著的 [10, 21, 71]。事实上,自从在DLRS的更多阶段启用低延迟模型更新以来,我们观察到推荐质量的更显著改进。

6.2.2 SLO保护机制

我们还运行A/B测试来评估Ekko的SLO保护机制的有效性。

图片名称

图10

SLO感知的模型更新调度器:我们将排序模型分为实验组(启用优先级调度器)和对照组。每组有1%的训练和推理流量,并部署到专用服务器以避免流量干扰。我们监控反映新鲜度SLOs的指标:推荐结果中新鲜视频(即过去一小时内发布)的数量。为了模拟网络拥塞,我们将模型更新的可用带宽减少了92%。模型更新调度器(i)使用默认的优先级计算规则(定义见§5.2.1),(ii)将百分位优先级 $ k $ 设置为99($ k $ 定义见§5.2.2)。

A/B测试结果显示,在实验组中,Ekko将同步流量减少了92%,并保持重要更新的低延迟。相比之下,对照组在繁忙网络中发送模型更新时无法区分更新。因此,对照组延迟了SLO关键更新,其SLO指标下降了2.32%。这种下降在实践中是显著的,因为该SLO指标是决定DLRS利润的关键因素。

在线模型状态回滚:我们评估在线回滚模型状态的延迟。我们将Ekko与checkpoint恢复方法进行比较。为了公平比较,我们让回滚延迟排除(i)Ekko中收集SLO指标的时间,以及(ii)等待分歧参数收敛的时间。我们部署了5个见证服务器。对于每个见证服务器,我们分配113 GB参数和800 Mbps网络带宽。

在实验过程中,我们通知Ekko的模型状态管理器将DLRS模型的状态回滚到1分钟前的版本。然后,管理器通知所有见证服务器识别过去1分钟内更新的参数。因此,见证服务器只需重新加载当前状态与早期状态之间的差异。因此,整个回滚操作仅需6.4秒即可完成。相比之下,checkpoint恢复方法无法感知模型状态的最近更新。因此,它必须重新加载整个状态,花费1,157秒完成(比Ekko慢180倍)。

7 相关工作

数据复制系统:Ekko中探索的参数同步问题与之前的数据复制工作相关。现有的数据复制系统通常探索如何利用应用程序的特性来提高数据复制的延迟性能 [13, 40, 45, 53]。例如,Egalitarian Paxos [53] 利用了状态机命令的低干扰率,Gemini [40] 利用了混合一致性操作,而COPS [45] 和PNUTS [13] 则利用了互联网服务对宽松一致性的容忍性。与这些系统不同,Ekko利用了DLRS特有的模型更新局部性和最终一致性模型来加速模型参数(而非通用数据)的同步,使Ekko在设计空间中独树一帜。

ML系统中的带宽节省技术:优先处理模型更新的问题与分布式ML系统中的带宽节省技术相关。这些技术通常涉及梯度压缩 [4, 6, 28, 44],在繁忙网络中优先处理大梯度,预期这些大梯度对训练模型的最终准确性有显著影响。与这些技术不同,Ekko针对模型推理场景,人们关心的是众多推理SLO指标,而不仅仅是模型的准确性。因此,Ekko不仅依赖梯度大小,还进一步考虑模型的新鲜度和优先级来调度模型更新。

ML系统中的SLO感知调度:在调度中考虑SLOs的问题在之前的ML系统中已有探索。模型服务系统通常将推理延迟作为主要SLO,以指导与推理相关的计算任务的调度 [16, 25, 70]。模型训练系统,例如Pollux [59] 和KungFu [48],使用ML特定的SLOs(例如训练吞吐量和梯度统计)来决定如何调度训练工作器。与这些系统相比,Ekko关注新鲜度和质量SLOs,并支持在调度模型更新时使用这些SLOs。


8 结论

本文提出了Ekko,一种能够在秒级延迟下更新大规模模型参数的新型DLRS。Ekko具有高效的P2P模型更新算法,能够协调数十亿次模型更新,并将其高效传播到地理分布的数据中心中的副本。此外,它还具备SLO保护机制,能够保护模型状态免受网络拥塞和在线有害模型更新的影响。实验结果表明,Ekko比最先进的DLRS快几个数量级,证明了其新颖设计的有效性。

附录

摘要

观看时间是视频推荐系统中衡量用户满意度的重要指标。然而,将观看时间作为目标变量进行预测常常受到其高度不平衡分布的阻碍,对于较大的目标值观察稀缺,而对于小值样本过多。最先进的观看时间预测模型将连续的观看时间离散化为一组桶,以考虑观看时间的分布。然而,如何从连续的观看时间分布中创建这些离散桶的问题尚未得到充分研究,现有的离散化方法要么存在较大的学习误差(learning error),要么存在较大的恢复误差(restoration error)。为了解决这一挑战,我们提出了一个带有错误自适应离散化(CREAD)的分类-恢复框架,以准确预测观看时间。所提出的框架包含一个离散化模块、一个分类模块和一个恢复模块。它通过多个分类问题来预测观看时间。离散化过程是CREAD框架的关键贡献。我们从理论上分析了离散化对学习误差和恢复误差的影响,然后提出了错误自适应离散化(EAD:error-adaptive discretization)技术,以更好地平衡这两种误差,这比传统的离散化方法实现了更好的性能。我们在公共数据集和工业数据集上进行了详细的离线评估,两者都显示出所提出方法的性能提升。此外,我们已经将我们的框架全面推广到快手应用,这是一个在线视频平台,通过A/B测试,用户的视频观看时间显著增加了0.29%。这些结果突出了CREAD框架在视频推荐系统中预测观看时间的有效性。

1 引言

推荐系统在匹配用户感兴趣的item方面取得了巨大成功(Herlocker 等人,2004)。其中最受欢迎的应用之一是短视频社交媒体(Tang 等人,2017;Wu, Rizoiu, 和 Xie 2018),如 TikTok 和 Instagram Reels,用户屏幕上会出现短视频,而无需任何主动操作,例如点击。因此,传统的指标如点击率不再适用。直观地说,观看时间(watch time)成为衡量用户参与度的关键指标(Covington, Adams, 和 Sargin 2016)。为确保最佳的用户体验,准确预测在线推荐系统中的观看时间至关重要。通过这样做,这些平台可以更好地了解用户偏好,并根据他们的兴趣提供个性化的视频推荐。大量研究(Zhan 等人,2022;Gong 等人,2022;Lin 等人,2022;Wang 等人,2022;Cai 等人,2023;Zhao 等人,2023)致力于开发神经网络模型,并显著提高了传统回归方法的预测准确性。

由于用户观看时间值的连续性和广泛性,观看时间的预测提出了一个回归问题,这增加了对异常值的敏感性和潜在预测偏差。如图1所示,短视频的观看时间分布是右偏的,大量的集中在短时间内:30% 的观看时间在 3 秒内,80% 在 32 秒内。这种分布给早期的观看时间预测尝试(Zhan 等人,2022;Covington, Adams, 和 Sargin 2016)带来了挑战,它们忽视了回归问题的长尾特性,因此对于尾部实例产生了次优结果。此外,在推荐系统中,预测之间的序数关系起着关键作用,以观看时间为视频比较的指标,突出了序数排序的重要性。然而,标准的回归损失如 ℓ1 和 ℓ2 并不考虑排序比较,只关注差异的大小。

图片名称

图1 观看时间的概率密度图

因此,在不平衡的连续标签分布中保持实时推荐系统中的预测准确性和排序效率面临巨大挑战。

为了解决这个问题,我们引入了一个有效的基于分类-恢复(classification-restoration)的框架,用于从现实世界中的不平衡连续目标中学习。该框架包含三个关键组件:

  • 一个有效的离散化:将连续标签转换为序数区间,
  • 一个分类模块:训练多个二元分类器跨越这些段以确保排序准确性,
  • 一个恢复模块:用于根据这些分类器的预测预测观看时间

这种方法的挑战在于从连续分布中创建离散类别的模糊性。这涉及到解决两种误差类型:

  • 与样本桶计数相关的学习误差
  • 影响从离散化预测中估计观看时间的恢复误差

平衡这些误差证明是复杂的;较窄的桶间隔通过降低样本概率降低了学习误差,而较宽的间隔减少了信息并提高了恢复误差。我们检查了离散化对学习和恢复误差的影响,并提出了一种错误自适应离散化(EAD)方法,以在实际分布中协调这些误差。

我们全面的框架,命名为带有 EAD 的 Classification-Restoration(CREAD),提供了一个适用于现有学习方法(如 D2Q)的多功能扩展。

总之,我们的主要贡献是:

  • 我们提出了一个用于学习观看时间的序数信息的通用分类-恢复框架,以及减少分类和恢复误差的训练算法。
  • 我们分析了离散化引入的误差界限,并提出了一种新的离散化方法,根据现实数据集分布平衡学习误差和恢复误差。

现实大规模数据集的离线和在线实验表明,我们的框架与最先进的方法相比取得了竞争性的结果。

2 相关工作

2.1 观看时间预测

观看时间预测的任务是:预测用户在给定用户画像、互动历史和一系列候选短视频的情况下的观看时间。

  • VR:值回归(Value Regression)直接预测观看时间的绝对值,其中学习函数的准确性通过均方误差来评估。
  • WLR:Covington, Adams 和 Sargin (2016) 将观看时间作为样本权重纳入(WLR)印象深刻的视频的逻辑回归中,将直接回归观看时间转化为学习视频点击率的概率。然而,这种假设仅在展示率较低时成立,而不适用于短视频设置中自动播放的视频。
  • D2Q:最近,Zhan 等人(2022)研究了视频推荐中观看时间预测的持续时间偏差(D2Q),通过基于持续时间的分箱数据去除不需要的偏差。尽管他们基于等频的方法去除了偏差,但他们忽略了观看时间不平衡分布对长尾样本的影响,导致与头部样本相比准确性较低。尽管有效,但他们没有利用在离散化过程中丢失的额外桶内信息,而我们在建模过程中施加了一个错误自适应框架

2.2 通过分类进行回归

最近,有一种趋势是将回归问题表述为一组分类问题,这在直接回归上取得了显著改进。第一个相关工作是序数回归(OR: Ordinal Regression)。它最初用于因变量表现出相对排序的分类问题,后来扩展到包括年龄估计(Niu 等人,2016;Beckham 和 Pal,2017)和深度估计(Diaz 和 Marathe,2019)等多个领域。

OR的一个关键问题是:从分布中创建离散类别的模糊性。大多数工作使用固定标准方法,如等宽或等频离散化来划分连续变量,而其他人则手动选择多个阈值(Crammer 和 Singer,2001;Shashua 和 Levin,2002)作为超参数。正如第 4 节所分析的,这些方法引入了较大的离散化误差,特别是当数据遵循不平衡分布时。相比之下,我们的方法通过提出一种自适应离散化方法来最小化总误差,从而缓解了这个问题。

3 方法

记号。设: ${(x_i, y_i)}^N_{i=1}$ 为训练集,

其中:

  • $y_i \in Y \subset \mathbb{R}^+$ :是对应的真实观看时间(ground truth)
  • $x_i \in \mathbb{R}^d$: 表示第 $i$ 个输入,包括与用户相关的特征(如人口统计特征和浏览历史)和与视频相关的特征(如标签和转发计数)。

不失一般性,我们假设:目标变量的值域通过 $T_{\text{max}} \in \mathbb{R}^+$ 来限制。

我们使用 $M-1$ 个阈值 $ \lbrace t_m \rbrace_{m=1}^{M-1} $ 将值域划分为$M$个离散桶 $ D \equiv \lbrace d_m \rbrace^M_{m=1} $,其中:

  • 第 $m$ 个桶 $d_m = [t_{m-1}, t_m)$ 对于 $m = 1, \cdots, M$,并且 $t_0 = 0$,$t_M = T_{\text{max}}$。
  • 设 $\widehat{y}_i$ 表示 $x_i$ 的预测观看时间。

设:

  • $1(\cdot)$ 表示指示函数。

为了简化,当我们不特指某个样本时,我们省略下标 $i$。

3.1 整体框架

CREAD框架,如图2 所示,包括三个模块,即离散化、分类和恢复。以下,我们解释每个组件的设计。

图片名称

图2 CREAD框架

离散化(Discretization)

这个模块是一个独立的预处理模块,与训练和评估过程无关。它根据数据分布获得阈值 ${t_m}{m=1}^{M-1}$,并将目标域 $Y$ 分割成 $M$ 个不重叠的桶 $D \equiv {d_m = [t{m-1}, t_m)}^M_{m=1}$。这些桶用于将观看时间 $y$ 转换为 $m$ 个离散标签:

\[y_m = 1(y > t_m). \quad (1)\]

离散化策略对预测精度至关重要,我们将在第 4.3 节详细讨论。

分类(Classification)

训练 $M$ 个分类器来预测观看时间 $y$ 是否大于第 $m$ 个阈值 $t_m$,即方程 (1) 中的 $y_m$,并输出一系列概率:

\[\widehat{\phi}_m(x_i; \Theta_m) = P(y > t_m | x_i), \quad 1 \leq i \leq N. \quad (2)\]

分类器是具有可学习参数 $\Theta_m$ 的神经网络。我们在第 3.2 节介绍如何训练这些模型。

恢复(Restoration)

给定 $\lbrace \widehat{\phi}m \rbrace^M{m=1}$,我们能够恢复预测的观看时间。恢复基于以下期望的事实:

\[\begin{align} E(y | x_i) & = \int_{t=0}^{t_M} tP(y = t | x_i)dt \\ & = \int_{t=0}^{t_M} P(y > t | x_i)dt \\ & \approx \sum_{m=1}^M P(y > t_m | x_i) (t_m - t_{m-1}). \end{align} \quad (3)\]

根据方程 (2) 中的 $\widehat{\phi}_m$ 的定义,我们可以从这些离散预测 $\widehat{\phi}_m$ 重建预测的观看时间:

\[\widehat{y} = \sum_{m=1}^M \widehat{\phi}_m (t_m - t_{m-1}). \quad (4)\]

3.2 模型训练

这里,我们提供了 $M$ 个分类器训练中的loss函数。损失函数包含三部分,其中:

第一部分是:通过标准分类的cross-entropy loss。

\[L_{ce} = \sum_{m=1}^M -y_m \log(\widehat{\phi}_m) - (1 - y_m) \log(1 - \widehat{\phi}_m). \quad (5)\]

第二部分是:restore loss,以减少方程 (4) 中重建观看时间的误差:

\[L_{restore} = \ell(\widehat{y}, y), \quad (6)\]

其中:

  • $\ell$ 是衡量 $\widehat{y}$ 到 $y$ 偏差的损失函数。

我们发现使用 Huber 损失 (Huber 1992) 作为 $L_{restore}$ 是有益的,这将在第 5.3 节详细分析。

第三部分是:通过序数先验的正则化项。根据定义,$M$ 个分类器的输出 $\lbrace \widehat{\phi}m \rbrace{m=1}^M$ 有一个先验,即 $\widehat{\phi}_m$ 随着 $m$ 的增长而单调递减。因此,我们通过最小化以下正则化项将先验纳入所提出的框架:

\[L_{ord} = \sum_{m=1}^{M-1} \max(\widehat{\phi}_{m+1} - \widehat{\phi}_m, 0). \quad (7)\]

总之,最终的优化目标是:

\[L_{CREAD} = \lambda_{ce}L_{ce} + \lambda_{restore}L_{restore} + \lambda_{ord}L_{ord}, \quad (8)\]

其中:

  • $\lambda_{ce}$、$\lambda_{restore}$ 和 $\lambda_{ord}$ 是超参数。

3.3 离散化的挑战

在 CREAD 框架中,一个关键模块是离散化模块,离散化方法在很大程度上影响最终预测精度。如图3 所示,离散化引入了两种误差:

图片名称

图3 离散化中的两种error

  • 学习误差:由于每个桶中的实例数量是有限的,$M$ 个分类器不能无限精确。随着我们增加桶的数量 $M$,落入每个桶的实例数量减少,从而限制了分类性能
  • 恢复误差:方程 (4) 中的恢复是期望的一个近似函数,省略了每个桶 [$t_{m-1}$, $t_m$] 中的详细概率密度,这也会引入误差。

不幸的是,这两种误差不能同时减少。为了减少学习误差,需要更大的桶宽度,导致更大的恢复误差(见图 3)。现有方法通常使用等宽或等频方法 (Gai 等人,2017) 来启发式地设置离散化集 $D$。我们将在第 4 节展示等宽和等频方法都不能很好地平衡这两种误差,并提出我们的 EAD 方法。

4 在离散化中平衡误差

本节旨在平衡离散化过程中引入的误差。我们首先提供对离散化过程中引入的学习误差和恢复误差的理论分析,然后提出EAD方法来有效平衡这两种误差。

4.1 离散化误差的分解

假设:训练数据集 $\lbrace (x_i, y_i)\rbrace^N_{i=1} \sim \mu(x, y) = \mu(x)\mu(y \mid x)$ 是独立同分布的。设:

  • $p_m(x) = P(y \in d_m \mid x)$: 表示标签 $y$ 属于第 $m$ 个桶 $d_m$ 给定 $x$ 的概率。
  • $v_m(x) = E(y \mid x, y \in d_m)$: 是样本 $x$ 的观看时间的期望值,假设它属于第 $m$ 个桶。
  • $w_m = E_{x \sim \mu(x)}v_m(x)$: 表示区间 $d_m$ 中观看时间的期望值

我们添加帽子上标来表示预测值,例如:

  • $\widehat{p}_m(x)$ 作为 $p_m(x)$ 的预测
  • $\widehat{w}_m$ 作为 $w_m$ 的预测。

然后我们可以将观看时间表示为:

\[\widehat{y} = \sum_m \widehat{p}_m(x) \widehat{w}_m.\]

注意,这种形式等同于方程 (4) 中的累积形式,其中:

\[\widehat{p}_m = \widehat{\phi}_m - \widehat{\phi}_{m-1}\]

现在我们的目标是:估计预测观看时间 $\widehat{y}$ 和真实值 $y$ 之间的误差界限。为了实现这一点,我们首先提供一个误差分解:

引理 4.1。假设 $\widehat{p}_m(x)$ 和 $\widehat{w}_m$ 分别是 $p_m(x)$ 和 $w_m$ 的无偏估计,我们有:

\(E(\widehat{y} - y)^2 = V_p + V_w + V_b + V_y,\) …(9)

其中:

\[V_p = E_x[E_{\widehat{p}}E_{\widehat{w}}\sum_m (\widehat{p}_m(x) - p_m(x)) \widehat{w}_m]^2,\] \[V_w = E_x[E_{\widehat{w}}\sum_m p_m(x) (\widehat{w}_m - w_m)]^2,\]

\(V_b = E_x[\sum_m p_m(x) (w_m - v_m(x))^2],\) \(V_y = E_{x,y}[\sum_m p_m(x)v_m(x) - y]^2.\)

详细证明请参见附录 A。直观上,

  • $V_p$ 由学习误差决定,即 $y$ 落入每个桶的概率 $p_m(x)$。
  • $V_w$ 描述了学习误差对代表性值 $\widehat{w}_m$ 的影响,
  • $V_b$ 是由离散化重建观看时间引起的误差,
  • $V_y$ 是观看时间 $y$ 的内在方差。

两个预测误差 $V_p$ 和 $V_w$ 受到学习算法误差的影响。因此,这两个误差项对应于学习误差。相比之下,$V_b$ 对应于与具体学习算法无关的恢复误差。最后,$V_y$ 与学习或离散化过程无关,后续不再讨论。

4.2 离散化的误差界限

本节分析离散化过程如何影响误差界限。为了理论分析的简便,我们只考虑表格输入的情况,并假设 $\mu(x, y)$ 足够平滑。但我们强调,受理论分析启发的离散化方法在现实世界设置中也将有效。

定理 4.2。假设输入 $x$ 从有限集合 $X$ 中采样。此外,假设 $\widehat{p}_m(x)$,$x \in X$ 和 $\widehat{w}_m$ 从最大似然估计中获得。此外,假设 $\mu(x, y)$ 具有有界的二阶偏导数。那么我们有:

\(V_p \leq V_p \triangleq \frac{C_p |X|}{N} \cdot A_p(D),\) \(V_w \leq V_w \triangleq \frac{C_w}{N} \cdot A_w(D),\) \(V_b \leq V_b \triangleq C_b \cdot A_b(D),\)

其中 $C_p$,$C_w$ 和 $C_b$ 是与离散化 $D$ 无关的常数,$A_p$,$A_w$,$A_b$ 是 $D$ 的函数:

\(A_p(D) = M E_{y \sim {\Psi}} y^2,\) \(A_w(D) = \sum_{m \in M} [\Psi(t_m) - \Psi(t_{m-1})]^2 \cdot \sum_{m \in M} \frac{(t_m - t_{m-1})^2}{\Psi(t_m) - \Psi(t_{m-1})},\) \(A_b(D) = \sum_{m \in M} [\Psi(t_m) - \Psi(t_{m-1})]^2 \cdot \sum_{m \in M} (t_m - t_{m-1})^2,\)

其中 $\Psi$ 是观看时间 $y$ 的累积分布函数(CDF):

\[\Psi(t) \triangleq P\{y \leq t\} = E_x \int_0^t \mu(y|x)dy.\]

证明。见附录 B。

因此,我们发现预测误差受到仅依赖于观看时间分布 $\Psi$ 和离散化 $D$ 的几个函数 $A_p$,$A_w$ 和 $A_b$ 的限制。现在我们提供一些直观的解释。

讨论 不同离散化方法对每个误差项的影响是什么?这里我们主要讨论误差项 $A_w$ 和 $A_b$,因为它们依赖于 $D$ 的分割点 ${t_m}_{m=1}^M$。我们考虑一个在 $[0, 1]$ 上截断的指数分布,即 $\Psi(t) = (1 - e^{-5t})/(1 - e^{-5})$,由 10 个桶离散化。表 1 显示了等宽和等频离散化的不同项。结果表明,等宽方法导致较低的 $A_b$,而等频方法导致较低的 $A_w$。这个结果有一个非常直观的解释,显示了 $A_b$ 和 $A_w$ 的含义:

  • 学习误差:学习误差由 $A_w$ 显示,受每个桶中的样本数量影响。具体来说,$A_w$ 的分母中存在一个 $\Psi(t_m) - \Psi(t_{m-1})$ 项。因此,如果某些桶中的样本很少,相应的概率 $\Psi(t_m) - \Psi(t_{m-1})$ 将很小,导致较大的误差。
  • 恢复误差:$A_b$ 显示恢复误差界限。它包含一个 $t_m - t_{m-1}$ 项作为乘数。当某些 $m$ 的 $t_m - t_{m-1}$ 较大时,误差项将增加,这与较大的桶宽度将导致较大的恢复误差的直觉一致。

图片名称

图4

现在我们再次讨论离散化引入的误差困境:学习误差和恢复误差通常相互矛盾。如果我们要减少学习误差,我们需要增加每个桶中的样本数量,但较大的桶宽度会导致较大的恢复误差。正式地,桶概率 $\Psi(t_m) - \Psi(t_{m-1})$ 通常与桶宽度 $t_m - t_{m-1}$ 正相关。根据上述讨论,我们即将提供 EAD 方法来平衡这两种误差。

4.3 EAD 方法

这里我们主要讨论给定桶的数量 $M$ 时的离散化策略 $D$。我们不讨论 $M$ 的选择,因为它是一个单一变量,可以被视为超参数。根据第 4.2 节,离散化策略 $D$ 需要平衡学习误差 $A_w$ 和恢复误差 $A_b$。因此,EAD 的离散化策略最小化以下目标:

\[\min_D J(D) = A_w(D) + \beta A_b(D),\]

其中:

  • $\beta$ 根据定理 4.2 由 $C_w$,$C_b$ 和 $N$ 确定。

由于 $C_w$ 和 $C_b$ 依赖于数据集的特征,无法从理论上获得,我们将 $\beta$ 视为超参数。

方程 (21) 是一个维度为 $M - 1$ 的优化问题。找到最优离散化策略是具有挑战性的,因为 $M$ 通常是几十或几百。这里我们提出一个更轻量级的方法。

图片名称

图5

我们首先正式表达等宽和等频离散化方法。具体来说,等宽方法写作:

\[t_m = \frac{m}{M} T_{\text{max}},\]

这保证了固定的 $\Delta t_m = \frac{T_{\text{max}}}{M}$,但在长尾桶中导致 $\Delta \Psi(t_m)$ 太小(见图 4(a))。相比之下,等频方法写作:

\[t_m = \Psi^{-1}\left(\frac{m}{M}\right),\]

这保证了固定的 $\Delta \Psi(t_m) = \frac{1}{M}$,但在长尾桶中导致 $\Delta t_m$ 太大(见图 4(b))。

关键是通过以下方式重写方程 (22) 和 (23):

\[t_m = \Psi^{-1}\left(\gamma\left(\frac{m}{M}\right)\right),\]

其中 $\gamma$ 是一个校准函数 $\gamma: [0, 1] \rightarrow [0, 1]$,满足 $\gamma(0) = 0$,$\gamma(1) = 1$。注意,当 $\gamma(z) = \Psi(z T_{\text{max}})$ 时,我们得到方程 (22);而当 $\gamma(z) = z$ 时,我们得到方程 (23)。

因此,方程 (24) 包含了离散化方法的两个极端情况。这启发我们通过适当地选择校准函数 $\gamma$ 在真实分布 $\Psi$ 和均匀分布之间,可以获得任何中间的分桶策略。图 4(c) 显示了校准函数 $\gamma$ 对离散化过程的影响。

根据上述讨论,我们将 $\gamma$ 设置为一组函数 $\gamma(\cdot; \alpha)$,由 $\alpha$ 参数化,然后可以通过网格搜索在方程 (21) 的优化问题下找到最优的 $\gamma$。

作为一个示例,我们使用与第 4.2 节讨论 2 相同的设置,并将 $\gamma$ 设置为 $\gamma(z; \alpha) = \frac{1 - e^{-\alpha z}}{1 - e^{-\alpha}}$,参数 $\alpha$ 从 0 到 5,并设置 $\beta = 50/100/200$。具体来说,$\alpha = 0$ 降级为等频方法,而 $\alpha = 5$ 降级为等宽方法。图 5 显示了在 $\alpha$ 和 $\beta$ 下的目标函数 $J(D)$(方程 (21)),这表明:

  • 超参数 $\beta$ 允许我们灵活地平衡学习误差和恢复误差。
  • 通过将最优的 $\alpha$ 与等频方法($\alpha = 0$)和等宽方法($\alpha = 5$)进行比较,表明通过适当地选择校准函数 $\gamma$,可以找到比传统的等频和等宽方法更合适的分桶策略。

#

https://arxiv.org/pdf/2401.07521

meta在《Understanding Scaling Laws for Recommendation Models》讨论了推荐系统中的scaling law问题。

摘要

规模(scale)一直是提高机器学习性能的主要驱动力,理解规模法则(scaling laws)对于可持续的模型质量性能增长的战略规划、长期资源规划以及开发支持大规模模型的高效系统基础设施至关重要。在本文中,我们研究了DLRM风格的推荐模型的经验规模法则,特别是点击率(CTR)。我们观察到模型质量与模型大小、数据大小和训练所用计算量呈幂律加常数的规模。我们通过比较这些轴上的不同规模方案,对数据、参数和计算三个不同的资源维度的规模效率进行了表征。我们展示了参数规模对于所研究的模型架构已经力不从心,而在出现更高性能的模型架构之前,数据规模是前进的道路。本研究解决的关键研究问题包括:

  • 推荐模型是否如规模法则预测的那样可持续地规模?
  • 我们是否远离规模法则的预测?
  • 规模的极限是什么?
  • 规模法则对长期硬件/系统开发有何影响?

1. 引言

在过去十年中,深度学习总体上,特别是基于深度学习的推荐模型(DLRM),在数据集规模、模型规模和系统资源方面经历了指数级的增长(Elkahky等人,2015年;Covington等人,2016年;Sullivan,2016年;Liu等人,2017年;Yi等人,2018年;Zhou等人,2019年;Zhao等人,2019年;Naumov等人,2020年;Zhao等人,2020年;Lui等人,2021年;Acun等人,2021年;Steck等人,2021年;Lian等人,2021年),将人工智能行业推向了万亿参数时代。实现万亿参数模型需要在人工智能系统基础设施上进行大量投资(Mudigere等人,2022年)。从系统设计的角度来看,主要问题/关注点是:

  • 如何扩展?
  • 哪种扩展方案提供更好的投资回报率(ROI)?
  • 如何战略性地结合不同的扩展方案以提供更好的ROI?

图1显示了在5年时间(2016-2021)内,语言建模任务和DLRMs的模型规模增长了10000倍。这些结果只反映了已发布模型的增长。我们预计DLRMs的增长速度甚至更快。推荐系统是许多互联网公司的主要收入来源。因此,这些模型的细节通常是保密的。最近的研究表明,在仅仅2年多的时间里(2019-2021),Facebook的推荐模型在参数数量上增长了20倍,在训练集大小上增长了2.4倍,系统基础设施增长了2.5-2.9倍(Wu等人,2021年;Mudigere等人,2022年),并且超过50%的数据中心AI训练周期都致力于推荐模型(Acun等人,2021年)。尽管它们很重要,但对于DLRM模型如何扩展,人们的认识有限。识别和理解模型的扩展属性对于设计服务于这些模型的人工智能系统和基础设施至关重要。我们的论文是首次尝试解决这一差距。

图片名称

图1 深度学习总体上,特别是基于深度学习的推荐模型近年来在参数规模上经历了指数级的增长(Sevilla等人,2021年;Mudigere等人,2022年;Lian等人,2021年)。请注意不同领域增长趋势的差异。

最近的工作(Hestness等人,2017年;Kaplan等人,2020年;Hernandez等人,2021年;Henighan等人,2020年;Gordon等人,2021年;Zhai等人,2021年;Brown等人,2020年;Hestness等人,2019年;Prato等人,2021年;Bahri等人,2021年)显示,在包括语言建模、机器翻译、视觉变换器、迁移学习和其他自回归模型在内的广泛领域中,高度可预测的扩展趋势。然而,推荐系统如何扩展尚不清楚

此外,先前的研究在他们的扩展分析中没有包括embedding参数。embedding参数占推荐模型容量的大部分(>90%),因此,研究它们对模型质量性能扩展的影响至关重要。

我们在这项工作中的目标是表征深度学习推荐模型的扩展规律,特别是点击率(CTR)预测模型。CTR模型是推荐系统中最重要的机器学习任务之一,为数十亿用户提供个性化体验。通过研究许多不同模型规模N(跨越三个数量级)、计算预算C(跨越五个数量级)和数据集规模D(跨越三个数量级),我们展示了一个简单的幂律加常数可以解释CTR模型在一个周期内的性能与N、D和C之间的关系。

图11概述了一个典型的DLRM架构。在高层次上,有两个主要组件可以扩展:嵌入表和多层感知器(MLP)。

图片名称

图11 深度学习模型架构的示意图。

  • 嵌入表(embedding table)可以通过垂直扩展(增加每个表的嵌入行数)或水平扩展(扩展嵌入的维度)来扩展。
  • MLP层可以通过加宽或加深层来扩展。

我们研究了在四种扩展方法上的推荐系统的经验扩展规律:扩展嵌入表(垂直和水平)、扩展顶层MLP层(我们称之为总架构层)以及扩展所有MLP层(包括通过增加宽度来扩展密集层、总架构层和密集-稀疏交互层)。

1.1 摘要

我们对CTR预测模型的关键发现如下:

幂律加常数:我们观察到,在训练一个周期后,推荐模型的性能(测试损失)与资源投入遵循幂律加常数关系(αx−β + γ)(见图2)。资源包括数据集大小、模型大小和计算浮点运算量。幂律加常数函数中的常数γ标识了扩展的极限:即我们假设可以无限扩展资源时能达到的最佳水平。表1显示了不同扩展方案和不同资源投入情景下经验收集的α、β和γ值。

图片名称

图2 推荐系统的性能随着数据规模、模型规模以及训练计算量(FLOPs)的增加而呈现出幂律增长加上一个常数的特性:

  • (a) 通过增加多层感知机(MLP)层的宽度来扩展模型规模。
  • (b) 通过增加顶层网络层的宽度来扩展模型规模。
  • (c) 通过增加嵌入表的维度来扩展模型规模。
  • (d) 通过增加嵌入表中的行数来扩展模型规模。

幂律函数的两个阶段:如图3所示,幂律函数可以被一个高回报阶段和随后的低回报/饱和阶段所特征化。收益递减点是过渡发生的地方。如果使用幂律函数来比较两种扩展方案的效率,需要关注幂律函数的指数(β)以及操作阶段。指数较大且衰减更快的幂律函数更适合扩展。然而,处于饱和阶段的操作方案无论其指数如何,都不如非饱和方法。

图片名称

图3 幂律函数特征曲线

性能强烈依赖于数据集大小和计算能力,而与模型参数大小关系较弱:模型性能强烈依赖于训练集中的样本数量(D)和计算浮点运算量(C),而与参数数量(P)关系较弱。

扩展的极限:幂律趋势中的常数(γ)捕获了不可减少的错误。这意味着通过扩展资源(模型参数、数据大小和/或计算浮点运算)到无限大所能达到的最佳归一化测试损失将饱和在0.98。

数据扩展效率:所有扩展方案的数据扩展效率相似(β在[0.09, 0.12]范围内),并且对模型大小不敏感。所有扩展方案都处于高回报阶段。根据图4中显示的幂律指数,可以看出垂直扩展嵌入表(V)比水平扩展嵌入表(H)更好,而水平扩展嵌入表又比顶层MLP层扩展(O)更好,后者又比MLP层扩展(M)在数据扩展效率方面更好。这意味着在固定参数预算下,通过同时扩展数据集大小和模型大小来扩展模型性能,对参数扩展方法有些敏感。

图片名称

图4 不同模型扩展方案中的数据扩展效率。尽管每条线显示了在固定模型规模下的数据扩展趋势,每个图表中的虚线及其对应的方程捕捉了帕累托最优曲线。如图所示,不论扩展方案如何,当模型和数据一起扩展时,所有模型或多或少具有相同的幂律扩展特性(幂指数为-0.1),这意味着所有模型扩展方案中的数据扩展效率是相同的。

计算扩展效率:所有扩展方案的计算扩展效率相似(β在[0.12, 0.15]范围内)。所有扩展方案都处于高回报阶段。根据图5中显示的幂律指数,可以看出MLP扩展比顶层扩展更计算效率高,顶层扩展又略比嵌入维度扩展更计算效率高。

图片名称

图5 计算扩展效率 - 两种视角:(a) 同时扩展计算量(FLOPs)和数据集规模 (b) 同时扩展计算量(FLOPs)和模型规模。

参数扩展效率:不同扩展方案的参数扩展效率不同(α在[0.4, 7.6]范围内)。然而,所有扩展方案都处于饱和阶段(见图6)。对于一个工业级模型,所有参数扩展技术在参数扩展效率方面相似。这意味着在固定数据预算下,通过增加模型中的参数数量来扩展模型性能,对参数扩展方法不敏感。

图片名称

图6 不同参数扩展方案中的参数扩展效率。在所有扩展方案中可见的模式是,准确性与参数规模之间的弱依赖性。

2. 扩展效率

在给定固定预算/资源的情况下,主要问题是哪种扩展方案可以提供更好的投资回报率(ROI)。我们针对三种不同的资源,即数据、参数和计算浮点运算量,对扩展效率进行了表征。我们展示了所有扩展方案在数据扩展和计算扩展效率上都相似,并且仍有改进空间。另一方面,参数扩展效率非常低,因为它已经超出了收益递减点。

2.1 数据扩展效率

为了研究数据扩展效率,我们在广泛范围内(三个数量级)扩展数据集大小,同时保持模型大小不变。从概念上讲,线的斜率捕捉了模型在面对问题时吸收新信息的有效性。结果如图4所示。每个图表捕捉了不同的模型扩展方案(垂直嵌入、水平嵌入、顶层和MLP扩展)。

正如所有扩展策略所示,推荐系统的性能强烈依赖于数据集大小,而与参数/模型大小关系较弱。这是违反直觉且非常有趣的。我们继续看到在过去5年中嵌入表的大小和嵌入表的数量不断增长。这些结果意味着工业级模型在过拟合范围内运行。

虽然图4中的每条线显示了固定模型大小的数据扩展趋势,但每个图表中的虚线捕捉了帕累托前沿线。如图所示,无论扩展方案如何,所有模型都有类似的幂律趋势。这意味着所有模型扩展方案的数据扩展效率相似。

摘要 推荐系统的性能强烈依赖于数据大小,而与参数/模型大小关系较弱。与大规模语言模型(Hestness等人,2017年;Kaplan等人,2020年)相比,其中性能与模型大小强烈相关,推荐系统对模型大小的敏感性较弱,这在设计下一代推荐系统时需要考虑。所有扩展方案的数据扩展效率相似。这意味着所研究的模型以相同的速率从新数据中吸收信息,无论其背后的扩展方案如何。输入粒度/词汇量大小对扩展趋势没有显著影响。

2.2 计算扩展效率

我们的目标是表征捕捉模型质量性能与计算浮点运算量之间关系的线的斜率。从概念上讲,线的斜率捕捉了模型在面对问题时对新计算浮点运算量吸收新信息的速度。在计算效率分析中,我们保持数据(或模型大小)不变,同时扩展模型大小(或数据大小)。当我们扩展模型大小或数据大小时,我们间接地增加了计算浮点运算量。还有另一种方法可以在不改变数据大小或模型大小的情况下扩展计算浮点运算量,那就是训练更长时间的模型。我们留待未来的工作。

图5显示了这种扩展的结果。每个图表捕捉了不同的模型扩展方案(水平嵌入、顶层和MLP扩展。注意我们没有显示垂直扩展的计算扩展,因为增加行数对计算浮点运算量没有任何影响。)如图所示,所有扩展策略中,推荐系统的性能强烈依赖于计算浮点运算量的数量。我们以两种不同的方式呈现相同的结果:(1)通过模型扩展增加计算浮点运算量,同时保持数据大小不变(图5,顶行)。(2)或者,我们通过数据扩展增加计算浮点运算量,同时保持模型大小不变(图5,底行)。

图片名称

图7 何时选择垂直扩展(Vertical Scaling)与水平扩展(Horizontal Scaling)?

图片名称

图8 何时选择顶层网络扩展(Over-arch Scaling)与多层感知机扩展(MLP Scaling)?

图片名称

图9 嵌入维度对表大小的敏感性:每条线展示了不同的垂直扩展因子(VSF)。大的蓝色圆圈显示了每条线的最小损失。然而,曲线的拐点在64处始终如一地出现。

同时扩展计算和数据 图5顶行显示了通过扩展模型大小对性能的计算浮点运算量扩展影响。在每条线内,我们保持数据大小不变,同时通过模型大小扩展增加计算浮点运算量。注意不同扩展方案的幂律方程的幂之间的轻微差异。看来,MLP扩展略优于顶层扩展,顶层扩展又略优于嵌入维度扩展,在相同增加的计算预算下提高模型准确性(0.15对-0.14对-0.12)。此外,如图所示,在固定的计算预算下,更大的数据集大小会带来更好的性能。同时,在固定的准确性目标下,更小的数据集大小更具计算效率。

同时扩展计算和模型大小 图5底行显示了通过扩展数据大小对性能的计算浮点运算量扩展影响。在每条线内,我们保持模型大小不变,同时通过扩展数据集大小增加计算浮点运算量。如图所示,在固定的计算预算下,更大的模型获得更低的性能。同时,在固定的准确性目标下,更小的模型大小更具计算效率。虚线捕捉了在每个计算预算下获得最佳性能的最佳模型大小。图5(a)和(b)基本上是同一组点,从两个不同的视角呈现(一次基于数据集大小对点进行分组,一次基于模型大小进行分组),因此,帕累托最优线(虚线)将是相同的。

摘要 在固定的计算预算下,需要在在更大的数据集大小上训练模型或训练具有更多参数的模型之间做出权衡。我们观察到,在固定的计算预算下,具有更多参数的模型显示出更低/更差的性能,而用更大的数据集大小训练的模型显示出更好的性能。从计算效率的角度来看,我们观察到,在第一个周期,MLP扩展优于顶层扩展,顶层扩展优于水平扩展嵌入表。注意,垂直扩展嵌入表对计算浮点运算量没有任何影响。

3. 敏感性分析

3.1 如何有效地按行数扩展嵌入维度?

图9展示了随着我们在表中增加行数(增加垂直扩展因子)时最佳嵌入维度的变化情况。如图所示,随着垂直扩展因子的增大,最佳嵌入维度趋于变小(对于0.125×和0.25×的垂直扩展因子,256是最佳嵌入维度,而对于0.5×、1×和2×的垂直扩展因子则是128)。然而,最佳性能和最具资源效率的嵌入维度并不一定是相同的。如图所示,曲线的拐点(收益递减点)对于所有表大小在嵌入维度=64左右开始出现。这意味着从资源效率的角度来看,嵌入维度的资源高效设计点对垂直扩展因子的依赖性较弱。这一结果暗示,从资源效率的角度来看,超过64将不会提供高投资回报率。

3.2 训练与测试

如图10所示,训练数据的学习曲线比测试数据的学习曲线更陡峭(-0.20对比-0.12)。两条曲线都捕捉了在相同数据上训练的相同模型的扩展,但在两个不同的数据集上进行了评估。左侧的曲线在训练集的数据点上进行了评估,右侧的模型在测试集上进行了评估。这种差距意味着模型从额外的训练点吸收的信息在预测来自相同分布(训练分布而非测试分布)的数据时更有效,这是意料之中的。

图片名称

图10 training loss与testing loss上的数据扩展效率。请注意训练曲线和测试曲线之间幂律指数的差异。

4. 讨论

特征化不同扩展方案的幂律曲线提供了每种扩展技术的数据效率、参数效率和计算效率的见解。人们可以通过比较它们在三个不同轴(数据、计算、参数)上的幂律曲线,潜在地比较任何成对扩展技术的效率。表2显示了这种比较的结果。如图所示,没有单一的扩展技术在所有扩展效率维度上都脱颖而出。例如,水平嵌入扩展(H)在数据效率方面优于MLP扩展(M),但在计算效率方面则较差。

最近的分析显示,在短短5年多的时间里,工业级推荐模型增长了四个数量级(Mudigere等人,2022年;Lian等人,2021年)。幂律分析支持了过去的趋势。当按幂律趋势近似时,参数扩展的指数幅度最大。然而,工业级推荐模型已经过于庞大且饱和,因此进一步的参数增长不会从资源效率的角度提供高投资回报率。

与此同时,数据扩展和计算扩展仍然处于高收益递减的范围内。这意味着在更好的模型架构出现之前,应该将数据扩展视为一流的扩展方法。话虽如此,我们应该意识到,由于数据保留的限制,数据扩展从长远来看(以原始形式)并不是一种可持续的方法。

为了克服这一点,我们需要考虑替代方案。以下是一些建议,其中一些我们将作为下一步探索:(1) 记录更多数据,特别是通过记录更多负样本和减少正样本下采样;(2) 探索使用历史数据作为教师模型来训练模型,以合成从历史数据中学习到的有价值信息,供更近期的模型使用;(3) 水平扩展数据量而不是垂直扩展,即增加更多特征而不是增加更多行。

扩展法则也可以用来指导长期硬件开发。硬件设计通常提前3-5年开始,依靠对未来3-5年模型增长的准确预测。我们的分析表明,展望未来,硬件不需要增长来支持更大的模型。相反,我们需要设计硬件/系统来支持使用更大的数据集进行训练。

另一个关键的收获是,幂律加常数方程中的常数在0.98(以归一化熵度量的损失)处有界。这个常数捕获了在无限扩展极限下模型的准确性,可以用作衡量工业级模型与无限极限的距离的指南。在NLP领域的先前分析表明,模型架构的创新(例如,从LSTM过渡到Transformer)可以改善幂律的系数(即α.x−β + γ中的α),并向下移动曲线,但它们对幂律的指数(β)几乎没有影响(Hestness等人,2017年;Brown等人,2020年)。这表明模型架构探索是性能增长的短期解决方案。长期解决方案将需要改善幂律趋势的指数。至今,是什么控制了幂律的斜率仍然是一个开放的研究问题。幂律曲线的斜率似乎对每个领域都是独特的,与模型架构无关(Hestness等人,2017年;2019年)。先前的分析表明,改善数据分布可以改善幂律的指数(Bahri等人,2021年)。最近的工作表明,通过有效的数据修剪,我们可以打败幂律并实现指数级扩展(Sorscher等人,2022年)。

https://arxiv.org/pdf/2208.08489