问题
I want to add a column of random values to a dataframe (has an id for each row) for something I am testing. I am struggling to get reproducible results across Spark sessions - same random value against each row id. I am able to reproduce the results by using
from pyspark.sql.functions import rand
new_df = my_df.withColumn("rand_index", rand(seed = 7))
but it only works when I am running it in same Spark session. I am not getting same results once I relaunch Spark and run my script.
I also tried defining a udf, testing to see if i can generate random values (integers) within an interval and using random from Python with random.seed set
import random
random.seed(7)
spark.udf.register("getRandVals", lambda x, y: random.randint(x, y), LongType())
but to no avail.
Is there a way to ensure reproducible random number generation across Spark sessions such that a row id gets same random value? I would really appreciate some guidance :) Thanks for the help!
回答1:
I suspect that you are getting the same common values for the seed, but in different order based on your partitioning which is influenced by the data distribution when reading from disk and there could be more or less data per time. But I am not privy to your code in reality.
The rand function generates the same random data (what is the point of the seed otherwise) and somehow the partitions get a slice of it. If you look you should guess the pattern!
Here is an an example of 2 different cardinality dataframes. You can see that the seed gives the same or a superset of results. So, ordering and partitioning play a role imo.
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import col
df1 = spark.range(1, 5).select(col("id").cast("double"))
df1 = df1.withColumn("rand_index", rand(seed = 7))
df1.show()
df1.rdd.getNumPartitions()
print('Partitioning distribution: '+ str(df1.rdd.glom().map(len).collect()))
returns:
+---+-------------------+
| id| rand_index|
+---+-------------------+
|1.0|0.06498948189958098|
|2.0|0.41371264720975787|
|3.0|0.12030715258495939|
|4.0| 0.2731073068483362|
+---+-------------------+
8 partitions & Partitioning distribution: [0, 1, 0, 1, 0, 1, 0, 1]
The same again with more data:
...
df1 = spark.range(1, 10).select(col("id").cast("double"))
...
returns:
+---+-------------------+
| id| rand_index|
+---+-------------------+
|1.0| 0.9147159860432812|
|2.0|0.06498948189958098|
|3.0| 0.7069655052310547|
|4.0|0.41371264720975787|
|5.0| 0.1982919638208397|
|6.0|0.12030715258495939|
|7.0|0.44292918521277047|
|8.0| 0.2731073068483362|
|9.0| 0.7784518091224375|
+---+-------------------+
8 partitions & Partitioning distribution: [1, 1, 1, 1, 1, 1, 1, 2]
You can see 4 common random values - within a Spark session or out of session.
来源:https://stackoverflow.com/questions/59077897/pyspark-get-consistent-random-value-across-spark-sessions