How to get item id from cosine similarity matrix?

前提是你 提交于 2019-12-10 11:29:49

问题


I am using Spark Scala to calculate cosine similarity between the Dataframe rows.

Dataframe schema is below:

root
    |-- itemId: string (nullable = true)
    |-- features: vector (nullable = true)

Sample of the dataframe below

    +-------+--------------------+
    | itemId|            features|
    +-------+--------------------+
    | ab    |[4.7143,0.0,5.785...|
    | cd    |[5.5,0.0,6.4286,4...|
    | ef    |[4.7143,1.4286,6....|
    ........
    +-------+--------------------+

Code to compute the cosine similarities:

val irm = new IndexedRowMatrix(myDataframe.rdd.zipWithIndex().map {
      case (row, index) => IndexedRow(row.getAs[Vector]("features"), index)
}).toCoordinateMatrix.transpose.toRowMatrix.columnSimilarities

In the irm matrix, I have (i, j, score) where i, j are the indexes of item i, and j of my original dataframe. What I would like is to get (itemIdA, itemIdB, score) where itemIdA and itemIdB are the ids of index i and j respectively, by joining this irm with the initial dataframe or if there is any better option?


回答1:


Create a row index before converting the dataframe to a matrix and create a mapping between the index and the id. After the computation, use the created Map to convert the column index (previously row index but changed with the transpose) to the id.

val rdd = myDataframe.as[(String, org.apache.spark.mllib.linalg.Vector)].rdd.zipWithIndex()
val indexMap = rdd.map{case ((id, vec), index) => (index, id)}.collectAsMap()

Calculate the cosine similarities as before using the :

val irm = new IndexedRowMatrix(rdd.map{case ((id, vec), index) => IndexedRow(index, vec)})
  .toCoordinateMatrix().transpose().toRowMatrix().columnSimilarities()

Convert column indices back to the ids:

irm.entries.map(e => (indexMap(e.i), indexMap(e.j), e.value)) 

This should give you what you are looking for.



来源:https://stackoverflow.com/questions/51163248/how-to-get-item-id-from-cosine-similarity-matrix

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!