spark mllib HashingTF解析

社会主义新天地 提交于 2019-11-27 22:02:29

在处理文本数据,尤其是自然语言处理的场景中,hashingTF使用的比较多;
Mllib使用hashing trick实现词频。元素的特征应用一个hash`函数映射到一个索引(即词),通过这个索引计算词频。这个方法避免计算全局的词-索引映射,因为全局的词-索引映射在大规模语料中花费较大。
但是,它会出现哈希冲突,这是因为不同的元素特征可能得到相同的哈希值。为了减少碰撞冲突,我们可以增加目标特征的维度,例如哈希表的桶数量。默认的特征维度是1048576。
1、spark ML中使用的hash方法基本上都是murmurhash实现,
private var binary = false
private var hashAlgorithm = HashingTF.Murmur3
// math.pow(2,20)=1048576 代表hashingTF中能表征的特征个数
def this() = this(1 << 20)
private[spark] val seed = 42

2、获取hash的方法

/**
   * Returns the index of the input term.
   */
  @Since("1.1.0")
  def indexOf(term: Any): Int = {
    Utils.nonNegativeMod(getHashFunction(term), numFeatures)
  }
/**
   用mur hash计算输入特征的hash code
   */
  private[spark] def murmur3Hash(term: Any): Int = {
    term match {
      case null => seed
      case b: Boolean => hashInt(if (b) 1 else 0, seed)
      case b: Byte => hashInt(b, seed)
      case s: Short => hashInt(s, seed)
      case i: Int => hashInt(i, seed)
      case l: Long => hashLong(l, seed)
      case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
      case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
      case s: String =>
        val utf8 = UTF8String.fromString(s)
        hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
      case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " +
        s"support type ${term.getClass.getCanonicalName} of input data.")
    }
  }
  
/* *
对返回的hash code取模 保证特征的索引在numFeatures范围内
  */
  def nonNegativeMod(x: Int, mod: Int): Int = {
    val rawMod = x % mod
    rawMod + (if (rawMod < 0) mod else 0)
  }

所有输入的特征统一有该方法返回特征对应的hash索引,用于解决大数据量的计算压力;
Tips:
1)、当你使用HashingTF和IDF训练完模型后,一定要保存你的IDFModel,还有HashingTF的参数,当后续你使用模型的时候
需要使用HashingTF相同的参数和模型生成时的同一个IDFModel,比如在spark-streaming中,切记!
2)、切记对自己语料库中特征数量要有预估,为了减少碰撞,将numFeatures设置为1048576

3、重点的实现过程

def transform(document: Iterable[_]): Vector = {
    val termFrequencies = mutable.HashMap.empty[Int, Double]
    // 返回特征的出现次数 经典的wordcount
    val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
    val hashFunc: Any => Int = getHashFunction
    document.foreach { term =>
      val i = Utils.nonNegativeMod(hashFunc(term), numFeatures)
      termFrequencies.put(i, setTF(i))
    }
    Vectors.sparse(numFeatures, termFrequencies.toSeq)
  }

计算单个预料中的每个特征的出现次数

举例子:

 val nGram = new NGram().setN(3).setInputCol("l_behavior").setOutputCol("elements")
 val nGramDF = nGram.transform(leftDF)
 val hashingTF = new HashingTF().
      setInputCol("elements").
      setOutputCol("features").
      setNumFeatures(1500000)
 val hashingDF = hashingTF.transform(nGramDF)

以上就是spark ML中HashingTF的实现;其数据类型为具有iterator性质的,一般会把单个预料切分成Array[String]。
这个实现比较简单,后续有时间会更新CountVectorizer;
HashingTF没有保留原有语料库中的原始的词语。记住这点才能更好的区分后续的transform feature方法。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!