mllib中lda源码分析(二)

Reading time ~3 minutes

mllib中lda源码分析(一)中,我们描述了LDA的原理、以及变分推断batch算法和online算法的推导。这一节会再描述下spark MLlib中的ml实现。

spark MLlib的实现,是基于变分推断算法实现的,后续的版本会增加Gibbs sampling。本文主要关注online版本的lda方法。(代码主要以1.6.1为主)

一、Batch算法

二、Online算法

ml中的lda相当于一层wrapper,更好地适配ml中的架构(tranformer/pipeline等等),调用的实际上是mllib中的lda实现。

1.LDA.run(主函数)

  def run(documents: RDD[(Long, Vector)]): LDAModel = {
    // 初始化(em=>em, online=>online).
    val state = ldaOptimizer.initialize(documents, this)

    // 迭代开始. 
    var iter = 0
    val iterationTimes = Array.fill[Double](maxIterations)(0)
    while (iter < maxIterations) {
      val start = System.nanoTime()
      
      // optimzier对应的迭代函数.
      state.next()
      
      val elapsedSeconds = (System.nanoTime() - start) / 1e9
      iterationTimes(iter) = elapsedSeconds
      iter += 1
    }

    // 获取最终完整的的LDAModel. 
    state.getLDAModel(iterationTimes)
  }

OnlineOptimizer.next

  override private[clustering] def next(): OnlineLDAOptimizer = {
    // 先进行sample
    val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
      randomGenerator.nextLong())
    
    // 如果为空,返回
    if (batch.isEmpty()) return this

    // 提交batch,进行lda迭代.
    submitMiniBatch(batch)
  }

submitMinibatch

/**
   * 提供语料的minibatch给online LDA模型.为出现在minibatch中的terms它会自适应更新topic distribution。
   * 
   * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
   * model, and it will update the topic distribution adaptively for the terms appearing in the
   * subset.
   */
  private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = {
    iteration += 1
    val k = this.k
    val vocabSize = this.vocabSize
    val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t      // β ~ Dirichlet(λ)
    val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)       // 
    val alpha = this.alpha.toBreeze
    val gammaShape = this.gammaShape

    // step 1: 对每个partition做map,单独计算E-Step. stats(stat, gammaPart)
    val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs =>

      // 1.过滤掉空的.
      val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)

      // 2.stat 是一个DenseMatrix:  k x vocabSize
      val stat = BDM.zeros[Double](k, vocabSize)

      // 
      var gammaPart = List[BDV[Double]]()

      // 3.遍历partition上的所有document:执行EM算法.
      // E-step可以并行计算.
      // M-step需要reduce,作合并,然后再计算.
      nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
        // 3.1  取出document所对应的词的id。(支持dense/sparse).
        val ids: List[Int] = termCounts match {
          case v: DenseVector => (0 until v.size).toList
          case v: SparseVector => v.indices.toList
        }

        // 3.2 更新状态 E-step => gammad ().
        val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
          termCounts, expElogbetaBc.value, alpha, gammaShape, k)

        // 3.3 根据所对应id取出对应列. 更新sstats(对应主题状态) 
        stat(::, ids) := stat(::, ids).toDenseMatrix + sstats

        // 3.4 更新该partition上每个文档的gammad的gammaPart列表中.
        gammaPart = gammad :: gammaPart
      }

      // 4.mapPartition返回iterator,每个partition返回一个迭代器(stat, gammaPart)
      // stat: k x V matrix. 针对该partition上的文档,所更新出的每个词在各主题上的分布.
      // gammaPart: list[vector[K]] 该分区上每个文档的gammad列表.
      Iterator((stat, gammaPart))
    }

    // step 2: 对mini-batch的所有partition上的stats(stat,gammaPart)中的stat进行求总和.
    val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
    expElogbetaBc.unpersist()

    // step 3: 对mini-batch的所有partition上的stats(stat,gammaPart)中的gammaPart进行更新.
    val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
      stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)

    // step 4: 更新batchResult ( K x V), 每个元素上乘以 E[log(β)]
    val batchResult = statsSum :* expElogbeta.t

    // M-step:
    // 更新λ.
    // Note that this is an optimization to avoid batch.count
    updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)

    // 如果需要优化DocConentration,是否更新alpha
    if (optimizeDocConcentration) updateAlpha(gammat)
    this
  }

variationalTopicInference

里面比较核心的一个方法是variationalTopicInference:

  private[clustering] def variationalTopicInference(
      termCounts: Vector,
      expElogbeta: BDM[Double],
      alpha: breeze.linalg.Vector[Double],
      gammaShape: Double,
      k: Int): (BDV[Double], BDM[Double]) = {

    // step 1: 词的id和cnt: (ids, cnts) 
    val (ids: List[Int], cts: Array[Double]) = termCounts match {
      case v: DenseVector => ((0 until v.size).toList, v.values)
      case v: SparseVector => (v.indices.toList, v.values)
    }

    // step 2: 初始化: gammad ~ Γ(100, 0.01), 100维
    // gammd: mini-batch的变分分布: q(θ | γ) 
    // expElogthetad: paper中的exp(E[logθ]), 其中θ ~ Dirichlet(γ)
    // expElogbetad:  paper中的exp(E[logβ]), 其中β : V * K
    // phiNorm:  ϕ ∝ exp{E[logθ] + E[logβ]},ϕ ~ θ*β. 非零.
    // Initialize the variational distribution q(theta|gamma) for the mini-batch
    val gammad: BDV[Double] =
      new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k)                   // K
    val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad))  // K
    val expElogbetad = expElogbeta(ids, ::).toDenseMatrix                        // ids * K

    // 加上一个很小的数,让它非零.
    val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
    var meanGammaChange = 1D
    val ctsVector = new BDV[Double](cts)                                         // ids

    // 单个mini-batch里的loop.
    // 迭代,直到 β 和 ϕ 收敛. paper中是0.000001,此处用的是1e-3.
    // Iterate between gamma and phi until convergence
    while (meanGammaChange > 1e-3) {
      val lastgamma = gammad.copy

      // breeze操作:https://github.com/scalanlp/breeze/wiki/Universal-Functions
      // ":"为elementwise操作符;其中的:=,对象相同,内容赋值
      // 计算:γ_dk, 先对ctsVector进行归一化,再乘以转置E(log(β)]^T,再pointwise乘E(log(θ)).
      // 各矩阵或向量的维度: K为topic, ids:为词汇表V
      // gammad: Vector(K),单个文档上每个主题各有一γ值.
      // expElogthetad: Vector(K),单个文档上每个主题各有一θ值.
      // expElogbetad: Matrix(ids*K),每个词,每个主题都有一β值.
      // ctsVector: Vector(ids),每个词上有一个词频.
      // phiNorm: Vector(ids),每个词上有一个ϕ分布值。
      // 
      // 
      //        K                  K * ids               ids
      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha
      
      // 更新 exp(E[logθ]), expElogbetad不需要更新
      expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
      
      // 更新 phiNorm, 
      // TODO: Keep more values in log space, and only exponentiate when needed.
      phiNorm := expElogbetad * expElogthetad :+ 1e-100

      // 平均变化量.
      meanGammaChange = sum(abs(gammad - lastgamma)) / k
    }

    // sstats: mini-batch下每个主题下的词分布.
    // n 
    // sstatsd: k x 1 * 1 x ids => k x ids
    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix
    (gammad, sstatsd)
  }

ok。关于alpha的更新等,此处不再解释。详见源码。

google Titans介绍

google在《Titans: Learning to Memorize at Test Time》提出了区别于Transformer的的一种新架构:Titans。我们来看一下它的实现,是否有前景:# 摘要在过去的十多年里,关于如何有效利用循环模型(recurrent mo...… Continue reading

meta QuickUpdate介绍

Published on January 02, 2025

kuaishou CREAD介绍

Published on August 05, 2024