spark word2vec 源码详细解析

简单介绍spark word2vec

Word2Vec creates vector representation of words in a text corpus.
The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary.
The vector representation can be used as features in natural language processing and machine learning algorithms.
We used skip-gram model in our implementation and hierarchical softmax method to train the model. The variable names in the implementation matches the original C implementation.
For original C implementation, see https://code.google.com/p/word2vec/
For research papers, see
Efficient Estimation of Word Representations in Vector Space paper1
and
Distributed Representations of Words and Phrases and their Compositionality. paper2

源码解析

package org.apache.spark.mllib.feature
import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
*  Entry in vocabulary   定义词典的属性类   复习:scala的class类别
*/
private case class VocabWord(
  var word: String,   //词
  var cn: Int,        //频次
  var point: Array[Int],   // ARRAY 存的是这个词[叶子结点]的从根节点到叶子节点的路径经过的节点
  var code: Array[Int],   //记录Huffman编码
  var codeLen: Int        //code长度,路径长度 ,存储到达该叶子结点,要经过多少个结点
)
本文只实现skip-gram hierarchical softmax 部分,参照C语言实现的代码:https://code.google.com/p/word2vec/
参照两篇论文:Efficient Estimation of Word Representations in Vector Space & Distributed Representations of Words and Phrases and their Compositionality
@Since("1.1.0")
class Word2Vec extends Serializable with Logging {
//默认参数
  private var vectorSize = 100  //训练vector的长度
  private var learningRate = 0.025  //训练时的学习率
  private var numPartitions = 1   //分区数
  private var numIterations = 1   //迭代次数
  private var seed = Utils.random.nextLong()  //随机种子
  private var minCount = 5   //词的最小出现频次
  private var maxSentenceLength = 1000  //句子的长度

//如果大于maxSentenceLength 句子的长度,将会截断为多个块。
  /**
   * Sets the maximum length (in words) of each sentence in the input data.
   * Any sentence longer than this threshold will be divided into chunks of
   * up to `maxSentenceLength` size (default: 1000)
   */
  @Since("2.0.0")
  def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
    require(maxSentenceLength > 0,
      s"Maximum length of sentences must be positive but got ${maxSentenceLength}")
    this.maxSentenceLength = maxSentenceLength
    this
  }
  /**
   * Sets vector size (default: 100).
   */
  @Since("1.1.0")
  def setVectorSize(vectorSize: Int): this.type = {
    require(vectorSize > 0,
      s"vector size must be positive but got ${vectorSize}")
    this.vectorSize = vectorSize
    this
  }


  /**
   * Sets initial learning rate (default: 0.025).
   */
  @Since("1.1.0")
  def setLearningRate(learningRate: Double): this.type = {
    require(learningRate > 0,
      s"Initial learning rate must be positive but got ${learningRate}")
    this.learningRate = learningRate
    this
  }


  /**
   * Sets number of partitions (default: 1). Use a small number for accuracy. //设置少数分区有利于准确性
   */
  @Since("1.1.0")
  def setNumPartitions(numPartitions: Int): this.type = {
    require(numPartitions > 0,
      s"Number of partitions must be positive but got ${numPartitions}")
    this.numPartitions = numPartitions
    this
  }


  /**
   * Sets number of iterations (default: 1), which should be smaller than or equal to number of 
   * partitions.  //设置迭代次数,要小于或者等于分区数
   */
  @Since("1.1.0")
  def setNumIterations(numIterations: Int): this.type = {
    require(numIterations >= 0,
      s"Number of iterations must be nonnegative but got ${numIterations}")
    this.numIterations = numIterations
    this
  }


  /**
   * Sets random seed (default: a random long integer).
   */
  @Since("1.1.0")
  def setSeed(seed: Long): this.type = {
    this.seed = seed
    this
  }


  /**
   * Sets the window of words (default: 5) //根据单个文本的长度合理设置,目前针对于标题40个字,设置为5
   */
  @Since("1.6.0")
  def setWindowSize(window: Int): this.type = {
    require(window > 0,
      s"Window of words must be positive but got ${window}")
    this.window = window
    this
  }


  /**
   * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
   * model's vocabulary (default: 5).//根据文本的词的频次分布设置,保证覆盖大多数的文本。
   */
  @Since("1.3.0")
  def setMinCount(minCount: Int): this.type = {
    require(minCount >= 0,
      s"Minimum number of times must be nonnegative but got ${minCount}")
    this.minCount = minCount
    this
  }


  private val EXP_TABLE_SIZE = 1000
  private val MAX_EXP = 6
  private val MAX_CODE_LENGTH = 40


  /** context words from [-window, window] */  //滑动窗口以中心词的左右各+-window选词。
  private var window = 5


  private var trainWordsCount = 0L
  private var vocabSize = 0
*********transient 解释:
我们都知道一个对象只要实现了Serilizable接口,这个对象就可以被序列化,java的这种序列化模式为开发者提供了很多便利,我们可以不必关系具体序列化的过程,只要这个类实现了Serilizable接口,这个类的所有属性和方法都会自动序列化。
然而在实际开发过程中,我们常常会遇到这样的问题,这个类的有些属性需要序列化,而其他属性不需要被序列化,打个比方,如果一个用户有一些敏感信息(如密码,银行卡号等),为了安全起见,不希望在网络操作(主要涉及到序列化操作,本地序列化缓存也适用)中被传输,这些信息对应的变量就可以加上transient关键字。换句话说,这个字段的生命周期仅存于调用者的内存中而不会写到磁盘里持久化。
总之,java的transient关键字为我们提供了便利,你只需要实现Serilizable接口,将不需要序列化的属性前添加关键字transient,序列化对象的时候,这个属性就不会序列化到指定的目的地中。
*********transient 解释:
  @transient private var vocab: Array[VocabWord] = null
  @transient private var vocabHash = mutable.HashMap.empty[String, Int]

********************************************************************************************************************
from :org.apache.spark.ml.feature.Word2Vec#fit
override def fit(dataset: Dataset[_]): Word2VecModel = {
  transformSchema(dataset.schema, logging = true)
  val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
  val wordVectors = new feature.Word2Vec()
    .setLearningRate($(stepSize))
    .setMinCount($(minCount))
    .setNumIterations($(maxIter))
    .setNumPartitions($(numPartitions))
    .setSeed($(seed))
    .setVectorSize($(vectorSize))
    .setWindowSize($(windowSize))
    .setMaxSentenceLength($(maxSentenceLength))
    .fit(input)
  copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
*********************************************************************************************************************
//dataset来自上面的input,里面是:Seq[String]
  private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {  //构建每个词的类
    val words = dataset.flatMap(x => x)  //把所有的词压平,统计词频
    vocab = words.map(w => (w, 1))
      .reduceByKey(_ + _)
      .filter(_._2 >= minCount)  //过滤词频大于minCount的词
      .map(x => VocabWord(
        x._1,
        x._2,
        new Array[Int](MAX_CODE_LENGTH),
        new Array[Int](MAX_CODE_LENGTH),
        0))
      .collect()
      .sortWith((a, b) => a.cn > b.cn)  //按频数从大到小排序

    vocabSize = vocab.length
    require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
      "the setting of minCount, which could be large enough to remove all your words in sentences.")


    var a = 0
    while (a < vocabSize) {
      vocabHash += vocab(a).word -> a   //@transient private var vocabHash = mutable.HashMap.empty[String, Int],【词,词频】  生成hashMap(K:word,V:a)--> 对词典中所有元素进行映射,方便查找
      trainWordsCount += vocab(a).cn    //训练词的个数统计
      a += 1
    }
    logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
  }

//创建sigmoid函数查询表
  private def createExpTable(): Array[Float] = {
    val expTable = new Array[Float](EXP_TABLE_SIZE)
    var i = 0
    while (i < EXP_TABLE_SIZE) {
      val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
      expTable(i) = (tmp / (tmp + 1.0)).toFloat
      i += 1
    }
    expTable
  }

//构造哈夫曼树
  private def createBinaryTree(): Unit = {
    val count = new Array[Long](vocabSize * 2 + 1)  //二叉树中所有的结点
    val binary = new Array[Int](vocabSize * 2 + 1)  //设置每个结点的Huffman编码:左1,右0
    val parentNode = new Array[Int](vocabSize * 2 + 1)  //存储每个结点的父节点
    val code = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的Huffman编码
    val point = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的路径(经历过哪些结点)
    var a = 0
    while (a < vocabSize) { //节点 0~vocabSize-1  赋值为该节点词的频次  左边都是叶子结点
      count(a) = vocab(a).cn
      a += 1
    }
    while (a < 2 * vocabSize) {  //节点 vocabSize~2*vocabSize-1  赋值为1e9  右边都是父节点
      count(a) = 1e9.toInt
      a += 1
    }
    var pos1 = vocabSize - 1
    var pos2 = vocabSize

//min1i和min2i是左右节点
    var min1i = 0
    var min2i = 0


    a = 0
    while (a < vocabSize - 1) {
      if (pos1 >= 0) {
        if (count(pos1) < count(pos2)) {
          min1i = pos1
          pos1 -= 1
        } else {
          min1i = pos2
          pos2 += 1
        }
      } else {
        min1i = pos2
        pos2 += 1
      }
      if (pos1 >= 0) {
        if (count(pos1) < count(pos2)) {
          min2i = pos1
          pos1 -= 1
        } else {
          min2i = pos2
          pos2 += 1
        }
      } else {
        min2i = pos2
        pos2 += 1
      }
      count(vocabSize + a) = count(min1i) + count(min2i)   //从三个点里面找到和最小的两个点
      parentNode(min1i) = vocabSize + a    //父节点
      parentNode(min2i) = vocabSize + a    //父节点
      binary(min2i) = 1          //定义右子树为1
      a += 1
    }
    // Now assign binary code to each vocabulary word
    var i = 0
    a = 0
    while (a < vocabSize) {
      var b = a
      i = 0
      while (b != vocabSize * 2 - 2) {  //哈弗曼树一共有2n-1个节点,所以vocabSize*2-2指的是根节点,遍历a二叉树路径上的每个节点,除了根节点
        code(i) = binary(b)         //第b个结点的Huffman编码是0 or 1
        point(i) = b                //存储路径,经过b结点
        i += 1
        b = parentNode(b)          //按照路径去找下一个节点,遍历b的下个节点
      }
      vocab(a).codeLen = i         //存储到达叶子结点a,要经过多少个结点
      vocab(a).point(0) = vocabSize - 2 //每个词的point(0)都是一样的为vocabSize-2,这个是根节点,在这里哈弗曼树已经建立完成了,point记录的是叶子结点a的从根节点以来的路径,因为哈弗曼树所有词的节点是叶子结点,从根节点到叶子节点上的路径都是中间节点如图一所示的,路径里面的节点都减了vocabSize,因为中间节点是vocabSize-1个,所以又都放在0到vocabSize-1的范围了。
      b = 0
      while (b < i) {        //遍历a二叉树路径上的每个节点
        vocab(a).code(i - b - 1) = code(b)   //根据上一步的结果,对节点a的哈夫曼编码赋值
        vocab(a).point(i - b) = point(b) - vocabSize  //根据上一步的结果,对节点a的路径节点进行赋值
        b += 1
      }
      a += 1    //下一个词
    }
  }


  /**
   * Computes the vector representation of each word in vocabulary.
   * @param dataset an RDD of sentences,
   *                each sentence is expressed as an iterable collection of words
   * @return a Word2VecModel
   */
  @Since("1.1.0")
  def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {


    learnVocab(dataset)         //构建词汇类


    createBinaryTree()          //构建 Huffman 树


    val sc = dataset.context


    val expTable = sc.broadcast(createExpTable())   //广播sigmod查询表
    val bcVocab = sc.broadcast(vocab)               //广播词汇类
    val bcVocabHash = sc.broadcast(vocabHash)       //广播词 词索引
    try {
      doFit(dataset, sc, expTable, bcVocab, bcVocabHash)  
    } finally {
      expTable.destroy(blocking = false)   //销毁广播变量
      bcVocab.destroy(blocking = false)
      bcVocabHash.destroy(blocking = false)
    }
  }


  private def doFit[S <: Iterable[String]](
    dataset: RDD[S], sc: SparkContext,
    expTable: Broadcast[Array[Float]],
    bcVocab: Broadcast[Array[VocabWord]],
    bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
    // each partition is a collection of sentences,
    // will be translated into arrays of Index integer
    val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>  //RDD[S] S为RDD里面最细粒度的数据结构,里面装的就是这个类型S的数据
      // Each sentence will map to 0 or more Array[Int]
      sentenceIter.flatMap { sentence =>
        // Sentence of words, some of which map to a word index
        val wordIndexes = sentence.flatMap(bcVocabHash.value.get) // flatMap对句子中每个词得到index,得到每个句子每个词的index
        // break wordIndexes into trunks of maxSentenceLength when has more
        wordIndexes.grouped(maxSentenceLength).map(_.toArray)  //如果有的句子的长度大于1000,就给它分组为1000单位,并是array | wordIndexes是个Iterable[Int]格式利用grouped函数对其分组。
           wordIndexes.grouped(maxSentenceLength)返回的是:Iterator[Array[Int]]
      }
    }


    val newSentences = sentences.repartition(numPartitions).cache()   //按照给定的分区数,进行重分区  并且全部cache
    val initRandom = new XORShiftRandom(seed)                         //


    if (vocabSize.toLong * vectorSize >= Int.MaxValue) {   //如果词汇量*词向量长度 大于或等于 INT最大值 就抛出异常
      throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +
        " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +
        "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.")
    }


    val syn0Global =
      Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)   //初始化叶子节点,分词向量随机设置初始值
    val syn1Global = new Array[Float](vocabSize * vectorSize)                                   //初始化非叶子结点,参数向量设置初始值为0
    val totalWordsCounts = numIterations * trainWordsCount + 1                                  //迭代次数*所有分词的个数 +1 
    var alpha = learningRate                                                                    //学习率


    for (k <- 1 to numIterations) {   //开始迭代
      val bcSyn0Global = sc.broadcast(syn0Global)     
      val bcSyn1Global = sc.broadcast(syn1Global)
      val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount //已经迭代过的词数


      val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
        val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
        val syn0Modify = new Array[Int](vocabSize)
        val syn1Modify = new Array[Int](vocabSize)
        /**
        def foldLeft[B](z: B)(op: (B, A) => B): B = {
          var result = z
          this foreach (x => result = op(result, x))
          result
        }
        */
        val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { //{}里面是OP[具体操作],初始值是(bcSyn0Global.value, bcSyn1Global.value, 0L, 0L),然后在每个分区里面串行运行,x是case ((syn0, syn1, lastWordCount, wordCount), sentence),最终结果是:(syn0, syn1, lwc, wc) 和Z同种类型。最后的结果(syn0, syn1, lwc, wc)总是更新存在的。总是赋值给B类型。最后结果也是B,B就是(syn0, syn1, lwc, wc)类型的数据。iter每迭代一次sentence就会更新一次B
          case ((syn0, syn1, lastWordCount, wordCount), sentence) =>  //每个分区里面的每个sentence
            var lwc = lastWordCount  //每次迭代的最新的
            var wc = wordCount
            if (wordCount - lastWordCount > 10000) { //当句子迭代10000个词的时候。每迭代10000词的时候就更新一下alpha
              lwc = wordCount   //更改上次词数
              alpha = learningRate *
                (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) /
                  totalWordsCounts)   //随着wordCount变大,alpha变小
              if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001  //当小于learningRate * 0.0001时候,直接等于learningRate * 0.0001
              logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " +
                s"alpha = $alpha")
            }
            wc += sentence.length //wc就是上次的wordCount,一直累加句子的长度。
            var pos = 0
            while (pos < sentence.length) {  //开始迭代,一个sentence中的pos位置的词,pos从0开始
              val word = sentence(pos)
              val b = random.nextInt(window)  //b是window内的随机数
              // Train Skip-gram
              var a = b
              while (a < window * 2 + 1 - b) {  //因为开始a = b ,从b开始到 window * 2 + 1 - b,也就是取pos词左右window - b 个词,迭代pos附近的窗口:window - b
                if (a != window) { //当a不是中心词
                  val c = pos - window + a   //pos位置的词pos-(window - a)[真实位置]
                  if (c >= 0 && c < sentence.length) {  //pos的左右位置迭代取值可能是负的或者超出句子长度,限定范围
                    val lastWord = sentence(c)    //该词的index
                    val l1 = lastWord * vectorSize  //syn0的index
                    val neu1e = new Array[Float](vectorSize) //相当于公式里面的e,就是x的梯度迭代项
                    // Hierarchical softmax
                    var d = 0
                    while (d < bcVocab.value(word).codeLen) {  //迭代中心词的路径哈夫曼二分类
                      val inner = bcVocab.value(word).point(d)  //路径上节点index
                      val l2 = inner * vectorSize               //syn1对应的index
                      // Propagate hidden -> output    blas.sdot函数解释:sdot(int n, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sx .* sy,并且sx[_sx_offset,incx*n + _sx_offset],sy[_sy_offset,incy*n + _sy_offset]
                      var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)   //向量点乘,syn0 .* syn1 , syn0[l1,l1+1*vectorSize],syn1[l2,l2+1*vectorSize]
                      if (f > -MAX_EXP && f < MAX_EXP) {                    
                        val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
                        f = expTable.value(ind)   //索引到sigmod函数表的值
                        val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat   //梯度
                        blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)   //neu1e = g * syn1 + neu1e  blas.saxpy函数解释:saxpy(int n, float sa, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sy= sa*sx+sy,并且sx[_sx_offset,_sx_offset+incx*n],sy[_sy_offset,_sy_offset+incy*n]
                        blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)   //syn1 = g * syn0 + syn1
                        syn1Modify(inner) += 1          //记录参数向量里面的点被更新次数
                      }
                      d += 1
                    }
                    blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)  //syn0 = 1.0f * neu1e + syn0   a的情况下,迭代完成中心词pos附近的一个词的参数向量和词向量
                    syn0Modify(lastWord) += 1
                  }
                }
                a += 1
              }
              pos += 1   //循环到这个句子的下一个中心词
            }
            (syn0, syn1, lwc, wc)
        }
        val syn0Local = model._1   //syn0 为叶子结点向量,即分词向量
        val syn1Local = model._2   //syn1 为非叶子结点向量,即参数向量
        // Only output modified vectors.   Iterator.tabulate函数: Creates an iterator producing the values of a given function over a range of integer values starting from 0.
        Iterator.tabulate(vocabSize) { index =>
          if (syn0Modify(index) > 0) {
            Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
          } else {
            None
          }
        }.flatten ++ Iterator.tabulate(vocabSize) { index =>
          if (syn1Modify(index) > 0) {
            Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
          } else {
            None
          }
        }.flatten    //得到n个词向量的结果,n-1个中间节点的向量结果,两个结果(index,array)拼接起来,并且中间参数节点向量的index 从vocabSize开始编号
      }
      val synAgg = partial.reduceByKey { case (v1, v2) =>   //注意partial是所有分区内部的结果,按照同样的index下的array进行聚合,直接把所有分区的结果暴力累加
          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
          v1
      }.collect()
      var i = 0
      while (i < synAgg.length) {  //分别得到分词向量和参数向量
        val index = synAgg(i)._1
        if (index < vocabSize) {
          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
        } else {
          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
        }
        i += 1
      }
      bcSyn0Global.destroy(false)
      bcSyn1Global.destroy(false)
    }
    newSentences.unpersist()


    val wordArray = vocab.map(_.word)
    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)  //得到分词向量
  }


  /**
   * Computes the vector representation of each word in vocabulary (Java version).
   * @param dataset a JavaRDD of words
   * @return a Word2VecModel
   */
  @Since("1.1.0")
  def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
    fit(dataset.rdd.map(_.asScala))
  }
}


/**
* Word2Vec model
* @param wordIndex maps each word to an index, which can retrieve the corresponding
*                  vector from wordVectors
* @param wordVectors array of length numWords * vectorSize, vector corresponding
*                    to the word mapped with index i can be retrieved by the slice
*                    (i * vectorSize, i * vectorSize + vectorSize)
*/
@Since("1.1.0")
class Word2VecModel private[spark] (
    private[spark] val wordIndex: Map[String, Int],
    private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable {


  private val numWords = wordIndex.size
  // vectorSize: Dimension of each word's vector.
  private val vectorSize = wordVectors.length / numWords


  // wordList: Ordered list of words obtained from wordIndex.
  private val wordList: Array[String] = {
    val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
    wl.toArray
  }
  // wordVecNorms: Array of length numWords, each value being the Euclidean norm
  //               of the wordVector.
  private val wordVecNorms: Array[Float] = {
    val wordVecNorms = new Array[Float](numWords)
    var i = 0
    while (i < numWords) {
      val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
      wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
      i += 1
    }
    wordVecNorms
  }
  @Since("1.5.0")
  def this(model: Map[String, Array[Float]]) = {
    this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
  }
  override protected def formatVersion = "1.0"
  @Since("1.4.0")
  def save(sc: SparkContext, path: String): Unit = {
    Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
  }
  /**
   * Transforms a word to its vector representation
   * @param word a word
   * @return vector representation of word
   */
  @Since("1.1.0")
  def transform(word: String): Vector = {
    wordIndex.get(word) match {
      case Some(ind) =>
        val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
        Vectors.dense(vec.map(_.toDouble))
      case None =>
        throw new IllegalStateException(s"$word not in vocabulary")
    }
  }
  /**
   * Find synonyms of a word; do not include the word itself in results.
   * @param word a word
   * @param num number of synonyms to find
   * @return array of (word, cosineSimilarity)
   */
  @Since("1.1.0")
  def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
    val vector = transform(word)
    findSynonyms(vector, num, Some(word))
  }
  /**
   * Find synonyms of the vector representation of a word, possibly
   * including any words in the model vocabulary whose vector respresentation
   * is the supplied vector.
   * @param vector vector representation of a word
   * @param num number of synonyms to find
   * @return array of (word, cosineSimilarity)
   */
  @Since("1.1.0")
  def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
    findSynonyms(vector, num, None)
  }
  /**
   * Find synonyms of the vector representation of a word, rejecting
   * words identical to the value of wordOpt, if one is supplied.
   * @param vector vector representation of a word
   * @param num number of synonyms to find
   * @param wordOpt optionally, a word to reject from the results list
   * @return array of (word, cosineSimilarity)
   */
  private def findSynonyms(
      vector: Vector,
      num: Int,
      wordOpt: Option[String]): Array[(String, Double)] = {
    require(num > 0, "Number of similar words should > 0")
    val fVector = vector.toArray.map(_.toFloat)
    val cosineVec = new Array[Float](numWords)
    val alpha: Float = 1
    val beta: Float = 0
    // Normalize input vector before blas.sgemv to avoid Inf value
    val vecNorm = blas.snrm2(vectorSize, fVector, 1)
    if (vecNorm != 0.0f) {
      blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)
    }
    blas.sgemv(
      "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
    var i = 0
    while (i < numWords) {
      val norm = wordVecNorms(i)
      if (norm == 0.0f) {
        cosineVec(i) = 0.0f
      } else {
        cosineVec(i) /= norm
      }
      i += 1
    }
    val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))
    var j = 0
    while (j < numWords) {
      pq += Tuple2(wordList(j), cosineVec(j))
      j += 1
    }
    val scored = pq.toSeq.sortBy(-_._2)
    val filtered = wordOpt match {
      case Some(w) => scored.filter(tup => w != tup._1)
      case None => scored
    }
    filtered
      .take(num)
      .map { case (word, score) => (word, score.toDouble) }
      .toArray
  }
  /**
   * Returns a map of words to their vector representations.
   */
  @Since("1.2.0")
  def getVectors: Map[String, Array[Float]] = {
    wordIndex.map { case (word, ind) =>
      (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
    }
  }
}
@Since("1.4.0")
object Word2VecModel extends Loader[Word2VecModel] {
  private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
    model.keys.zipWithIndex.toMap
  }
  private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
    require(model.nonEmpty, "Word2VecMap should be non-empty")
    val (vectorSize, numWords) = (model.head._2.length, model.size)
    val wordList = model.keys.toArray
    val wordVectors = new Array[Float](vectorSize * numWords)
    var i = 0
    while (i < numWords) {
      Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
      i += 1
    }
    wordVectors
  }
  private object SaveLoadV1_0 {
    val formatVersionV1_0 = "1.0"
    val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
    case class Data(word: String, vector: Array[Float])
    def load(sc: SparkContext, path: String): Word2VecModel = {
      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
      val dataFrame = spark.read.parquet(Loader.dataPath(path))
      // Check schema explicitly since erasure makes it hard to use match-case for checking.
      Loader.checkSchema[Data](dataFrame.schema)
      val dataArray = dataFrame.select("word", "vector").collect()
      val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
      new Word2VecModel(word2VecMap)
    }
    def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
      val vectorSize = model.values.head.length
      val numWords = model.size
      val metadata = compact(render(
        ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
        ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
      // We want to partition the model in partitions smaller than
      // spark.kryoserializer.buffer.max
      val bufferSize = Utils.byteStringAsBytes(
        spark.conf.get("spark.kryoserializer.buffer.max", "64m"))
      // We calculate the approximate size of the model
      // We only calculate the array size, considering an
      // average string size of 15 bytes, the formula is:
      // (floatSize * vectorSize + 15) * numWords
      val approxSize = (4L * vectorSize + 15) * numWords
      val nPartitions = ((approxSize / bufferSize) + 1).toInt
      val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
      spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))
    }
  }
  @Since("1.4.0")
  override def load(sc: SparkContext, path: String): Word2VecModel = {
    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    implicit val formats = DefaultFormats
    val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
    val expectedNumWords = (metadata \ "numWords").extract[Int]
    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
    (loadedClassName, loadedVersion) match {
      case (classNameV1_0, "1.0") =>
        val model = SaveLoadV1_0.load(sc, path)
        val vectorSize = model.getVectors.values.head.length
        val numWords = model.getVectors.size
        require(expectedVectorSize == vectorSize,
          s"Word2VecModel requires each word to be mapped to a vector of size " +
          s"$expectedVectorSize, got vector of size $vectorSize")
        require(expectedNumWords == numWords,
          s"Word2VecModel requires $expectedNumWords words, but got $numWords")
        model
      case _ => throw new Exception(
        s"Word2VecModel.load did not recognize model with (className, format version):" +
        s"($loadedClassName, $loadedVersion).  Supported:\n" +
        s"  ($classNameV1_0, 1.0)")
    }
  }
}

word2vec 的原理 只需要看层次哈弗曼树skip-gram那部分

其中Sparkword2vec使用过程中有以下问题:

  1. 当迭代次数或者分区过多的情况下,会产生Infinity的问题
  2. 训练过程中分区过多准确度会下降
  3. 内存消耗过大,全部cache了
  4. 哈夫曼树的方法时间消耗大。等问题,这些问题最近几天完善都一一解决了嘿嘿
发布了14 篇原创文章 · 获赞 35 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/u014552678/article/details/104001725
今日推荐