How to find the argmax of a vector in PySpark ML

∥☆過路亽.° 提交于 2021-01-07 05:48:27

问题


My model has output a DenseVector column, and I'd like to find the argmax. This page suggests this function should be available, but I'm not sure what the syntax should be.

Is it df.select("mycolumn").argmax()?


回答1:


I could not find the documents for argmax operation in python. but you can do them by converting them to arrays

For pyspark 3.0.0

from pyspark.ml.functions import vector_to_array    
tst_arr = tst_df.withColumn("arr",vector_to_array(F.col('vector_column')))
tst_max=tst_arr.withColumn("max_value",F.array_max("arr"))
tst_max_exp = tst_max.select('*',F.posexplode("arr"))
tst_fin = tst_max_exp.where('col==max_value')

For pyspark <3.0.0

from pyspark.sql.functions import udf
@udf
def vect_argmax(row):
    row_arr = row.toArray()
    max_pos = np.argmax(row_arr)
    return(int(max_pos))
tst_fin = tst_df.withColumn("argmax",vect_argmax(F.col('probability')))



回答2:


Have you tried

from pyspark.sql.functions import col
df.select(col("mycolumn").argmax())


来源:https://stackoverflow.com/questions/63046572/how-to-find-the-argmax-of-a-vector-in-pyspark-ml

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!