Making Sieve of Eratosthenes more memory efficient in python?

前端 未结 5 754
心在旅途
心在旅途 2021-01-20 00:41

Sieve of Eratosthenes memory constraint issue

Im currently trying to implement a version of the sieve of eratosthenes for a Kattis problem, however, I am running in

相关标签:
5条回答
  • 2021-01-20 00:59

    Here is an example of a segmented sieve approach that should not exceed 8MB of memory.

    def primeSieve(n,X,window=10**6): 
        primes     = []       # only store minimum number of primes to shift windows
        primeCount = 0        # count primes beyond the ones stored
        flags      = list(X)  # numbers will be replaced by 0 or 1 as we progress
        base       = 1        # number corresponding to 1st element of sieve
        isPrime    = [False]+[True]*(window-1) # starting sieve
        
        def flagPrimes(): # flag x values for current sieve window
            flags[:] = [isPrime[x-base]*1 if x in range(base,base+window) else x
                        for x in flags]
        for p in (2,*range(3,n+1,2)):       # potential primes: 2 and odd numbers
            if p >= base+window:            # shift sieve window as needed
                flagPrimes()                # set X flags before shifting window
                isPrime = [True]*window     # initialize next sieve window
                base    = p                 # 1st number in window
                for k in primes:            # update sieve using known primes 
                    if k>base+window:break
                    i = (k-base%k)%k + k*(k==p)  
                    isPrime[i::k] = (False for _ in range(i,window,k))
            if not isPrime[p-base]: continue
            primeCount += 1                 # count primes 
            if p*p<=n:primes.append(p)      # store shifting primes, update sieve
            isPrime[p*p-base::p] = (False for _ in range(p*p-base,window,p))
    
        flagPrimes() # update flags with last window (should cover the rest of them)
        return primeCount,flags     
            
    

    output:

    print(*primeSieve(9973,[1,2,3,4,9972,9973]))
    # 1229, [0, 1, 1, 0, 0, 1]
    
    print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
    # 5761455 [0, 1, 1, 0, 0, 1, 0]
    

    You can play with the window size to get the best trade off between execution time and memory consumption. The execution time (on my laptop) is still rather long for large values of n though:

    from timeit import timeit
    for w in range(3,9):
        t = timeit(lambda:primeSieve(10**8,[],10**w),number=1)
        print(f"10e{w} window:",t)
    
    10e3 window: 119.463959956
    10e4 window: 33.33273301199999
    10e5 window: 24.153761258999992
    10e6 window: 24.649398391000005
    10e7 window: 27.616014667
    10e8 window: 27.919413531000004
    

    Strangely enough, window sizes beyond 10^6 give worse performance. The sweet spot seems to be somewhere between 10^5 and 10^6. A window of 10^7 would exceed your 50MB limit anyway.

    0 讨论(0)
  • 2021-01-20 01:07

    There's a trick I learned just yesterday - if you divide the numbers into groups of 6, only 2 of the 6 may be prime. The others can be evenly divided by either 2 or 3. That means it only takes 2 bits to track the primality of 6 numbers; a byte containing 8 bits can track primality for 24 numbers! This greatly diminishes the memory requirements of your sieve.

    In Python 3.7.5 64 bit on Windows 10, the following code didn't go over 36.4 MB.

    remainder_bit = [0, 0x01, 0, 0, 0, 0x02,
                     0, 0x04, 0, 0, 0, 0x08,
                     0, 0x10, 0, 0, 0, 0x20,
                     0, 0x40, 0, 0, 0, 0x80]
    
    def is_prime(xs, a):
        if a <= 3:
            return a > 1
        index, rem = divmod(a, 24)
        bit = remainder_bit[rem]
        if not bit:
            return False
        return not (xs[index] & bit)
    
    def sieve_of_eratosthenes(xs, n):
        count = (n // 3) + 1 # subtract out 1 and 4, add 2 3 and 5
        p = 5
        while p*p <= n:
            if is_prime(xs, p):
                for i in range(5 * p, n + 1, p):
                    index, rem = divmod(i, 24)
                    bit = remainder_bit[rem]
                    if bit and not (xs[index] & bit):
                        xs[index] |= bit
                        count -= 1
            p += 2
            if is_prime(xs, p):
                for i in range(5 * p, n + 1, p):
                    index, rem = divmod(i, 24)
                    bit = remainder_bit[rem]
                    if bit and not (xs[index] & bit):
                        xs[index] |= bit
                        count -= 1
            p += 4
    
        return count
    
    
    def init_sieve(n):
        return bytearray((n + 23) // 24)
    
    n = 100000000
    xs = init_sieve(n)
    sieve_of_eratosthenes(xs, n)
    5761455
    sum(is_prime(xs, i) for i in range(n+1))
    5761455
    

    Edit: the key to understanding how this works is that a sieve creates a repeating pattern. For the primes 2 and 3 the pattern repeats every 2*3 or 6 numbers, and of those 6, 4 have been rendered impossible to be prime leaving only 2. There's nothing limiting you in the choices of prime numbers to generate the pattern, except perhaps for the law of diminishing returns. I decided to try adding 5 to the mix, making the pattern repeat every 2*3*5=30 numbers. Out of these 30 numbers only 8 can be prime, meaning each byte can track 30 numbers instead of the 24 above! That gives you a 20% advantage in memory usage.

    Here's the updated code. I also simplified it a bit and took out the counting of primes as it went along.

    remainder_bit30 = [0,    0x01, 0,    0,    0,    0,    0, 0x02, 0,    0,
                       0,    0x04, 0,    0x08, 0,    0,    0, 0x10, 0,    0x20,
                       0,    0,    0,    0x40, 0,    0,    0, 0,    0,    0x80]
    
    def is_prime(xs, a):
        if a <= 5:
            return (a > 1) and (a != 4)
        index, rem = divmod(a, 30)
        bit = remainder_bit30[rem]
        return (bit != 0) and not (xs[index] & bit)
    
    def sieve_of_eratosthenes(xs):
        n = 30 * len(xs) - 1
        p = 0
        while p*p < n:
            for offset in (1, 7, 11, 13, 17, 19, 23, 29):
                p += offset
                if is_prime(xs, p):
                    for i in range(p * p, n + 1, p):
                        index, rem = divmod(i, 30)
                        if index < len(xs):
                            bit = remainder_bit30[rem]
                            xs[index] |= bit
                p -= offset
            p += 30
    
    def init_sieve(n):
        b = bytearray((n + 30) // 30)
        return b
    
    0 讨论(0)
  • 2021-01-20 01:07

    I had another idea on how to generate primes quickly in a memory efficient way. It is based on the same concept as the Sieve of Eratosthenes but uses a dictionary to hold the next value that each prime will invalidate (i.e. skip). This only requires storage of one dictionary entry per prime up to the square root of n.

    def genPrimes(maxPrime):
        if maxPrime>=2: yield 2           # special processing for 2
        primeSkips = dict()               # skipValue:prime
        for n in range(3,maxPrime+1,2):
            if n not in primeSkips:       # if not in skip list, it is a new prime
                yield n
                if n*n <= maxPrime:       # first skip will be at n^2
                    primeSkips[n*n] = n
                continue
            prime = primeSkips.pop(n)     # find next skip for n's prime
            skip  = n+2*prime
            while skip in primeSkips:     # must not already be skipped
                skip += 2*prime                
            if skip<=maxPrime:            # don't skip beyond maxPrime
                primeSkips[skip]=prime           
    

    Using this, the primeSieve function can simply run through the prime numbers, count them, and flag the x values:

    def primeSieve(n,X):
        primeCount = 0
        nonPrimes  = set(X)
        for prime in genPrimes(n):
            primeCount += 1
            nonPrimes.discard(prime)
        return primeCount,[0 if x in nonPrimes else 1 for x in X]
    
    
    print(*primeSieve(9973,[1,2,3,4,9972,9973]))
    # 1229, [0, 1, 1, 0, 0, 1]
    
    print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
    # 5761455 [0, 1, 1, 0, 0, 1, 0]
    

    This runs slightly faster than my previous answer and only consumes 78K of memory to generate primes up to 10^8 (in 21 seconds).

    0 讨论(0)
  • 2021-01-20 01:22

    I think you can try by using a list of booleans to mark whether its index is prime or not:

    def sieve_of_erato(range_max):
        primes_count = range_max
        is_prime = [True for i in range(range_max + 1)]
        # Cross out all even numbers first.
        for i in range(4, range_max, 2):
            is_prime[i] = False
            primes_count -=1
        i = 3
        while i * i <= range_max:
            if is_prime[i]:
                # Update all multiples of this prime number
                # CAREFUL: Take note of the range args.
                # Reason for i += 2*i instead of i += i:
                # Since p and p*p, both are odd, (p*p + p) will be even,
                # which means that it would have already been marked before
                for multiple in range(i * i, range_max + 1, i * 2):
                    is_prime[multiple] = False
                    primes_count -= 1
            i += 1
    
        return primes_count
    
    
    def main():
        num_primes = sieve_of_erato(100)
        print(num_primes)
    
    
    if __name__ == "__main__":
        main()
    

    You can use the is_prime array to check whether a number is prime or not later on by simply checking is_prime[number] == True.

    If this doesn't work, then try segmented sieve.

    As a bonus, you might be surprised to know that there is a way to generate the sieve in O(n) rather than O(nloglogn). Check the code here.

    0 讨论(0)
  • 2021-01-20 01:24

    This is a very challenging problem indeed. With a maximum possible N of 10^8, using one byte per value results in almost 100 MB of data assuming no overhead whatsoever. Even halving the data by only storing odd numbers will put you very close to 50 MB after overhead is considered.

    This means the solution will have to make use of one or more of a few strategies:

    1. Using a more efficient data type for our array of primality flags. Python lists maintain an array of pointers to each list item (4 bytes each on a 64 bit python). We effectively need raw binary storage, which pretty much only leaves bytearray in standard python.
    2. Using only one bit per value in the sieve instead of an entire byte (Bool technically only needs one bit, but typically uses a full byte).
    3. Sub-dividing to remove even numbers, and possibly also multiples of 3, 5, 7 etc.
    4. Using a segmented sieve

    I initially tried to solve the problem by storing only 1 bit per value in the sieve, and while the memory usage was indeed within the requirements, Python's slow bit manipulation pushed the execution time far too long. It also was rather difficult to figure out the complex indexing to make sure the correct bits were being counted reliably.

    I then implemented the odd numbers only solution using a bytearray and while it was quite a bit faster, the memory was still an issue.

    Bytearray odd numbers implementation:

    class Sieve:
        def __init__(self, n):
            self.not_prime = bytearray(n+1)
            self.not_prime[0] = self.not_prime[1] = 1
            for i in range(2, int(n**.5)+1):
                if self.not_prime[i] == 0:
                    self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
            self.n_prime = n + 1 - sum(self.not_prime)
            
        def is_prime(self, n):
            return int(not self.not_prime[n])
            
    
    
    def main():
        n, q = map(int, input().split())
        s = Sieve(n)
        print(s.n_prime)
        for _ in range(q):
            i = int(input())
            print(s.is_prime(i))
    
    if __name__ == "__main__":
        main()
    

    Further reduction in memory from this should* make it work.

    EDIT: also removing multiples of 2 and 3 did not seem to be enough memory reduction even though guppy.hpy().heap() seemed to suggest my usage was in fact a bit under 50MB.

    0 讨论(0)
提交回复
热议问题