Slick 3.0.0 - update row with only non-null values

后端 未结 3 1112
攒了一身酷
攒了一身酷 2021-01-13 15:37

Having a table with the columns

class Data(tag: Tag) extends Table[DataRow](tag, \"data\") {
  def id = column[Int](\"id\", O.PrimaryKey)
  def name = column         


        
3条回答
  •  迷失自我
    2021-01-13 15:43

    I solved it in the following way.

    The implementation below works only if it is a Product object.

    Execute the update statement except for None for the Option type and null for the object type.

    package slick.extensions
    
    import slick.ast._
    import slick.dbio.{ Effect, NoStream }
    import slick.driver.JdbcDriver
    import slick.jdbc._
    import slick.lifted._
    import slick.relational.{ CompiledMapping, ProductResultConverter, ResultConverter, TypeMappingResultConverter }
    import slick.util.{ ProductWrapper, SQLBuilder }
    
    import scala.language.{ existentials, higherKinds, implicitConversions }
    
    trait PatchActionExtensionMethodsSupport { driver: JdbcDriver =>
    
      trait PatchActionImplicits {
        implicit def queryPatchActionExtensionMethods[U <: Product, C[_]](
            q: Query[_, U, C]
        ): PatchActionExtensionMethodsImpl[U] =
          createPatchActionExtensionMethods(updateCompiler.run(q.toNode).tree, ())
      }
    
      ///////////////////////////////////////////////////////////////////////////////////////////////
      //////////////////////////////////////////////////////////// Patch Actions
      ///////////////////////////////////////////////////////////////////////////////////////////////
    
      type PatchActionExtensionMethods[T <: Product] = PatchActionExtensionMethodsImpl[T]
    
      def createPatchActionExtensionMethods[T <: Product](tree: Node, param: Any): PatchActionExtensionMethods[T] =
        new PatchActionExtensionMethodsImpl[T](tree, param)
    
      class PatchActionExtensionMethodsImpl[T <: Product](tree: Node, param: Any) {
        protected[this] val ResultSetMapping(_, CompiledStatement(_, sres: SQLBuilder.Result, _),
          CompiledMapping(_converter, _)) = tree
        protected[this] val converter = _converter.asInstanceOf[ResultConverter[JdbcResultConverterDomain, Product]]
        protected[this] val TypeMappingResultConverter(childConverter, toBase, toMapped) = converter
        protected[this] val ProductResultConverter(elementConverters @ _ *) =
          childConverter.asInstanceOf[ResultConverter[JdbcResultConverterDomain, Product]]
        private[this] val updateQuerySplitRegExp = """(.*)(?<=set )((?:(?= where)|.)+)(.*)?""".r
        private[this] val updateQuerySetterRegExp = """[^\s]+\s*=\s*\?""".r
    
        /** An Action that updates the data selected by this query. */
        def patch(value: T): DriverAction[Int, NoStream, Effect.Write] = {
          val (seq, converters) = value.productIterator.zipWithIndex.toIndexedSeq
            .zip(elementConverters)
            .filter {
              case ((Some(_), _), _) => true
              case ((None, _), _) => false
              case ((null, _), _) => false
              case ((_, _), _) => true
            }
            .unzip
    
          val (products, indexes) = seq.unzip
    
          val newConverters = converters.zipWithIndex
            .map(c => (c._1, c._2 + 1))
            .map {
              case (c: BaseResultConverter[_], idx) => new BaseResultConverter(c.ti, c.name, idx)
              case (c: OptionResultConverter[_], idx) => new OptionResultConverter(c.ti, idx)
              case (c: DefaultingResultConverter[_], idx) => new DefaultingResultConverter(c.ti, c.default, idx)
              case (c: IsDefinedResultConverter[_], idx) => new IsDefinedResultConverter(c.ti, idx)
            }
    
          val productResultConverter =
            ProductResultConverter(newConverters: _*).asInstanceOf[ResultConverter[JdbcResultConverterDomain, Any]]
          val newConverter = TypeMappingResultConverter(productResultConverter, (p: Product) => p, (a: Any) => toMapped(a))
    
          val newValue: Product = new ProductWrapper(products)
          val newSql = sres.sql match {
            case updateQuerySplitRegExp(prefix, setter, suffix) =>
              val buffer = StringBuilder.newBuilder
              buffer.append(prefix)
              buffer.append(
                updateQuerySetterRegExp
                  .findAllIn(setter)
                  .zipWithIndex
                  .filter(s => indexes.contains(s._2))
                  .map(_._1)
                  .mkString(", ")
              )
              buffer.append(suffix)
              buffer.toString()
          }
    
          new SimpleJdbcDriverAction[Int]("patch", Vector(newSql)) {
            def run(ctx: Backend#Context, sql: Vector[String]): Int =
              ctx.session.withPreparedStatement(sql.head) { st =>
                st.clearParameters
                newConverter.set(newValue, st)
                sres.setter(st, newConverter.width + 1, param)
                st.executeUpdate
              }
          }
        }
      }
    }
    

    Example

    // Model
    case class User(
      id: Option[Int] = None,
      name: Option[String] = None,
      username: Option[String] = None,
      password: Option[String] = None
    )
    
    // Table
    class Users(tag: Tag) extends Table[User](tag, "users") {
      def id = column[Int]("id", O.PrimaryKey, O.AutoInc)
      def name = column[String]("name")
      def username = column[String]("username")
      def password = column[String]("password")
      override def * = (id.?, name.?, username.?, password.?) <>(User.tupled, User.unapply)
    }
    
    // TableQuery
    object Users extends TableQuery(new Users(_))
    
    // CustomDriver 
    trait CustomDriver extends PostgresDriver with PatchActionExtensionMethodsSupport {
      override val api: API = new API {}
      trait API extends super.API  with PatchActionImplicits
    }
    
    // Insert
    Users += User(Some(1), Some("Test"), Some("test"), Some("1234"))
    
    // User patch
    Users.filter(_.id === 1).patch(User(name = Some("Change Name"), username = Some("")))
    

    https://gist.github.com/bad79s/1edf9ea83ba08c46add03815059acfca

提交回复
热议问题