Spark SQL
简介
- SparkSQL 的前身是 Shark, SparkSQL 产生的根本原因是其完全脱离了 Hive 的限制。(Shark 底层依赖于 Hive 的解析器, 查询优化器)
- SparkSQL 支持查询原生的 RDD。
- 能够在 scala/java 中写 SQL 语句。 支持简单的 SQL 语法检查, 能够在 Scala 中 写Hive 语句访问 Hive 数据, 并将结果取回作为RDD使用
- SparkSQL 的前身是 Shark, SparkSQL 产生的根本原因是其完全脱离了 Hive 的限制。(Shark 底层依赖于 Hive 的解析器, 查询优化器)
Spark on Hive 和 Hive on Spark
- Spark on Hive: Hive 只作为储存角色, Spark负责 sql 解析优化, 执行。
- Hive on Spark: Hive 即作为存储又负责 sql 的解析优化, Spark 负责执行。
Dataset 与 DataFrame
- Dataset 是一个分布式数据容器,与 RDD 类似, 然而 DataSet 更像 传统数据库的二维表格, 除了数据以外, 还掌握的结构信息, 即schema。
- 同时, 与 Hive 类似, Dataset 也支持嵌套数据类型 (struct、array 和 map)。
- 从 API 易用性角度上看, DataSet API 提供的是一套高层的关系操作, 比函数式的 RDD API 更加友好, 门槛更低。
- Dataset 的底层封装的是RDD, 当 RDD 的泛型是 Row 类型的时候, 我们可以可以称它为 DataFrame。即 Dataset
= DataFrame
SparkSQL 的数据源
SparkSQL的数据源可以是JSON类型的字符串, JDBC, Parquent, Hive, HDFS 等。
SparkSQL 底层架构
首先拿到 sql 后解析一批未被解决的逻辑计划, 再经过分析得到分析后的逻辑计划, 再经过一批优化规则转换成一批最佳优化的逻辑计划, 再经过一批优化规则转换成一批最佳优化的逻辑计划, 再经过 SparkPlanner 测策略转化成一批物理计划, 随后经过消费模型转换成一个个的Spark任务执行。
谓词下推 (predicate Pushdown)
- 从关系型数据库借鉴而来, 关系型数据中谓词下推到外部数据库用以减少数据传输
- 基本思想: 尽可能早的处理表达式
- 属于逻辑优化, 优化器将谓词过滤下推到数据源, 使物理执行跳过无关数据
- 参数打开设置: hive.optimize.ppd=true
创建 Dataset 的几种方式
读取 json 格式的文件创建 DataSet
注意事项:
json 文件中的 json 数据不能嵌套 json 格式数据。
Dataset 是一个一个 Row 类型的 RDD, ds.rdd()/ds.javaRDD()。
可以两种方式读取json格式的文件。
df.show() 默认显示前 20 行数据。
Dataset 原生 API 可以操作 Dataset(不方便)。
注册成临时表时, 表中的列默认按 ascii 顺序显示列。
案例:
json:
{"name":"burning","age": 18} {"name":"atme"} {"name":"longdd","age":18} {"name":"yyf","age":28} {"name":"zhou","age":20} {"name":"blaze"} {"name":"ocean","age":18} {"name":"xiaoliu","age":28} {"name":"zhangsan","age":28} {"name":"lisi"} {"name":"wangwu","age":18}
Java代码:
package com.ronnie.java.json; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; public class ReadJson { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("jsonFile").master("local").getOrCreate(); /** * Dataset的底层是一个个的 RDD, RDD的泛型是Row类型。 * 以下两种方式都可以读取json格式的文件 */ Dataset<Row> ds = sparkSession.read().format("json").load("./resources/json"); // Dataset<Row> ds = sparkSession.read().json("data/json"); ds.show(); /** * +----+--------+ * | age| name| * +----+--------+ * | 18| burning| * |null| atme| * | 18| longdd| * | 28| yyf| * | 20| zhou| * |null| blaze| * | 18| ocean| * | 28| xiaoliu| * | 28|zhangsan| * |null| lisi| * | 18| wangwu| * +----+--------+ */ /** * Dataset 转换为 RDD */ JavaRDD<Row> javaRDD = ds.javaRDD(); /** * 显示 Dataset 中的内容, 默认显示前20行. 如果显示多行要指定多少行show(行数) * 注意: 当有多个列时, 显示的列先后顺序是按列的ascii码顺序先后显示 */ /** * 树形的形式显示schema信息 */ ds.printSchema(); /** * root * |-- age: long (nullable = true) * |-- name: string (nullable = true) */ /** * Dataset自带的API 操作Dataset */ // select name from table ds.select("name").show(); /** * +--------+ * | name| * +--------+ * | burning| * | atme| * | longdd| * | yyf| * | zhou| * | blaze| * | ocean| * | xiaoliu| * |zhangsan| * | lisi| * | wangwu| * +--------+ */ // select name age+10 as addage from table ds.select(ds.col("name"),ds.col("age").plus(10).alias("addage")).show(); /** * +--------+------+ * | name|addage| * +--------+------+ * | burning| 28| * | atme| null| * | longdd| 28| * | yyf| 38| * | zhou| 30| * | blaze| null| * | ocean| 28| * | xiaoliu| 38| * |zhangsan| 38| * | lisi| null| * | wangwu| 28| * +--------+------+ */ // select name, age from table where age > 19 ds.select(ds.col("name"),ds.col("age")).where(ds.col("age").gt(19)).show(); /** * +--------+---+ * | name|age| * +--------+---+ * | yyf| 28| * | zhou| 20| * | xiaoliu| 28| * |zhangsan| 28| * +--------+---+ */ // select count(*) from table group by age ds.groupBy(ds.col("age")).count().show(); /** * +----+-----+ * | age|count| * +----+-----+ * |null| 3| * | 28| 3| * | 18| 4| * | 20| 1| * +----+-----+ */ /** * 将Dataset 注册成临时的一张表, 这张表临时注册到内存中, 是逻辑上的表, 不会雾化到磁盘 */ ds.createOrReplaceTempView("jtable"); Dataset<Row> result = sparkSession.sql("select age, count(*) as gege from jtable group by age"); result.show(); /** * +----+----+ * | age|gege| * +----+----+ * |null| 3| * | 28| 3| * | 18| 4| * | 20| 1| * +----+----+ */ sparkSession.stop(); } }
通过 json 格式的 RDD 创建 DataSet
package com.ronnie.java.json; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import java.util.Arrays; public class CreateDatasetFromJsonRDD { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("jsonrdd").master("local").getOrCreate(); SparkContext sc = sparkSession.sparkContext(); JavaSparkContext jsc = new JavaSparkContext(sc); JavaRDD<String> nameRDD = jsc.parallelize(Arrays.asList( "{'name':'hao','age':\"24\"}", "{\"name\":\"mu\",\"age\":\"26\"}", "{\"name\":\"xiao8\",\"age\":\"27\"}" )); JavaRDD<String> scoreRDD = jsc.parallelize(Arrays.asList( "{\"name\":\"zhangsan\",\"score\":\"100\"}", "{\"name\":\"mu\",\"score\":\"200\"}", "{\"name\":\"wangwu\",\"score\":\"300\"}" )); Dataset<Row> nameds = sparkSession.read().json(nameRDD); Dataset<Row> scoreds = sparkSession.read().json(scoreRDD); nameds.createOrReplaceTempView("nameTable"); scoreds.createOrReplaceTempView("scoreTable"); Dataset<Row> result = sparkSession.sql("select nameTable.name, nameTable.age, scoreTable.score " + "from nameTable join scoreTable " + "on nameTable.name = scoreTable.name"); result.show(); /** * +----+---+-----+ * |name|age|score| * +----+---+-----+ * | mu| 26| 200| * +----+---+-----+ */ sparkSession.stop(); } }
非 json 格式的 RDD 创建 DataSet
通过反射的方式将非 json 格式的RDD转换成 Dataset
- 自定义类要可序列化
- 自定义类的访问级别是 Public
- RDD 转成 Dataset 后会根据映射将字段 ASCII 码排序
- 将 Dataset 转换成 RDD时获取字段的两种方式:
- ds.getInt(0) 下标获取(不推荐使用)
- ds.getAs("列名") 获取(推荐使用)
- person.txt
1,longdd,27 2,yyf,26 3,zhou,27 4,burning,30 5,atme,21
- Person.java
package com.ronnie.java.entity; import java.io.Serializable; public class Person implements Serializable { private static final long serialVersionUID = 1L; private String id ; private String name; private Integer age; public String getId() { return id; } public void setId(String id) { this.id = id; } public String getName() { return name; } public void setName(String name) { this.name = name; } public Integer getAge() { return age; } public void setAge(Integer age) { this.age = age; } @Override public String toString() { return "Person{" + "id='" + id + '\'' + ", name='" + name + '\'' + ", age=" + age + '}'; } }
CreateDatasetRDDWithReflect
package com.ronnie.java.json; import com.ronnie.java.entity.Person; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; public class CreateDatasetFromRDDWithReflect { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("reflect").master("local").getOrCreate(); SparkContext sc = sparkSession.sparkContext(); JavaSparkContext jsc = new JavaSparkContext(sc); JavaRDD<String> lineRDD = jsc.textFile("./resources/person.txt"); JavaRDD<Person> personRDD = lineRDD.map(new Function<String, Person>() { private static final long serialVersionUID = 1L; @Override public Person call(String line) throws Exception { Person p = new Person(); p.setId(line.split(",")[0]); p.setName(line.split(",")[1]); p.setAge(Integer.valueOf(line.split(",")[2])); return p; } }); /** * 传入进去Person.class的时候,sqlContext是通过反射的方式创建DataFrame * 在底层通过反射的方式获得Person的所有field,结合RDD本身,就生成了DataFrame */ Dataset<Row> dataFrame = sparkSession.createDataFrame(personRDD, Person.class); dataFrame.show(); /** * +---+---+-------+ * |age| id| name| * +---+---+-------+ * | 27| 1| longdd| * | 26| 2| yyf| * | 27| 3| zhou| * | 30| 4|burning| * | 21| 5| atme| * +---+---+-------+ */ dataFrame.printSchema(); /** * root * |-- age: integer (nullable = true) * |-- id: string (nullable = true) * |-- name: string (nullable = true) */ dataFrame.registerTempTable("person"); Dataset<Row> sql = sparkSession.sql("select name, id, age from person where id = 2"); sql.show(); /** * +----+---+---+ * |name| id|age| * +----+---+---+ * | yyf| 2| 26| * +----+---+---+ */ /** * 将Dataset转成JavaRDD * 注意: * 1.可以使用row.getInt(0),row.getString(1)...通过下标获取返回Row类型的数据,但是要注意列顺序问题---不常用 * 2.可以使用row.getAs("列名")来获取对应的列值。 */ JavaRDD<Row> javaRDD = dataFrame.javaRDD(); JavaRDD<Person> map = javaRDD.map(new Function<Row, Person>() { private static final long serialVersionUID = 1L; @Override public Person call(Row row) throws Exception { Person p = new Person(); p.setId(row.getAs("id")); p.setName(row.getAs("name")); p.setAge(row.getAs("age")); return p; } }); map.foreach(x -> System.out.println(x)); /** * Person{id='1', name='longdd', age=27} * Person{id='2', name='yyf', age=26} * Person{id='3', name='zhou', age=27} * Person{id='4', name='burning', age=30} * Person{id='5', name='atme', age=21} */ sc.stop(); } }
动态创建 Schema 将非 json 格式的 RDD 转换成 Dataset
package com.ronnie.java.json; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.Arrays; import java.util.List; public class CreateDatasetFromRDDWithStruct { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("schema").master("local").getOrCreate(); SparkContext sc = sparkSession.sparkContext(); JavaSparkContext jsc = new JavaSparkContext(sc); JavaRDD<String> lineRDD = jsc.textFile("./resources/person.txt"); /** * 转换成Row类型的RDD */ final JavaRDD<Row> rowRDD = lineRDD.map(new Function<String, Row>() { private static final long serialVersionUID = 1L; @Override public Row call(String line) throws Exception { return RowFactory.create( line.split(",")[0], line.split(",")[1], Integer.valueOf(line.split(",")[2]) ); } }); /** * 动态构建DataFrame中的元数据,一般来说这里的字段可以来源自字符串,也可以来源于外部数据库 */ List<StructField> asList = Arrays.asList( DataTypes.createStructField("id", DataTypes.StringType, true), DataTypes.createStructField("name", DataTypes.StringType, true), DataTypes.createStructField("age", DataTypes.IntegerType, true) ); StructType schema = DataTypes.createStructType(asList); Dataset<Row> df = sparkSession.createDataFrame(rowRDD, schema); df.printSchema(); /** * root * |-- id: string (nullable = true) * |-- name: string (nullable = true) * |-- age: integer (nullable = true) */ df.show(); /** * +---+-------+---+ * | id| name|age| * +---+-------+---+ * | 1| longdd| 27| * | 2| yyf| 26| * | 3| zhou| 27| * | 4|burning| 30| * | 5| atme| 21| * +---+-------+---+ */ sc.stop(); } }
读取 parquet 文件创建 DataSet
- SaveMode: 指文件保存时的模式
- Overwrite: 覆盖
- Append: 追加
- getOrCreate: 获取或创建
package com.ronnie.java.parquet; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; public class CreateDataFrameFromParquet { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("parquet").master("local").getOrCreate(); Dataset<Row> df = sparkSession.read().json("./resources/json"); df.show(); /** * +----+--------+ * | age| name| * +----+--------+ * | 18| burning| * |null| atme| * | 18| longdd| * | 28| yyf| * | 20| zhou| * |null| blaze| * | 18| ocean| * | 28| xiaoliu| * | 28|zhangsan| * |null| lisi| * | 18| wangwu| * +----+--------+ */ /** * 将Dataset保存成parquet文件, * SaveMode指定存储文件时的保存模式: * Overwrite:覆盖 * Append:追加 * ErrorIfExists:如果存在就报错 * Ignore:如果存在就忽略 * 保存成parquet文件有以下两种方式: */ // df.write().mode(SaveMode.Overwrite).format("parquet").save("./resources/parquet"); df.write().mode(SaveMode.Overwrite).parquet("./resources/parquet"); /** * { * "type" : "struct", * "fields" : [ { * "name" : "age", * "type" : "long", * "nullable" : true, * "metadata" : { } * }, { * "name" : "name", * "type" : "string", * "nullable" : true, * "metadata" : { } * } ] * } * and corresponding Parquet message type: * message spark_schema { * optional int64 age; * optional binary name (UTF8); * } */ /** * 加载parquet文件成Dataset * 加载parquet文件有以下两种方式: */ Dataset<Row> load = sparkSession.read().format("parquet").load("./resources/parquet"); load.show(); /** * +----+--------+ * | age| name| * +----+--------+ * | 18| burning| * |null| atme| * | 18| longdd| * | 28| yyf| * | 20| zhou| * |null| blaze| * | 18| ocean| * | 28| xiaoliu| * | 28|zhangsan| * |null| lisi| * | 18| wangwu| * +----+--------+ */ sparkSession.stop(); } }
读取 JDBC 中的数据创建 DataSet
package com.ronnie.java.jdbc; import org.apache.spark.sql.*; import java.util.HashMap; import java.util.Map; import java.util.Properties; public class CreateDatasetFromMysql { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("jdbc").master("local").getOrCreate(); /** * 第一种方式读取MySql数据库表,加载为Dataset */ Map<String, String> options = new HashMap<>(); options.put("url", "jdbc:mysql://localhost:3306/spark"); options.put("driver", "com.mysql.jdbc.Driver"); options.put("user", "root"); options.put("password", "123456"); options.put("dbtable", "person"); Dataset<Row> person = sparkSession.read().format("jdbc").options(options).load(); person.show(); /** * +---+------+---+ * | id| name|age| * +---+------+---+ * | 1| slark| 70| * | 2| pom| 40| * | 3|huskar| 60| * | 4| axe| 80| * +---+------+---+ */ person.createOrReplaceTempView("person"); /** * 第二种方式读取MySql数据表加载为Dataset */ DataFrameReader reader = sparkSession.read().format("jdbc"); reader.option("url", "jdbc:mysql://localhost:3306/spark"); reader.option("driver", "com.mysql.jdbc.Driver"); reader.option("user", "root"); reader.option("password", "123456"); reader.option("dbtable", "score"); Dataset<Row> score = reader.load(); score.show(); /** * +---+------+-----+ * | id| name|score| * +---+------+-----+ * | 1|dragon| 80| * | 2| axe| 99| * | 3| slark| 81| * +---+------+-----+ */ score.createOrReplaceTempView("score"); Dataset<Row> result = sparkSession.sql("select person.id,person.name,person.age,score.score " + "from person,score " + "where person.name = score.name and score.score> 82"); result.show(); /** * +---+----+---+-----+ * | id|name|age|score| * +---+----+---+-----+ * | 4| axe| 80| 99| * +---+----+---+-----+ */ result.registerTempTable("result"); /** * 将Dataset结果保存到Mysql中 */ // Properties properties = new Properties(); properties.setProperty("user", "root"); properties.setProperty("password", "123456"); /** * SaveMode: * Overwrite:覆盖 * Append:追加 * ErrorIfExists:如果存在就报错 * Ignore:如果存在就忽略 * */ result.write().mode(SaveMode.Overwrite).jdbc("jdbc:mysql://127.0.0.1:3306/spark", "result", properties); // System.out.println("----Finish----"); sparkSession.stop(); } }
Hive 中的数据创建 DataSet
package com.ronnie.java.hive; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; public class CreateDatasetFromHive { public static void main(String[] args) { SparkSession sparkSession = SparkSession .builder() .appName("hive") //开启hive的支持,接下来就可以操作hive表了 // 前提需要是需要开启hive metastore 服务 .enableHiveSupport() .getOrCreate(); sparkSession.sql("USE spark"); sparkSession.sql("DROP TABLE IF EXISTS student_infos"); //在hive中创建student_infos表 sparkSession.sql("CREATE TABLE IF NOT EXISTS student_infos (name STRING,age INT) row format delimited fields terminated by '\t' "); sparkSession.sql("load data local inpath '/root/student_infos' into table student_infos"); //注意:此种方式,程序需要能读取到数据(如/root/student_infos),同时也要能读取到 metastore服务的配置信息。 sparkSession.sql("DROP TABLE IF EXISTS student_scores"); sparkSession.sql("CREATE TABLE IF NOT EXISTS student_scores (name STRING, score INT) row format delimited fields terminated by '\t'"); sparkSession.sql("LOAD DATA " + "LOCAL INPATH '/root/student_scores'" + "INTO TABLE student_scores"); // Dataset<Row> df = hiveContext.table("student_infos");//读取Hive表加载Dataset方式 /** * 查询表生成Dataset */ Dataset<Row> goodStudentsDF = sparkSession.sql("SELECT si.name, si.age, ss.score " + "FROM student_infos si " + "JOIN student_scores ss " + "ON si.name=ss.name " + "WHERE ss.score>=80"); goodStudentsDF.registerTempTable("goodstudent"); Dataset<Row> result = sparkSession.sql("select * from goodstudent"); result.show(); /** * 将结果保存到hive表 good_student_infos */ sparkSession.sql("DROP TABLE IF EXISTS good_student_infos"); goodStudentsDF.write().mode(SaveMode.Overwrite).saveAsTable("good_student_infos"); sparkSession.stop(); } }
序列化问题
- Java 中以下几种情况下不被序列化的问题:
- 反序列化时 serializable 版本号不一致导致不能反序列化
- 子类中实现了serializable 接口, 但父类中没有实现, 父类中的变量不能被序列化, 序列化后父类中的变量会得到 null。
- 被关键字 transient 修饰的变量不能被序列化。
- 静态变量不能被序列化, 属于类, 不属于方法和对象, 所以不能被序列化。
储存 DataSet
- 将 DataSet 存储为 parquet 文件。
- 将 DataSet 存储到 JDBC 数据库。
- 将DataSet 存储到 Hive 表
自定义函数 UDP 和 UDAF
UDF(User Defined Function): 用户自定义函数
package com.ronnie.java.udf_udaf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.Arrays; public class UDF { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("udf").master("local").getOrCreate(); JavaSparkContext jsc = new JavaSparkContext(sparkSession.sparkContext()); JavaRDD<String> parallelize = jsc.parallelize(Arrays.asList("atme", "maybe", "chalice")); JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() { private static final long serialVersionUID = 1L; @Override public Row call(String s) throws Exception { return RowFactory.create(s); } }); /** * 动态创建Schema方式加载DF */ ArrayList<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("name", DataTypes.StringType,true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = sparkSession.createDataFrame(rowRDD, schema); df.registerTempTable("user"); /** * 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx */ sparkSession.udf().register("StrLen", new UDF2<String, Integer, Integer>(){ private static final long serialVersionUID = 1L; @Override public Integer call(String t1, Integer t2) throws Exception { return t1.length() + t2; } }, DataTypes.IntegerType); sparkSession.sql("select name ,StrLen(name,100) as length from user").show(); /** * +-------+------+ * | name|length| * +-------+------+ * | atme| 104| * | maybe| 105| * |chalice| 107| * +-------+------+ */ sparkSession.stop(); } }
UDAF(User Defined Aggregate Function): 用户自定义聚合函数
实现UDAF函数如果要自定义类要实现UserDefinedAggregateFunction
package com.ronnie.java.udf_udaf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class UDAF { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("udaf").master("local").getOrCreate(); JavaSparkContext sc = new JavaSparkContext(sparkSession.sparkContext()); JavaRDD<String> parallelize = sc.parallelize( Arrays.asList("zeus", "lina", "wind ranger", "zeus", "zeus", "lina","zeus", "lina", "wind ranger", "zeus", "zeus", "lina"),2); JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() { @Override public Row call(String s) throws Exception { return RowFactory.create(s); } }); List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = sparkSession.createDataFrame(rowRDD, schema); df.registerTempTable("user"); /** * 注册一个UDAF函数,实现统计相同值得个数 * 注意:这里可以自定义一个类继承UserDefinedAggregateFunction类也是可以的 * 数据: * zeus * zeus * lina * lina * * select count(*) from user group by name */ sparkSession.udf().register("StringCount", new UserDefinedAggregateFunction() { /** * 指定输入字段的字段及类型 * @return */ @Override public StructType inputSchema() { return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true))); } /** * 指定UDAF函数计算后返回的结果类型 * @return */ @Override public DataType dataType() { return DataTypes.IntegerType; } /** * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。 * @return */ @Override public boolean deterministic() { return true; } /** * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑 * buffer.getInt(0)获取的是上一次聚合后的值 * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合 * 大聚和发生在reduce端. * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算 * @param buffer * @param input */ @Override public void update(MutableAggregationBuffer buffer, Row input) { buffer.update(0, buffer.getInt(0) + 1); System.out.println("update......buffer" + buffer.toString() + " | row" + input); } /** * 在进行聚合操作的时候所要处理的数据的结果的类型 * @return */ @Override public StructType bufferSchema() { return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("buffer", DataTypes.IntegerType, true))); } /** * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理 * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来 * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值 * buffer2.getInt(0) : 这次计算传入进来的update的结果 * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { /* 2 3 4 5 6 7 0 + 2 = 2 2 + 3 = 5 5 + 4 = 9 */ buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0)); System.out.println("merge.....buffer : " + buffer1.toString() + "| row" + buffer2.toString()); } /** * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果 * @param buffer */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0); System.out.println("init ......" + buffer.get(0)); } /** * 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果 * @param row * @return */ @Override public Object evaluate(Row row) { return row.getInt(0); } }); sparkSession.sql("select name, StringCount(name) as number from user group by name").show(); /** * +-----------+------+ * | name|number| * +-----------+------+ * |wind ranger| 2| * | lina| 4| * | zeus| 6| * +-----------+------+ */ sc.stop(); } }
开窗函数
package com.ronnie.java.windowFun; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; /** * row_number()开窗函数: * 主要是按照某个字段分组,然后取另一字段的前几个的值,相当于 分组取topN * row_number() over (partition by xxx order by xxx desc) xxx * */ public class RowNumberWindowFun { //-Xms800m -Xmx800m -XX:PermSize=64M -XX:MaxNewSize=256m -XX:MaxPermSize=128m public static void main(String[] args) { SparkSession sparkSession = SparkSession .builder() .appName("window") .master("local") //开启hive的支持,接下来就可以操作hive表了 // 前提需要是需要开启hive metastore 服务 .enableHiveSupport() .getOrCreate(); sparkSession.sql("use spark"); sparkSession.sql("drop table if exists sales"); sparkSession.sql("create table if not exists sales (riqi string,leibie string,jine Int) " + "row format delimited fields terminated by '\t'"); sparkSession.sql("load data local inpath './data/sales.txt' into table sales"); /** * 开窗函数格式: * 【 row_number() over (partition by XXX order by XXX) as rank】 * 注意:rank 从1开始 */ /** * 以类别分组,按每种类别金额降序排序,显示 【日期,种类,金额】 结果,如: * * 1 A 100 * 2 B 200 * 3 A 300 * 4 B 400 * 5 A 500 * 6 B 600 * * 排序后: * 5 A 500 --rank 1 * 3 A 300 --rank 2 * 1 A 100 --rank 3 * 6 B 600 --rank 1 * 4 B 400 --rank 2 * 2 B 200 --rank 3 * * 2018 A 400 1 * 2017 A 500 2 * 2016 A 550 3 * * * 2016 A 550 1 * 2017 A 500 2 * 2018 A 400 3 * */ Dataset<Row> result = sparkSession.sql("select riqi,leibie,jine,rank " + "from (" + "select riqi,leibie,jine," + "row_number() over (partition by leibie order by jine desc) rank " + "from sales) t " + "where t.rank<=3"); result.show(100); /** * 将结果保存到hive表sales_result */ // result.write().mode(SaveMode.Overwrite).saveAsTable("sales_result"); sparkSession.stop(); } }