I am trying to learn meta-programming in dotty. Specifically compile time code generation. I thought learning by building something would be a good approach. So I decided to
Using standard type class derivation in Dotty
import scala.deriving._
import scala.compiletime._
case class ParseError(str: String, msg: String)
trait Decoder[T]{
def decode(str:String): Either[ParseError, T]
}
object Decoder {
inline given stringDec as Decoder[String] = new Decoder[String] {
override def decode(str: String): Either[ParseError, String] = Right(str)
}
inline given intDec as Decoder[Int] = new Decoder[Int] {
override def decode(str: String): Either[ParseError, Int] =
str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
}
inline derived[T](using m: Mirror.Of[T]): Decoder[T] = {
val elemInstances = summonAll[m.MirroredElemTypes]
inline m match {
case p: Mirror.ProductOf[T] => productDecoder(p, elemInstances)
case s: Mirror.SumOf[T] => ???
}
}
inline def summonAll[T <: Tuple]: List[Decoder[_]] = inline erasedValue[T] match {
case _: Unit /* EmptyTuple in 0.25 */ => Nil
case _: (t *: ts) => summonInline[Decoder[t]] :: summonAll[ts]
}
def productDecoder[T](p: Mirror.ProductOf[T], elems: List[Decoder[_]]): Decoder[T] =
new Decoder[T] {
def decode(str: String): Either[ParseError, T] = {
elems.zip(str.split(','))
.map(_.decode(_).map(_.asInstanceOf[AnyRef]))
.sequence
.map(ts => p.fromProduct(new ArrayProduct(ts.toArray)))
}
}
def [E,A](es: List[Either[E,A]]) sequence: Either[E,List[A]] =
traverse(es)(x => x)
def traverse[E,A,B](es: List[A])(f: A => Either[E, B]): Either[E, List[B]] =
es.foldRight[Either[E, List[B]]](Right(Nil))((h, tRes) => map2(f(h), tRes)(_ :: _))
def map2[E, A, B, C](a: Either[E, A], b: Either[E, B])(f: (A, B) => C): Either[E, C] =
for { a1 <- a; b1 <- b } yield f(a1,b1)
}
case class A(i: Int, s: String) derives Decoder
@main def test = {
println(summon[Decoder[A]].decode("10,abc"))//Right(A(10,abc))
println(summon[Decoder[A]].decode("xxx,abc"))//Left(ParseError(xxx,value is not valid Int))
// println(summon[Decoder[A]].decode(","))
}
Tested in 0.24.0.
Using Shapeless-3
import shapeless.{K0, Typeable}
case class ParseError(str: String, msg: String)
trait Decoder[T]{
def decode(str:String): Either[ParseError, T]
}
object Decoder {
inline given stringDec as Decoder[String] = new Decoder[String] {
override def decode(str: String): Either[ParseError, String] = Right(str)
}
inline given intDec as Decoder[Int] = new Decoder[Int] {
override def decode(str: String): Either[ParseError, Int] =
str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
}
inline def derived[A](using gen: K0.Generic[A]): Decoder[A] =
gen.derive(productDecoder, null)
given productDecoder[T](using inst: K0.ProductInstances[Decoder, T], typeable: Typeable[T]) as Decoder[T] = new Decoder[T] {
def decode(str: String): Either[ParseError, T] = {
type Acc = (List[String], Option[ParseError])
inst.unfold[Decoder, T, Acc](str.split(',').toList, None)([t] => (acc: Acc, dec: Decoder[t]) =>
acc._1 match {
case head :: tail => dec.decode(head) match {
case Right(t) => ((tail, None), Some(t))
case Left(e) => ((Nil, Some(e)), None)
}
case Nil => (acc, None)
}
) match {
case ((_, Some(e)), None) => Left(e)
case ((_, None), None) => Left(ParseError(str, s"value is not valid ${typeable.describe}"))
case (_, Some(t)) => Right(t)
}
}
}
}
case class A(i: Int, s: String) derives Decoder
@main def test = {
println(summon[Decoder[A]].decode("10,abc")) //Right(A(10,abc))
println(summon[Decoder[A]].decode("xxx,abc")) //Left(ParseError(xxx,value is not valid Int))
println(summon[Decoder[A]].decode(",")) //Left(ParseError(,,value is not valid A))
}
build.sbt
scalaVersion := "0.24.0"
libraryDependencies += "org.typelevel" %% "shapeless-core" % "3.0.0-M1"
project/plugins.sbt
addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % "0.4.1")
Using Dotty macros + TASTy reflection like in dotty-macro-examples/macroTypeclassDerivation (this approach is even more low-level than the one with scala.deriving.Mirror
)
import scala.quoted._
case class ParseError(str: String, msg: String)
trait Decoder[T]{
def decode(str:String): Either[ParseError, T]
}
object Decoder {
inline given stringDec as Decoder[String] = new Decoder[String] {
override def decode(str: String): Either[ParseError, String] = Right(str)
}
inline given intDec as Decoder[Int] = new Decoder[Int] {
override def decode(str: String): Either[ParseError, Int] =
str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
}
inline def derived[T]: Decoder[T] = ${ derivedImpl[T] }
def derivedImpl[T](using qctx: QuoteContext, tpe: Type[T]): Expr[Decoder[T]] = {
import qctx.tasty._
val tpeSym = tpe.unseal.symbol
if (tpeSym.flags.is(Flags.Case)) productDecoder[T]
else if (tpeSym.flags.is(Flags.Trait & Flags.Sealed)) ???
else sys.error(s"Unsupported combination of flags: ${tpeSym.flags.show}")
}
def productDecoder[T](using qctx: QuoteContext, tpe: Type[T]): Expr[Decoder[T]] = {
import qctx.tasty._
val fields: List[Symbol] = tpe.unseal.symbol.caseFields
val fieldTypeTrees: List[TypeTree] = fields.map(_.tree.asInstanceOf[ValDef].tpt)
val fieldTypes: List[Type] = fieldTypeTrees.map(_.tpe)
val decoderTerms: List[Term] = fieldTypes.map(lookupDecoderFor(_))
val decoders: Expr[List[Decoder[_]]] = Expr.ofList(decoderTerms.map(_.seal.cast[Decoder[_]]))
def mkT(fields: Expr[List[_]]): Expr[T] = {
Apply(
Select.unique(New(tpe.unseal), "<init>"),
fieldTypeTrees.zipWithIndex.map((fieldType, i) =>
TypeApply(
Select.unique(
Apply(
Select.unique(
fields.unseal,
"apply"),
List(Literal(Constant(i)))
), "asInstanceOf"),
List(fieldType)
)
)
).seal.cast[T]
}
'{
new Decoder[T]{
override def decode(str: String): Either[ParseError, T] = {
str.split(',').toList.zip($decoders).map((str, decoder) =>
decoder.decode(str)
).sequence.map(fields =>
${mkT('fields)}
)
}
}
}
}
def lookupDecoderFor(using qctx: QuoteContext)(t: qctx.tasty.Type): qctx.tasty.Term = {
import qctx.tasty._
val tpe = AppliedType(Type(classOf[Decoder[_]]), List(t))
searchImplicit(tpe) match {
case res: ImplicitSearchSuccess => res.tree
}
}
def [E,A](es: List[Either[E,A]]) sequence: Either[E,List[A]] =
traverse(es)(x => x)
def traverse[E,A,B](es: List[A])(f: A => Either[E, B]): Either[E, List[B]] =
es.foldRight[Either[E, List[B]]](Right(Nil))((h, tRes) => map2(f(h), tRes)(_ :: _))
def map2[E, A, B, C](a: Either[E, A], b: Either[E, B])(f: (A, B) => C): Either[E, C] =
for { a1 <- a; b1 <- b } yield f(a1,b1)
}
case class A(i: Int, s: String) derives Decoder
@main def test = {
println(summon[Decoder[A]].decode("10,abc"))//Right(A(10,abc))
println(summon[Decoder[A]].decode("xxx,abc"))//Left(ParseError(xxx,value is not valid Int))
// println(summon[Decoder[A]].decode(","))
}
Tested in 0.24.0.
For comparison deriving type classes in Scala 2
Use the lowest subtype in a typeclass?