pyspark - get consistent random value across Spark sessions

允我心安 提交于 2019-12-11 16:25:19

问题


I want to add a column of random values to a dataframe (has an id for each row) for something I am testing. I am struggling to get reproducible results across Spark sessions - same random value against each row id. I am able to reproduce the results by using

from pyspark.sql.functions import rand

new_df = my_df.withColumn("rand_index", rand(seed = 7))

but it only works when I am running it in same Spark session. I am not getting same results once I relaunch Spark and run my script.

I also tried defining a udf, testing to see if i can generate random values (integers) within an interval and using random from Python with random.seed set

import random
random.seed(7)
spark.udf.register("getRandVals", lambda x, y: random.randint(x, y), LongType())

but to no avail.

Is there a way to ensure reproducible random number generation across Spark sessions such that a row id gets same random value? I would really appreciate some guidance :) Thanks for the help!


回答1:


I suspect that you are getting the same common values for the seed, but in different order based on your partitioning which is influenced by the data distribution when reading from disk and there could be more or less data per time. But I am not privy to your code in reality.

The rand function generates the same random data (what is the point of the seed otherwise) and somehow the partitions get a slice of it. If you look you should guess the pattern!

Here is an an example of 2 different cardinality dataframes. You can see that the seed gives the same or a superset of results. So, ordering and partitioning play a role imo.

from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import col
df1 = spark.range(1, 5).select(col("id").cast("double"))
df1 = df1.withColumn("rand_index", rand(seed = 7))                                   
df1.show()

df1.rdd.getNumPartitions()
print('Partitioning distribution: '+ str(df1.rdd.glom().map(len).collect()))

returns:

+---+-------------------+
| id|         rand_index|
+---+-------------------+
|1.0|0.06498948189958098|
|2.0|0.41371264720975787|
|3.0|0.12030715258495939|
|4.0| 0.2731073068483362|
+---+-------------------+

8 partitions & Partitioning distribution: [0, 1, 0, 1, 0, 1, 0, 1]

The same again with more data:

...
df1 = spark.range(1, 10).select(col("id").cast("double"))
...

returns:

+---+-------------------+
| id|         rand_index|
+---+-------------------+
|1.0| 0.9147159860432812|
|2.0|0.06498948189958098|
|3.0| 0.7069655052310547|
|4.0|0.41371264720975787|
|5.0| 0.1982919638208397|
|6.0|0.12030715258495939|
|7.0|0.44292918521277047|
|8.0| 0.2731073068483362|
|9.0| 0.7784518091224375|
+---+-------------------+

8 partitions & Partitioning distribution: [1, 1, 1, 1, 1, 1, 1, 2]

You can see 4 common random values - within a Spark session or out of session.



来源:https://stackoverflow.com/questions/59077897/pyspark-get-consistent-random-value-across-spark-sessions

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!