Use Scala macros to generate methods

后端 未结 1 907
鱼传尺愫
鱼传尺愫 2020-12-30 10:37

I want to generate aliases of methods using annotation macros in Scala 2.11+. I am not even sure that is even possible. If yes, how?

Example - Given this below, I wa

相关标签:
1条回答
  • 2020-12-30 11:06

    This doesn't seem possible exactly as stated. Using a macro annotation on a class member does not allow you to manipulate the tree of the class itself. That is, when you annotate a method within a class with a macro annotation, macroTransform(annottees: Any*) will be called, but the only annottee will be the method itself.

    I was able to get a proof-of-concept working with two annotations. It's obviously not as nice as simply annotating the class, but I can't think of another way around it.

    You'll need:

    import scala.annotation.{ StaticAnnotation, compileTimeOnly }
    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox.Context
    

    The idea is, you can annotate each method with this annotation, so that a macro annotation on the parent class is able to find which methods you want to expand.

    class alias(aliases: String *) extends StaticAnnotation
    

    Then the macro:

    // Annotate the containing class to expand aliased methods within
    @compileTimeOnly("You must enable the macro paradise plugin.")
    class aliased extends StaticAnnotation {
        def macroTransform(annottees: Any*): Any = macro AliasMacroImpl.impl
    }
    
    object AliasMacroImpl {
    
      def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
        import c.universe._
    
        val result = annottees map (_.tree) match {
          // Match a class, and expand.
          case (classDef @ q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }") :: _ =>
    
            val aliasedDefs = for {
              q"@alias(..$aliases) def $tname[..$tparams](...$paramss): $tpt = $expr" <- stats
              Literal(Constant(alias)) <- aliases
              ident = TermName(alias.toString)
            } yield {
              val args = paramss map { paramList =>
                paramList.map { case q"$_ val $param: $_ = $_" => q"$param" }
              }
    
              q"def $ident[..$tparams](...$paramss): $tpt = $tname(...$args)"
            }
    
            if(aliasedDefs.nonEmpty) {
              q"""
                $mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self =>
                  ..$stats
                  ..$aliasedDefs
                }
              """
            } else classDef
            // Not a class.
            case _ => c.abort(c.enclosingPosition, "Invalid annotation target: not a class")
        }
    
        c.Expr[Any](result)
      }
    
    }
    

    Keep in mind this implementation will be brittle. It only inspects the annottees to check that the first is a ClassDef. Then, it looks for members of the class that are methods annotated with @alias, and creates multiple aliased trees to splice back into the class. If there are no annotated methods, it simply returns the original class tree. As is, this will not detect duplicate method names, and strips away modifiers (the compiler would not let me match annotations and modifiers at the same time).

    This can easily be expanded to handle companion objects as well, but I left them out to keep the code smaller. See the quasiquotes syntax summary for the matchers I used. Handling companion objects would require modifying the result match to handle case classDef :: objDef :: Nil, and case objDef :: Nil.

    In use:

    @aliased
    class Socket {
        @alias("ask", "read")
        def load(n: Int): Seq[Byte] = Seq(1, 2, 3).map(_.toByte)
    }
    
    scala> val socket = new Socket
    socket: Socket = Socket@7407d2b8
    
    scala> socket.load(5)
    res0: Seq[Byte] = List(1, 2, 3)
    
    scala> socket.ask(5)
    res1: Seq[Byte] = List(1, 2, 3)
    
    scala> socket.read(5)
    res2: Seq[Byte] = List(1, 2, 3)
    

    It can also handle multiple parameter lists:

    @aliased
    class Foo {
        @alias("bar", "baz")
        def test(a: Int, b: Int)(c: String) = a + b + c
    }
    
    scala> val foo = new Foo
    foo: Foo = Foo@3857a375
    
    scala> foo.baz(1, 2)("4")
    res0: String = 34
    
    0 讨论(0)
提交回复
热议问题