Filtering Scala's Parallel Collections with early abort when desired number of results found

前端 未结 3 1712
终归单人心
终归单人心 2021-02-08 13:02

Given a very large instance of collection.parallel.mutable.ParHashMap (or any other parallel collection), how can one abort a filtering parallel scan once a giv

3条回答
  •  梦如初夏
    2021-02-08 13:21

    I had performed an interesting investigation about your case.

    Investigation reasoning

    I suspected the problem is with the mutability of the input Map and I will try to explain you why: HashMap implementation organizes the data in different buckets, as one can see on Wikipedia.

    Wikipedia HashMap

    The first thread-safe collections in Java, the synchronized collections were based on synchronizing all the methods around the underlying implementation and resulted in poor performance. Further research and thinking brought to the more performant Concurrent Collection, such as the ConcurrentHashMap which approach was smarter : why don't we protect each bucket with a specific lock?

    According to my feeling the performance problem occurs because:

    • when you run in parallel your filter, some threads will conflict on accessing the same bucket at once and will hit the same lock, because your map is mutable.
    • You hold a counter to see how many results you have while you can actually check the size of your result. If you have a thread-safe way to build a collection, you don't need a thread-safe counter too.

    Investigation result

    I have developed a test case and I find out I was wrong. The problem is with the concurrent nature of the output map. In fact, that is where the collision occurs, when you are putting elements in the map, rather then when you are iterating on it. Additionally, since you want only the result on values, you don't need the keys and the hashing and all the map features. It might be interesting to test if you remove the AtomicCounter and you use only the result map to check if you collected enough elements how your version performs.

    Please be careful with the following code in Scala 2.9.2. I am explaining in another post why I need two different functions for the parallel and the non parallel version: Calling map on a parallel collection via a reference to an ancestor type

    object MapPerformance {
    
      val size = 100000
      val items = Seq.tabulate(size)( x => (x,x*2))
    
    
      val concurrentParallelMap = ImmutableParHashMap(items:_*)
    
      val concurrentMutableParallelMap = MutableParHashMap(items:_*)
    
      val unparallelMap = Map(items:_*)
    
    
      class ThreadSafeIndexedSeqBuilder[T](maxSize:Int) {
        val underlyingBuilder = new VectorBuilder[T]()
        var counter = 0
        def sizeHint(hint:Int) { underlyingBuilder.sizeHint(hint) }
        def +=(item:T):Boolean ={
          synchronized{
            if(counter>=maxSize)
              false
            else{
              underlyingBuilder+=item
              counter+=1
              true
            }
          }
        }
        def result():Vector[T] = underlyingBuilder.result()
    
      }
    
      def find(map:ParMap[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
      {
    
        // we already know the maximum size
        val resultsBuilder = new ThreadSafeIndexedSeqBuilder[Int](maxResults)
        resultsBuilder.sizeHint(maxResults)
    
        import util.control.Breaks._
        breakable
        {
          for ((key, node) <- map if filter(node))
          {
            val newItemAdded = resultsBuilder+=node
            if (!newItemAdded)
              break()
    
          }
        }
        resultsBuilder.result().seq
    
      }
    
      def findUnParallel(map:Map[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
      {
    
        // we already know the maximum size
        val resultsBuilder = Array.newBuilder[Int]
        resultsBuilder.sizeHint(maxResults)
    
        var counter = 0
          for {
            (key, node) <- map if filter(node)
            if counter < maxResults
          }{
            resultsBuilder+=node
            counter+=1
          }
    
        resultsBuilder.result()
    
      }
    
      def measureTime[K](f: => K):(Long,K) = {
        val startMutable = System.currentTimeMillis()
        val result = f
        val endMutable = System.currentTimeMillis()
        (endMutable-startMutable,result)
      }
    
      def main(args:Array[String]) = {
        val maxResultSetting=10
        (1 to 10).foreach{
          tryNumber =>
            println("Try number " +tryNumber)
            val (mutableTime, mutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
            val (immutableTime, immutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
            val (unparallelTime, unparallelResult) = measureTime(findUnParallel(unparallelMap,_%2==0,maxResultSetting))
            assert(mutableResult.size==maxResultSetting)
            assert(immutableResult.size==maxResultSetting)
            assert(unparallelResult.size==maxResultSetting)
            println(" The mutable version has taken " + mutableTime + " milliseconds")
            println(" The immutable version has taken " + immutableTime + " milliseconds")
            println(" The unparallel version has taken " + unparallelTime + " milliseconds")
         }
      }
    
    }
    

    With this code, I have systematically the parallel (both mutable and immutable version of the input map) about 3,5 time faster then the unparallel on my machine.

提交回复
热议问题