How to share work roughly evenly between processes in MPI despite the array_size not being cleanly divisible by the number of processes?

前端 未结 7 1558
借酒劲吻你
借酒劲吻你 2021-01-04 10:25

Hi all, I have an array of length N, and I\'d like to divide it as best as possible between \'size\' processors. N/size has a remainder, e.g. 1000 array elements divided b

7条回答
  •  天涯浪人
    2021-01-04 10:51

    I had a similar problem, and here is my non optimum solution with Python and mpi4py API. An optimum solution would take into account how the processors are laid out, here extra work is ditributed to lower ranks. The uneven workload only differ by one task, so it should not be a big deal in general.

    from mpi4py import MPI
    import sys
    def get_start_end(comm,N):
        """
        Distribute N consecutive things (rows of a matrix , blocks of a 1D array)
        as evenly as possible over a given communicator.
        Uneven workload (differs by 1 at most) is on the initial ranks.
    
        Parameters
        ----------
        comm: MPI communicator
        N:  int
        Total number of things to be distributed.
    
        Returns
        ----------
        rstart: index of first local row
        rend: 1 + index of last row
    
        Notes
        ----------
        Index is zero based.
        """
    
        P      = comm.size
        rank   = comm.rank
        rstart = 0
        rend   = N
        if P >= N:
            if rank < N:
                rstart = rank
                rend   = rank + 1
            else:
                rstart = 0
                rend   = 0
        else:
            n = N//P # Integer division PEP-238
            remainder = N%P
            rstart    = n * rank
            rend      = n * (rank+1)
            if remainder:
                if rank >= remainder:
                    rstart += remainder
                    rend   += remainder
                else:
                    rstart += rank
                    rend   += rank + 1
        return rstart, rend
    
    if __name__ == '__main__':
        comm = MPI.COMM_WORLD
        n = int(sys.argv[1])
        print(comm.rank,get_start_end(comm,n))
    

提交回复
热议问题