pyspark - aggregate (sum) vector element-wise

前端 未结 2 925
南方客
南方客 2021-01-19 15:56

I have what seems like a simple problem but I keep banging my head against the wall with no success. I am essentially trying to do the same thing as this post except that I

相关标签:
2条回答
  • 2021-01-19 16:32

    I think you have to cast the vector column to an array before you can aggregate it.

    from pyspark.ml.linalg import Vectors, VectorUDT
    from pyspark.sql import functions as F
    from pyspark.sql import types as T
    
    def vec2array(v):
      v = Vectors.dense(v)
      array = list([float(x) for x in v])
      return array
    
    vec2array_udf = F.udf(vec2array, T.ArrayType(T.FloatType()))
    
    df = df.withColumn('Vec', vec2array_udf('Vec'))
    
    n = len(df.select('Vec').first()[0])
    bla = df.agg(F.array(*[F.sum(F.col("Vec")[i]) for i in range(n)]).alias("sum"))
    bla.show(truncate=False)
    
    0 讨论(0)
  • 2021-01-19 16:47

    I eventually figured this out (I'm lying, one of my coworkers figured it out for me) so I'll post the answer here in case anyone has the same issue.

    You can use fold similar to how it's done in the scala example linked in the original question. Syntax in pyspark is like so:

    # find out how many Xs we're iterating over to establish the range below
    vec_df = df.select('Vec')
    num_cols = len(vec_df.first().Vec)
    
    # iterate over vector to sum each "column"    
    vec_sums = vec_df.rdd.fold([0]*num_cols, lambda a,b: [x + y for x, y in zip(a, b)])
    

    Brief explanation: rdd.fold() takes two arguments. The first is an initialization array, in this case [0]*num_cols which is just an array of 0's. The second is a function to apply to the array and to use for iterating over each row of the dataframe. So for each row it does lambda a,b: [x + y for x, y in zip(a, b)] which just adds this row element-wise to what it has computed so far.

    You can use my code in the original question to generate a toy dataframe to test this on. Hope that's helpful to someone.

    0 讨论(0)
提交回复
热议问题