Tensorflow in Scala reflection

后端 未结 2 1076
无人及你
无人及你 2020-12-20 09:58

I am trying to get tensorflow for java to work on Scala. I am use the tensorflow java library without any wrapper for Scala.

At sbt I have

相关标签:
2条回答
  • 2020-12-20 10:37

    The thing is in this bug appearing in combination of reflective compilation and Scala-Java interop

    https://github.com/scala/bug/issues/8956

    Toolbox can't typecheck a value (s.runner()) of path-dependent type (s.Runner) if this type comes from Java non-static inner class. And Runner is exactly such class inside org.tensorflow.Session.

    You can run the compiler manually (similarly to how Toolbox runs it)

    import org.tensorflow.Tensor
    import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
    import scala.reflect.io.{AbstractFile, VirtualDirectory}
    import scala.reflect.runtime
    import scala.reflect.runtime.universe
    import scala.reflect.runtime.universe._
    import scala.tools.nsc.{Global, Settings}
    
    val code: String =
      """
        |import org.tensorflow.Graph
        |import org.tensorflow.Session
        |import org.tensorflow.Tensor
        |import org.tensorflow.TensorFlow
        |
        |object Main {
        |  def foo() = () => {
        |      val g = new Graph()
        |      val value = "Hello from " + TensorFlow.version()
        |      val t = Tensor.create(value.getBytes("UTF-8"))
        |      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
        |
        |      val s = new Session(g)
        |
        |      s.runner().fetch("MyConst").run().get(0)
        |  }
        |}
    """.stripMargin
    
    val directory = new VirtualDirectory("(memory)", None)
    val runtimeMirror = createRuntimeMirror(directory, runtime.currentMirror)
    compileCode(code, List(), directory)
    val tensor = runObjectMethod("Main", runtimeMirror, "foo").asInstanceOf[() => Tensor[_]]
    tensor() // STRING tensor with shape []
    
    def compileCode(code: String, classpathDirectories: List[AbstractFile], outputDirectory: AbstractFile): Unit = {
      val settings = new Settings
      classpathDirectories.foreach(dir => settings.classpath.prepend(dir.toString))
      settings.outputDirs.setSingleOutput(outputDirectory)
      settings.usejavacp.value = true
      val global = new Global(settings)
      (new global.Run).compileSources(List(new BatchSourceFile("(inline)", code)))
    }
    
    def runObjectMethod(objectName: String, runtimeMirror: Mirror, methodName: String, arguments: Any*): Any = {
      val objectSymbol = runtimeMirror.staticModule(objectName)
      val objectModuleMirror = runtimeMirror.reflectModule(objectSymbol)
      val objectInstance = objectModuleMirror.instance
      val objectType = objectSymbol.typeSignature
      val methodSymbol = objectType.decl(TermName(methodName)).asMethod
      val objectInstanceMirror = runtimeMirror.reflect(objectInstance)
      val methodMirror = objectInstanceMirror.reflectMethod(methodSymbol)
      methodMirror(arguments: _*)
    }
    
    def createRuntimeMirror(directory: AbstractFile, parentMirror: Mirror): Mirror = {
      val classLoader = new AbstractFileClassLoader(directory, parentMirror.classLoader)
      universe.runtimeMirror(classLoader)
    }
    

    dynamically parse json in flink map

    Dynamic compilation of multiple Scala classes at runtime

    How to eval code that uses InterfaceStability annotation (that fails with "illegal cyclic reference involving class InterfaceStability")?

    0 讨论(0)
  • 2020-12-20 10:38

    As Dmytro pointed out on his answer, it is not possible using toolbox. And he pointed out to another answer (How to eval code that uses InterfaceStability annotation (that fails with "illegal cyclic reference involving class InterfaceStability")?). I think there is a neat solution by just replace the Compiler class defined in the previous, and replacing the Toolbox for that Compiler class.

    In that case, the final snippet will look like:

    import your.package.Compiler
    val fnStr = """
        {() =>
          import org.tensorflow.Graph
          import org.tensorflow.Session
          import org.tensorflow.Tensor
          import org.tensorflow.TensorFlow
    
          val g = new Graph()
          val value = "Hello from " + TensorFlow.version()
          val t = Tensor.create(value.getBytes("UTF-8"))
          g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
    
          val s = new Session(g)
    
          s.runner().fetch("MyConst").run().get(0)
        }
        """
    val tb = new Compiler() // this replaces the mirror and toolbox instantiation
    var t = tb.parse(fnStr)
    val fn = tb.eval(t).asInstanceOf[() => Any]
    // and finally, executing the function
    println(fn())
    

    And just for completion, copy/paste from the solution at this answer:

      class Compiler() {
        import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
        import scala.reflect.io.{AbstractFile, VirtualDirectory}
        import scala.reflect.runtime
        import scala.reflect.runtime.universe
        import scala.reflect.runtime.universe._
        import scala.tools.nsc.{Global, Settings}
        import scala.collection.mutable
        import java.security.MessageDigest
        import java.math.BigInteger
           
        val target  = new VirtualDirectory("(memory)", None)
           
        val classCache = mutable.Map[String, Class[_]]()
           
        private val settings = new Settings()
        settings.deprecation.value = true // enable detailed deprecation warnings
        settings.unchecked.value = true // enable detailed unchecked warnings
        settings.outputDirs.setSingleOutput(target)
        settings.usejavacp.value = true
           
        private val global = new Global(settings)
        private lazy val run = new global.Run
           
        val classLoader = new AbstractFileClassLoader(target, this.getClass.getClassLoader)
           
        /**Compiles the code as a class into the class loader of this compiler.
          * 
          * @param code
          * @return
          */
        def compile(code: String) = {
          val className = classNameForCode(code)
          findClass(className).getOrElse {
            val sourceFiles = List(new BatchSourceFile("(inline)", wrapCodeInClass(className, code)))
            run.compileSources(sourceFiles)
            findClass(className).get
          } 
        }   
           
        /** Compiles the source string into the class loader and
          * evaluates it.
          * 
          * @param code
          * @tparam T
          * @return
          */
        def eval[T](code: String): T = {
          val cls = compile(code)
          cls.getConstructor().newInstance().asInstanceOf[() => Any].apply().asInstanceOf[T]
        }  
            
        def findClass(className: String): Option[Class[_]] = {
          synchronized {
            classCache.get(className).orElse {
              try {
                val cls = classLoader.loadClass(className)
                classCache(className) = cls
                Some(cls)
              } catch {
                case e: ClassNotFoundException => None
              }
            }
          } 
        }   
      
        protected def classNameForCode(code: String): String = {
          val digest = MessageDigest.getInstance("SHA-1").digest(code.getBytes)
          "sha"+new BigInteger(1, digest).toString(16)
        }   
      
        /*  
         * Wrap source code in a new class with an apply method.
         */ 
       private def wrapCodeInClass(className: String, code: String) = {
         "class " + className + " extends (() => Any) {\n" +
         "  def apply() = {\n" +
         code + "\n" +
         "  }\n" +
         "}\n"
       }    
      }  
    
    0 讨论(0)
提交回复
热议问题