在mllib中lda源码分析(一) 中,我们描述了LDA的原理、以及变分推断batch算法和online算法的推导。这一节会再描述下spark MLlib中的ml实现。
spark MLlib的实现,是基于变分推断算法实现的,后续的版本会增加Gibbs sampling。本文主要关注online版本的lda方法。(代码主要以1.6.1为主)
一、Batch算法
二、Online算法
ml中的lda相当于一层wrapper,更好地适配ml中的架构(tranformer/pipeline等等),调用的实际上是mllib中的lda实现。
1.LDA.run(主函数)
def run ( documents : RDD [( Long , Vector )]) : LDAModel = {
// 初始化(em=>em, online=>online).
val state = ldaOptimizer . initialize ( documents , this )
// 迭代开始.
var iter = 0
val iterationTimes = Array . fill [ Double ]( maxIterations )( 0 )
while ( iter < maxIterations ) {
val start = System . nanoTime ()
// optimzier对应的迭代函数.
state . next ()
val elapsedSeconds = ( System . nanoTime () - start ) / 1 e9
iterationTimes ( iter ) = elapsedSeconds
iter += 1
}
// 获取最终完整的的LDAModel.
state . getLDAModel ( iterationTimes )
}
OnlineOptimizer.next
override private [ clustering ] def next () : OnlineLDAOptimizer = {
// 先进行sample
val batch = docs . sample ( withReplacement = sampleWithReplacement , miniBatchFraction ,
randomGenerator . nextLong ())
// 如果为空,返回
if ( batch . isEmpty ()) return this
// 提交batch,进行lda迭代.
submitMiniBatch ( batch )
}
submitMinibatch
/**
* 提供语料的minibatch给online LDA模型.为出现在minibatch中的terms它会自适应更新topic distribution。
*
* Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
* model, and it will update the topic distribution adaptively for the terms appearing in the
* subset.
*/
private [ clustering ] def submitMiniBatch ( batch : RDD [( Long , Vector )]) : OnlineLDAOptimizer = {
iteration += 1
val k = this . k
val vocabSize = this . vocabSize
val expElogbeta = exp ( LDAUtils . dirichletExpectation ( lambda )). t // β ~ Dirichlet(λ)
val expElogbetaBc = batch . sparkContext . broadcast ( expElogbeta ) //
val alpha = this . alpha . toBreeze
val gammaShape = this . gammaShape
// step 1: 对每个partition做map,单独计算E-Step. stats(stat, gammaPart)
val stats : RDD [( BDM [ Double ] , List [ BDV [ Double ]])] = batch . mapPartitions { docs =>
// 1.过滤掉空的.
val nonEmptyDocs = docs . filter ( _ . _2 . numNonzeros > 0 )
// 2.stat 是一个DenseMatrix: k x vocabSize
val stat = BDM . zeros [ Double ]( k , vocabSize )
//
var gammaPart = List [ BDV [ Double ]]()
// 3.遍历partition上的所有document:执行EM算法.
// E-step可以并行计算.
// M-step需要reduce,作合并,然后再计算.
nonEmptyDocs . foreach { case ( _ , termCounts : Vector ) =>
// 3.1 取出document所对应的词的id。(支持dense/sparse).
val ids : List [ Int ] = termCounts match {
case v : DenseVector => ( 0 until v . size ). toList
case v : SparseVector => v . indices . toList
}
// 3.2 更新状态 E-step => gammad ().
val ( gammad , sstats ) = OnlineLDAOptimizer . variationalTopicInference (
termCounts , expElogbetaBc . value , alpha , gammaShape , k )
// 3.3 根据所对应id取出对应列. 更新sstats(对应主题状态)
stat (::, ids ) := stat (::, ids ). toDenseMatrix + sstats
// 3.4 更新该partition上每个文档的gammad的gammaPart列表中.
gammaPart = gammad :: gammaPart
}
// 4.mapPartition返回iterator,每个partition返回一个迭代器(stat, gammaPart)
// stat: k x V matrix. 针对该partition上的文档,所更新出的每个词在各主题上的分布.
// gammaPart: list[vector[K]] 该分区上每个文档的gammad列表.
Iterator (( stat , gammaPart ))
}
// step 2: 对mini-batch的所有partition上的stats(stat,gammaPart)中的stat进行求总和.
val statsSum : BDM [ Double ] = stats . map ( _ . _1 ). reduce ( _ += _ )
expElogbetaBc . unpersist ()
// step 3: 对mini-batch的所有partition上的stats(stat,gammaPart)中的gammaPart进行更新.
val gammat : BDM [ Double ] = breeze . linalg . DenseMatrix . vertcat (
stats . map ( _ . _2 ). reduce ( _ ++ _ ). map ( _ . toDenseMatrix ) : _ * )
// step 4: 更新batchResult ( K x V), 每个元素上乘以 E[log(β)]
val batchResult = statsSum :* expElogbeta . t
// M-step:
// 更新λ.
// Note that this is an optimization to avoid batch.count
updateLambda ( batchResult , ( miniBatchFraction * corpusSize ). ceil . toInt )
// 如果需要优化DocConentration,是否更新alpha
if ( optimizeDocConcentration ) updateAlpha ( gammat )
this
}
variationalTopicInference
里面比较核心的一个方法是variationalTopicInference:
private [ clustering ] def variationalTopicInference (
termCounts : Vector ,
expElogbeta : BDM [ Double ],
alpha : breeze.linalg.Vector [ Double ],
gammaShape : Double ,
k : Int ) : ( BDV [ Double ], BDM [ Double ]) = {
// step 1: 词的id和cnt: (ids, cnts)
val ( ids : List [ Int ], cts : Array [ Double ]) = termCounts match {
case v : DenseVector => (( 0 until v . size ). toList , v . values )
case v : SparseVector => ( v . indices . toList , v . values )
}
// step 2: 初始化: gammad ~ Γ(100, 0.01), 100维
// gammd: mini-batch的变分分布: q(θ | γ)
// expElogthetad: paper中的exp(E[logθ]), 其中θ ~ Dirichlet(γ)
// expElogbetad: paper中的exp(E[logβ]), 其中β : V * K
// phiNorm: ϕ ∝ exp{E[logθ] + E[logβ]},ϕ ~ θ*β. 非零.
// Initialize the variational distribution q(theta|gamma) for the mini-batch
val gammad : BDV [ Double ] =
new Gamma ( gammaShape , 1.0 / gammaShape ). samplesVector ( k ) // K
val expElogthetad : BDV [ Double ] = exp ( LDAUtils . dirichletExpectation ( gammad )) // K
val expElogbetad = expElogbeta ( ids , ::). toDenseMatrix // ids * K
// 加上一个很小的数,让它非零.
val phiNorm : BDV [ Double ] = expElogbetad * expElogthetad :+ 1 e - 100 // ids
var meanGammaChange = 1D
val ctsVector = new BDV [ Double ]( cts ) // ids
// 单个mini-batch里的loop.
// 迭代,直到 β 和 ϕ 收敛. paper中是0.000001,此处用的是1e-3.
// Iterate between gamma and phi until convergence
while ( meanGammaChange > 1 e - 3 ) {
val lastgamma = gammad . copy
// breeze操作:https://github.com/scalanlp/breeze/wiki/Universal-Functions
// ":"为elementwise操作符;其中的:=,对象相同,内容赋值
// 计算:γ_dk, 先对ctsVector进行归一化,再乘以转置E(log(β)]^T,再pointwise乘E(log(θ)).
// 各矩阵或向量的维度: K为topic, ids:为词汇表V
// gammad: Vector(K),单个文档上每个主题各有一γ值.
// expElogthetad: Vector(K),单个文档上每个主题各有一θ值.
// expElogbetad: Matrix(ids*K),每个词,每个主题都有一β值.
// ctsVector: Vector(ids),每个词上有一个词频.
// phiNorm: Vector(ids),每个词上有一个ϕ分布值。
//
//
// K K * ids ids
gammad := ( expElogthetad :* ( expElogbetad . t * ( ctsVector :/ phiNorm ))) :+ alpha
// 更新 exp(E[logθ]), expElogbetad不需要更新
expElogthetad := exp ( LDAUtils . dirichletExpectation ( gammad ))
// 更新 phiNorm,
// TODO: Keep more values in log space, and only exponentiate when needed.
phiNorm := expElogbetad * expElogthetad :+ 1 e - 100
// 平均变化量.
meanGammaChange = sum ( abs ( gammad - lastgamma )) / k
}
// sstats: mini-batch下每个主题下的词分布.
// n
// sstatsd: k x 1 * 1 x ids => k x ids
val sstatsd = expElogthetad . asDenseMatrix . t * ( ctsVector :/ phiNorm ). asDenseMatrix
( gammad , sstatsd )
}
ok。关于alpha的更新等,此处不再解释。详见源码。
lda
Updated on December 22, 2015
d0evi1