SPARK SQL - update MySql table using DataFrames and JDBC

前端 未结 6 2045
旧巷少年郎
旧巷少年郎 2020-11-29 05:17

I\'m trying to insert and update some data on MySql using Spark SQL DataFrames and JDBC connection.

I\'ve succeeded to insert new data using the SaveMode.Append. Is

相关标签:
6条回答
  • 2020-11-29 05:41

    It is not possible. As for now (Spark 1.6.0 / 2.2.0 SNAPSHOT) Spark DataFrameWriter supports only four writing modes:

    • SaveMode.Overwrite: overwrite the existing data.
    • SaveMode.Append: append the data.
    • SaveMode.Ignore: ignore the operation (i.e. no-op).
    • SaveMode.ErrorIfExists: default option, throw an exception at runtime.

    You can insert manually for example using mapPartitions (since you want an UPSERT operation should be idempotent and as such easy to implement), write to temporary table and execute upsert manually, or use triggers.

    In general achieving upsert behavior for batch operations and keeping decent performance is far from trivial. You have to remember that in general case there will be multiple concurrent transactions in place (one per each partition) so you have to ensure that there will no write conflicts (typically by using application specific partitioning) or provide appropriate recovery procedures. In practice it may be better to perform and batch writes to a temporary table and resolve upsert part directly in the database.

    0 讨论(0)
  • 2020-11-29 05:41

    overwrite org.apache.spark.sql.execution.datasources.jdbc JdbcUtils.scala insert into to replace into

    import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}
    
    import scala.collection.JavaConverters._
    import scala.util.control.NonFatal
    import com.typesafe.scalalogging.Logger
    import org.apache.spark.sql.catalyst.InternalRow
    import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}
    import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{DataFrame, Row}
    
    /**
      * Util functions for JDBC tables.
      */
    object UpdateJdbcUtils {
    
      val logger = Logger(this.getClass)
    
      /**
        * Returns a factory for creating connections to the given JDBC URL.
        *
        * @param options - JDBC options that contains url, table and other information.
        */
      def createConnectionFactory(options: JDBCOptions): () => Connection = {
        val driverClass: String = options.driverClass
        () => {
          DriverRegistry.register(driverClass)
          val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
            case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
            case d if d.getClass.getCanonicalName == driverClass => d
          }.getOrElse {
            throw new IllegalStateException(
              s"Did not find registered driver with class $driverClass")
          }
          driver.connect(options.url, options.asConnectionProperties)
        }
      }
    
      /**
        * Returns a PreparedStatement that inserts a row into table via conn.
        */
      def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
      : PreparedStatement = {
        val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
        val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
        val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"
        conn.prepareStatement(sql)
      }
    
      /**
        * Retrieve standard jdbc types.
        *
        * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
        * @return The default JdbcType for this DataType
        */
      def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
        dt match {
          case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
          case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
          case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
          case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
          case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
          case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
          case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
          case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
          case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
          case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
          case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
          case t: DecimalType => Option(
            JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
          case _ => None
        }
      }
    
      private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
        dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
          throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
      }
    
      // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
      // for `MutableRow`. The last argument `Int` means the index for the value to be set in
      // the row and also used for the value in `ResultSet`.
      private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
    
      // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
      // `PreparedStatement`. The last argument `Int` means the index for the value to be set
      // in the SQL statement and also used for the value in `Row`.
      private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
    
      /**
        * Saves a partition of a DataFrame to the JDBC database.  This is done in
        * a single database transaction (unless isolation level is "NONE")
        * in order to avoid repeatedly inserting data as much as possible.
        *
        * It is still theoretically possible for rows in a DataFrame to be
        * inserted into the database more than once if a stage somehow fails after
        * the commit occurs but before the stage can return successfully.
        *
        * This is not a closure inside saveTable() because apparently cosmetic
        * implementation changes elsewhere might easily render such a closure
        * non-Serializable.  Instead, we explicitly close over all variables that
        * are used.
        */
      def savePartition(
                         getConnection: () => Connection,
                         table: String,
                         iterator: Iterator[Row],
                         rddSchema: StructType,
                         nullTypes: Array[Int],
                         batchSize: Int,
                         dialect: JdbcDialect,
                         isolationLevel: Int): Iterator[Byte] = {
        val conn = getConnection()
        var committed = false
    
        var finalIsolationLevel = Connection.TRANSACTION_NONE
        if (isolationLevel != Connection.TRANSACTION_NONE) {
          try {
            val metadata = conn.getMetaData
            if (metadata.supportsTransactions()) {
              // Update to at least use the default isolation, if any transaction level
              // has been chosen and transactions are supported
              val defaultIsolation = metadata.getDefaultTransactionIsolation
              finalIsolationLevel = defaultIsolation
              if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
                // Finally update to actually requested level if possible
                finalIsolationLevel = isolationLevel
              } else {
                logger.warn(s"Requested isolation level $isolationLevel is not supported; " +
                  s"falling back to default isolation level $defaultIsolation")
              }
            } else {
              logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")
            }
          } catch {
            case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)
          }
        }
        val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
    
        try {
          if (supportsTransactions) {
            conn.setAutoCommit(false) // Everything in the same db transaction.
            conn.setTransactionIsolation(finalIsolationLevel)
          }
          val stmt = insertStatement(conn, table, rddSchema, dialect)
          val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
            .map(makeSetter(conn, dialect, _))
          val numFields = rddSchema.fields.length
    
          try {
            var rowCount = 0
            while (iterator.hasNext) {
              val row = iterator.next()
              var i = 0
              while (i < numFields) {
                if (row.isNullAt(i)) {
                  stmt.setNull(i + 1, nullTypes(i))
                } else {
                  setters(i).apply(stmt, row, i)
                }
                i = i + 1
              }
              stmt.addBatch()
              rowCount += 1
              if (rowCount % batchSize == 0) {
                stmt.executeBatch()
                rowCount = 0
              }
            }
            if (rowCount > 0) {
              stmt.executeBatch()
            }
          } finally {
            stmt.close()
          }
          if (supportsTransactions) {
            conn.commit()
          }
          committed = true
          Iterator.empty
        } catch {
          case e: SQLException =>
            val cause = e.getNextException
            if (cause != null && e.getCause != cause) {
              if (e.getCause == null) {
                e.initCause(cause)
              } else {
                e.addSuppressed(cause)
              }
            }
            throw e
        } finally {
          if (!committed) {
            // The stage must fail.  We got here through an exception path, so
            // let the exception through unless rollback() or close() want to
            // tell the user about another problem.
            if (supportsTransactions) {
              conn.rollback()
            }
            conn.close()
          } else {
            // The stage must succeed.  We cannot propagate any exception close() might throw.
            try {
              conn.close()
            } catch {
              case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)
            }
          }
        }
      }
    
      /**
        * Saves the RDD to the database in a single transaction.
        */
      def saveTable(
                     df: DataFrame,
                     url: String,
                     table: String,
                     options: JDBCOptions) {
        val dialect = JdbcDialects.get(url)
        val nullTypes: Array[Int] = df.schema.fields.map { field =>
          getJdbcType(field.dataType, dialect).jdbcNullType
        }
    
        val rddSchema = df.schema
        val getConnection: () => Connection = createConnectionFactory(options)
        val batchSize = options.batchSize
        val isolationLevel = options.isolationLevel
        df.foreachPartition(iterator => savePartition(
          getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
        )
      }
    
      private def makeSetter(
                              conn: Connection,
                              dialect: JdbcDialect,
                              dataType: DataType): JDBCValueSetter = dataType match {
        case IntegerType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setInt(pos + 1, row.getInt(pos))
    
        case LongType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setLong(pos + 1, row.getLong(pos))
    
        case DoubleType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setDouble(pos + 1, row.getDouble(pos))
    
        case FloatType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setFloat(pos + 1, row.getFloat(pos))
    
        case ShortType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setInt(pos + 1, row.getShort(pos))
    
        case ByteType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setInt(pos + 1, row.getByte(pos))
    
        case BooleanType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setBoolean(pos + 1, row.getBoolean(pos))
    
        case StringType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setString(pos + 1, row.getString(pos))
    
        case BinaryType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
    
        case TimestampType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
    
        case DateType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
    
        case t: DecimalType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
    
        case ArrayType(et, _) =>
          // remove type length parameters from end of type name
          val typeName = getJdbcType(et, dialect).databaseTypeDefinition
            .toLowerCase.split("\\(")(0)
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            val array = conn.createArrayOf(
              typeName,
              row.getSeq[AnyRef](pos).toArray)
            stmt.setArray(pos + 1, array)
    
        case _ =>
          (_: PreparedStatement, _: Row, pos: Int) =>
            throw new IllegalArgumentException(
              s"Can't translate non-null value for field $pos")
      }
    }
    

    usage:

    val url = s"jdbc:mysql://$host/$database?useUnicode=true&characterEncoding=UTF-8"
    
    val parameters: Map[String, String] = Map(
      "url" -> url,
      "dbtable" -> table,
      "driver" -> "com.mysql.jdbc.Driver",
      "numPartitions" -> numPartitions.toString,
      "user" -> user,
      "password" -> password
    )
    val options = new JDBCOptions(parameters)
    
    for (d <- data) {
      UpdateJdbcUtils.saveTable(d, url, table, options)
    }
    

    ps: pay attention to the deadlock, not update data frequently, just use in re-run in case of emergency, I think that's why spark not support this official.

    0 讨论(0)
  • 2020-11-29 05:50

    zero323's answer is right, I just wanted to add that you could use JayDeBeApi package to workaround this: https://pypi.python.org/pypi/JayDeBeApi/

    to update data in your mysql table. It might be a low-hanging fruit since you already have mysql jdbc driver installed.

    The JayDeBeApi module allows you to connect from Python code to databases using Java JDBC. It provides a Python DB-API v2.0 to that database.

    We use Anaconda distribution of Python, and JayDeBeApi python package comes standard.

    See examples in that link above.

    0 讨论(0)
  • 2020-11-29 05:56

    A pity that there is no SaveMode.Upsert mode in Spark for such quite common cases like upserting.

    zero322 is right in general, but I think it should be possible (with compromises in performance) to offer such replace feature.

    I also wanted to provide some java code for this case. Of course it is not that performant as the built-in one from spark - but it should be a good basis for your requirements. Just modify it towards your needs:

    myDF.repartition(20); //one connection per partition, see below
    
    myDF.foreachPartition((Iterator<Row> t) -> {
                Connection conn = DriverManager.getConnection(
                        Constants.DB_JDBC_CONN,
                        Constants.DB_JDBC_USER,
                        Constants.DB_JDBC_PASS);
    
                conn.setAutoCommit(true);
                Statement statement = conn.createStatement();
    
                final int batchSize = 100000;
                int i = 0;
                while (t.hasNext()) {
                    Row row = t.next();
                    try {
                        // better than REPLACE INTO, less cycles
                        statement.addBatch(("INSERT INTO mytable " + "VALUES ("
                                + "'" + row.getAs("_id") + "', 
                                + "'" + row.getStruct(1).get(0) + "'
                                + "')  ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") + "';"));
                        //conn.commit();
    
                        if (++i % batchSize == 0) {
                            statement.executeBatch();
                        }
                    } catch (SQLIntegrityConstraintViolationException e) {
                        //should not occur, nevertheless
                        //conn.commit();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    } finally {
                        //conn.commit();
                        statement.executeBatch();
                    }
                }
                int[] ret = statement.executeBatch();
    
                System.out.println("Ret val: " + Arrays.toString(ret));
                System.out.println("Update count: " + statement.getUpdateCount());
                conn.commit();
    
                statement.close();
                conn.close();
    
    0 讨论(0)
  • 2020-11-29 06:03

    If your table is small, then you can read the sql data and do the upsertion in spark dataframe. And overwrite the existing sql table.

    0 讨论(0)
  • 2020-11-29 06:04

    In PYSPARK I was not able to do that so I decided to use odbc.

    url = "jdbc:sqlserver://xxx:1433;databaseName=xxx;user=xxx;password=xxx"
    df.write.jdbc(url=url, table="__TableInsert", mode='overwrite')
    cnxn  = pyodbc.connect('Driver={ODBC Driver 17 for SQL Server};Server=xxx;Database=xxx;Uid=xxx;Pwd=xxx;', autocommit=False) 
    try:
        crsr = cnxn.cursor()
        # DO UPSERTS OR WHATEVER YOU WANT
        crsr.execute("DELETE FROM Table")
        crsr.execute("INSERT INTO Table (Field) SELECT Field FROM __TableInsert")
        cnxn.commit()
    except:
        cnxn.rollback()
    cnxn.close()
    
    0 讨论(0)
提交回复
热议问题