How divide or multiply every non-string columns of a PySpark dataframe with a float constant?

老子叫甜甜 提交于 2021-02-16 08:43:54

问题


My input dataframe looks like the below

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Basics").getOrCreate()

df=spark.createDataFrame(data=[('Alice',4.300,None),('Bob',float('nan'),897)],schema=['name','High','Low'])

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 4.3|null|
|  Bob| NaN| 897|
+-----+----+----+

Expected Output if divided by 10.0

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 0.43|null|
|  Bob| NaN| 89.7|
+-----+----+----+

回答1:


I don't know about any library function that could do this, but this snippet seems to do job just fine:

CONSTANT = 10.0

for field in df.schema.fields:
    if str(field.dataType) in ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']:
        name = str(field.name)
        df = df.withColumn(name, col(name)/CONSTANT)


df.show()

outputs:

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice|0.43|null|
|  Bob| NaN|89.7|
+-----+----+----+



回答2:


The below code should solve your problem in a time efficient manner

from pyspark.sql.functions import col

allowed_types = ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']

df = df.select(*[(col(field.name)/10).name(field.name) if str(field.dataType) in allowed_types else col(field.name) for field in df.schema.fields]

Using "withColumn" iteratively might not be a good idea when the number of columns is large.
This is because PySpark dataframes are immutable, so essentially we will be creating a new DataFrame for each column casted using withColumn, which will be a very slow process.

This is where the above code comes in handy.



来源:https://stackoverflow.com/questions/44807818/how-divide-or-multiply-every-non-string-columns-of-a-pyspark-dataframe-with-a-fl

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!