Filtering Scala's Parallel Collections with early abort when desired number of results found

前端 未结 3 1707
终归单人心
终归单人心 2021-02-08 13:02

Given a very large instance of collection.parallel.mutable.ParHashMap (or any other parallel collection), how can one abort a filtering parallel scan once a giv

相关标签:
3条回答
  • 2021-02-08 13:21

    I had performed an interesting investigation about your case.

    Investigation reasoning

    I suspected the problem is with the mutability of the input Map and I will try to explain you why: HashMap implementation organizes the data in different buckets, as one can see on Wikipedia.

    Wikipedia HashMap

    The first thread-safe collections in Java, the synchronized collections were based on synchronizing all the methods around the underlying implementation and resulted in poor performance. Further research and thinking brought to the more performant Concurrent Collection, such as the ConcurrentHashMap which approach was smarter : why don't we protect each bucket with a specific lock?

    According to my feeling the performance problem occurs because:

    • when you run in parallel your filter, some threads will conflict on accessing the same bucket at once and will hit the same lock, because your map is mutable.
    • You hold a counter to see how many results you have while you can actually check the size of your result. If you have a thread-safe way to build a collection, you don't need a thread-safe counter too.

    Investigation result

    I have developed a test case and I find out I was wrong. The problem is with the concurrent nature of the output map. In fact, that is where the collision occurs, when you are putting elements in the map, rather then when you are iterating on it. Additionally, since you want only the result on values, you don't need the keys and the hashing and all the map features. It might be interesting to test if you remove the AtomicCounter and you use only the result map to check if you collected enough elements how your version performs.

    Please be careful with the following code in Scala 2.9.2. I am explaining in another post why I need two different functions for the parallel and the non parallel version: Calling map on a parallel collection via a reference to an ancestor type

    object MapPerformance {
    
      val size = 100000
      val items = Seq.tabulate(size)( x => (x,x*2))
    
    
      val concurrentParallelMap = ImmutableParHashMap(items:_*)
    
      val concurrentMutableParallelMap = MutableParHashMap(items:_*)
    
      val unparallelMap = Map(items:_*)
    
    
      class ThreadSafeIndexedSeqBuilder[T](maxSize:Int) {
        val underlyingBuilder = new VectorBuilder[T]()
        var counter = 0
        def sizeHint(hint:Int) { underlyingBuilder.sizeHint(hint) }
        def +=(item:T):Boolean ={
          synchronized{
            if(counter>=maxSize)
              false
            else{
              underlyingBuilder+=item
              counter+=1
              true
            }
          }
        }
        def result():Vector[T] = underlyingBuilder.result()
    
      }
    
      def find(map:ParMap[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
      {
    
        // we already know the maximum size
        val resultsBuilder = new ThreadSafeIndexedSeqBuilder[Int](maxResults)
        resultsBuilder.sizeHint(maxResults)
    
        import util.control.Breaks._
        breakable
        {
          for ((key, node) <- map if filter(node))
          {
            val newItemAdded = resultsBuilder+=node
            if (!newItemAdded)
              break()
    
          }
        }
        resultsBuilder.result().seq
    
      }
    
      def findUnParallel(map:Map[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
      {
    
        // we already know the maximum size
        val resultsBuilder = Array.newBuilder[Int]
        resultsBuilder.sizeHint(maxResults)
    
        var counter = 0
          for {
            (key, node) <- map if filter(node)
            if counter < maxResults
          }{
            resultsBuilder+=node
            counter+=1
          }
    
        resultsBuilder.result()
    
      }
    
      def measureTime[K](f: => K):(Long,K) = {
        val startMutable = System.currentTimeMillis()
        val result = f
        val endMutable = System.currentTimeMillis()
        (endMutable-startMutable,result)
      }
    
      def main(args:Array[String]) = {
        val maxResultSetting=10
        (1 to 10).foreach{
          tryNumber =>
            println("Try number " +tryNumber)
            val (mutableTime, mutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
            val (immutableTime, immutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
            val (unparallelTime, unparallelResult) = measureTime(findUnParallel(unparallelMap,_%2==0,maxResultSetting))
            assert(mutableResult.size==maxResultSetting)
            assert(immutableResult.size==maxResultSetting)
            assert(unparallelResult.size==maxResultSetting)
            println(" The mutable version has taken " + mutableTime + " milliseconds")
            println(" The immutable version has taken " + immutableTime + " milliseconds")
            println(" The unparallel version has taken " + unparallelTime + " milliseconds")
         }
      }
    
    }
    

    With this code, I have systematically the parallel (both mutable and immutable version of the input map) about 3,5 time faster then the unparallel on my machine.

    0 讨论(0)
  • 2021-02-08 13:29

    You could try to get an iterator and then create a lazy list (a Stream) where you filter (with your predicate) and take the number of elements you want. Because it is a non strict, this 'taking' of elements is not evaluated. Afterwards you can force the execution by adding ".par" to the whole thing and achieve parallelization.

    Example code:

    A parallelized map with random values (simulating your parallel hash map):

    scala> myMap
    res14: scala.collection.parallel.immutable.ParMap[Int,Int] = ParMap(66978401 -> -1331298976, 256964068 -> 126442706, 1698061835 -> 1622679396, -1556333580 -> -1737927220, 791194343 -> -591951714, -1907806173 -> 365922424, 1970481797 -> 162004380, -475841243 -> -445098544, -33856724 -> -1418863050, 1851826878 -> 64176692, 1797820893 -> 405915272, -1838192182 -> 1152824098, 1028423518 -> -2124589278, -670924872 -> 1056679706, 1530917115 -> 1265988738, -808655189 -> -1742792788, 873935965 -> 733748120, -1026980400 -> -163182914, 576661388 -> 900607992, -1950678599 -> -731236098)
    

    Get an iterator and create a Stream from the iterator and filter it. In this case my predicate is only accepting pairs (of the value member of the map). I want to get 10 even elements, so I take 10 elements which will only get evaluated when I force it to:

    scala> val mapIterator = myMap.toIterator
    mapIterator: Iterator[(Int, Int)] = HashTrieIterator(20)
    
    
    scala> val r = Stream.continually(mapIterator.next()).filter(_._2 % 2 == 0).take(10)
    r: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), ?)
    

    Finally, I force the evaluation which only gets 10 elements as planned

    scala> r.force
    res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), (256964068,126442706), (1698061835,1622679396), (-1556333580,-1737927220), (791194343,-591951714), (-1907806173,365922424), (1970481797,162004380), (-475841243,-445098544), (-33856724,-1418863050), (1851826878,64176692))
    

    This way you only get the number of elements you want (without needing to process the remaining elements) and you parallelize the process without locks, atomics or breaks.

    Please compare this to your solutions to see if it is any good.

    0 讨论(0)
  • 2021-02-08 13:45

    I would first do parallel scan in which variable maxResults would be threadlocal. This would find up to (maxResults * numberOfThreads) results.

    Then I would do single threaded scan to reduce it to maxResults.

    0 讨论(0)
提交回复
热议问题