问题
Sieve of Eratosthenes memory constraint issue
Im currently trying to implement a version of the sieve of eratosthenes for a Kattis problem, however, I am running into some memory constraints that my implementation wont pass.
Here is a link to the problem statement. In short the problem wants me to first return the amount of primes less or equal to n and then solve for a certain number of queries if a number i is a prime or not. There is a constraint of 50 MB memory usage as well as only using the standard libraries of python (no numpy etc). The memory constraint is where I am stuck.
Here is my code so far:
import sys
def sieve_of_eratosthenes(xs, n):
count = len(xs) + 1
p = 3 # start at three
index = 0
while p*p < n:
for i in range(index + p, len(xs), p):
if xs[i]:
xs[i] = 0
count -= 1
temp_index = index
for i in range(index + 1, len(xs)):
if xs[i]:
p = xs[i]
temp_index += 1
break
temp_index += 1
index = temp_index
return count
def isPrime(xs, a):
if a == 1:
return False
if a == 2:
return True
if not (a & 1):
return False
return bool(xs[(a >> 1) - 1])
def main():
n, q = map(int, sys.stdin.readline().split(' '))
odds = [num for num in range(2, n+1) if (num & 1)]
print(sieve_of_eratosthenes(odds, n))
for _ in range(q):
query = int(input())
if isPrime(odds, query):
print('1')
else:
print('0')
if __name__ == "__main__":
main()
I've done some improvements so far, like only keeping a list of all odd numbers which halves the memory usage. I am also certain that the code works as intended when calculating the primes (not getting the wrong answer). My question is now, how can I make my code even more memory efficient? Should I use some other data structures? Replace my list of integers with booleans? Bitarray?
Any advice is much appreciated!
EDIT
After some tweaking to the code in python I hit a wall where my implementation of a segmented sieve would not pass the memory requirements.
Instead, I chose to implement the solution in Java, which took very little effort. Here is the code:
public int sieveOfEratosthenes(int n){
sieve = new BitSet((n+1) / 2);
int count = (n + 1) / 2;
for (int i=3; i*i <= n; i += 2){
if (isComposite(i)) {
continue;
}
// Increment by two, skipping all even numbers
for (int c = i * i; c <= n; c += 2 * i){
if(!isComposite(c)){
setComposite(c);
count--;
}
}
}
return count;
}
public boolean isComposite(int k) {
return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
}
public void setComposite(int k) {
sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
}
public boolean isPrime(int a) {
if (a < 3)
return a > 1;
if (a == 2)
return true;
if ((a & 1) == 1)
return !isComposite(a);
else
return false;
}
public void run() throws Exception{
BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
String[] line = scan.readLine().split(" ");
int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
System.out.println(sieveOfEratosthenes(n));
for (int i=0; i < q; i++){
line = scan.readLine().split(" ");
System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
}
}
I Have personally not found a way to implement this BitSet solution in Python (using only the standard library).
If anyone stumbles across a neat implementation to the problem in python, using a segmented sieve, bitarray or something else, I would be interested to see the solution.
回答1:
This is a very challenging problem indeed. With a maximum possible N
of 10^8, using one byte per value results in almost 100 MB of data assuming no overhead whatsoever. Even halving the data by only storing odd numbers will put you very close to 50 MB after overhead is considered.
This means the solution will have to make use of one or more of a few strategies:
- Using a more efficient data type for our array of primality flags. Python lists maintain an array of pointers to each list item (4 bytes each on a 64 bit python). We effectively need raw binary storage, which pretty much only leaves
bytearray
in standard python. - Using only one bit per value in the sieve instead of an entire byte (Bool technically only needs one bit, but typically uses a full byte).
- Sub-dividing to remove even numbers, and possibly also multiples of 3, 5, 7 etc.
- Using a segmented sieve
I initially tried to solve the problem by storing only 1 bit per value in the sieve, and while the memory usage was indeed within the requirements, Python's slow bit manipulation pushed the execution time far too long. It also was rather difficult to figure out the complex indexing to make sure the correct bits were being counted reliably.
I then implemented the odd numbers only solution using a bytearray and while it was quite a bit faster, the memory was still an issue.
Bytearray odd numbers implementation:
class Sieve:
def __init__(self, n):
self.not_prime = bytearray(n+1)
self.not_prime[0] = self.not_prime[1] = 1
for i in range(2, int(n**.5)+1):
if self.not_prime[i] == 0:
self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
self.n_prime = n + 1 - sum(self.not_prime)
def is_prime(self, n):
return int(not self.not_prime[n])
def main():
n, q = map(int, input().split())
s = Sieve(n)
print(s.n_prime)
for _ in range(q):
i = int(input())
print(s.is_prime(i))
if __name__ == "__main__":
main()
Further reduction in memory from this should* make it work.
EDIT: also removing multiples of 2 and 3 did not seem to be enough memory reduction even though guppy.hpy().heap() seemed to suggest my usage was in fact a bit under 50MB. 🤷♂️
回答2:
There's a trick I learned just yesterday - if you divide the numbers into groups of 6, only 2 of the 6 may be prime. The others can be evenly divided by either 2 or 3. That means it only takes 2 bits to track the primality of 6 numbers; a byte containing 8 bits can track primality for 24 numbers! This greatly diminishes the memory requirements of your sieve.
In Python 3.7.5 64 bit on Windows 10, the following code didn't go over 36.4 MB.
remainder_bit = [0, 0x01, 0, 0, 0, 0x02,
0, 0x04, 0, 0, 0, 0x08,
0, 0x10, 0, 0, 0, 0x20,
0, 0x40, 0, 0, 0, 0x80]
def is_prime(xs, a):
if a <= 3:
return a > 1
index, rem = divmod(a, 24)
bit = remainder_bit[rem]
if not bit:
return False
return not (xs[index] & bit)
def sieve_of_eratosthenes(xs, n):
count = (n // 3) + 1 # subtract out 1 and 4, add 2 3 and 5
p = 5
while p*p <= n:
if is_prime(xs, p):
for i in range(5 * p, n + 1, p):
index, rem = divmod(i, 24)
bit = remainder_bit[rem]
if bit and not (xs[index] & bit):
xs[index] |= bit
count -= 1
p += 2
if is_prime(xs, p):
for i in range(5 * p, n + 1, p):
index, rem = divmod(i, 24)
bit = remainder_bit[rem]
if bit and not (xs[index] & bit):
xs[index] |= bit
count -= 1
p += 4
return count
def init_sieve(n):
return bytearray((n + 23) // 24)
n = 100000000
xs = init_sieve(n)
sieve_of_eratosthenes(xs, n)
5761455
sum(is_prime(xs, i) for i in range(n+1))
5761455
Edit: the key to understanding how this works is that a sieve creates a repeating pattern. For the primes 2 and 3 the pattern repeats every 2*3 or 6 numbers, and of those 6, 4 have been rendered impossible to be prime leaving only 2. There's nothing limiting you in the choices of prime numbers to generate the pattern, except perhaps for the law of diminishing returns. I decided to try adding 5 to the mix, making the pattern repeat every 2*3*5=30 numbers. Out of these 30 numbers only 8 can be prime, meaning each byte can track 30 numbers instead of the 24 above! That gives you a 20% advantage in memory usage.
Here's the updated code. I also simplified it a bit and took out the counting of primes as it went along.
remainder_bit30 = [0, 0x01, 0, 0, 0, 0, 0, 0x02, 0, 0,
0, 0x04, 0, 0x08, 0, 0, 0, 0x10, 0, 0x20,
0, 0, 0, 0x40, 0, 0, 0, 0, 0, 0x80]
def is_prime(xs, a):
if a <= 5:
return (a > 1) and (a != 4)
index, rem = divmod(a, 30)
bit = remainder_bit30[rem]
return (bit != 0) and not (xs[index] & bit)
def sieve_of_eratosthenes(xs):
n = 30 * len(xs) - 1
p = 0
while p*p < n:
for offset in (1, 7, 11, 13, 17, 19, 23, 29):
p += offset
if is_prime(xs, p):
for i in range(p * p, n + 1, p):
index, rem = divmod(i, 30)
if index < len(xs):
bit = remainder_bit30[rem]
xs[index] |= bit
p -= offset
p += 30
def init_sieve(n):
b = bytearray((n + 30) // 30)
return b
回答3:
I think you can try by using a list of booleans to mark whether its index is prime or not:
def sieve_of_erato(range_max):
primes_count = range_max
is_prime = [True for i in range(range_max + 1)]
# Cross out all even numbers first.
for i in range(4, range_max, 2):
is_prime[i] = False
primes_count -=1
i = 3
while i * i <= range_max:
if is_prime[i]:
# Update all multiples of this prime number
# CAREFUL: Take note of the range args.
# Reason for i += 2*i instead of i += i:
# Since p and p*p, both are odd, (p*p + p) will be even,
# which means that it would have already been marked before
for multiple in range(i * i, range_max + 1, i * 2):
is_prime[multiple] = False
primes_count -= 1
i += 1
return primes_count
def main():
num_primes = sieve_of_erato(100)
print(num_primes)
if __name__ == "__main__":
main()
You can use the is_prime
array to check whether a number is prime or not later on by simply checking is_prime[number] == True
.
If this doesn't work, then try segmented sieve.
As a bonus, you might be surprised to know that there is a way to generate the sieve in O(n)
rather than O(nloglogn)
. Check the code here.
回答4:
Here is an example of a segmented sieve approach that should not exceed 8MB of memory.
def primeSieve(n,X,window=10**6):
primes = [] # only store minimum number of primes to shift windows
primeCount = 0 # count primes beyond the ones stored
flags = list(X) # numbers will be replaced by 0 or 1 as we progress
base = 1 # number corresponding to 1st element of sieve
isPrime = [False]+[True]*(window-1) # starting sieve
def flagPrimes(): # flag x values for current sieve window
flags[:] = [isPrime[x-base]*1 if x in range(base,base+window) else x
for x in flags]
for p in (2,*range(3,n+1,2)): # potential primes: 2 and odd numbers
if p >= base+window: # shift sieve window as needed
flagPrimes() # set X flags before shifting window
isPrime = [True]*window # initialize next sieve window
base = p # 1st number in window
for k in primes: # update sieve using known primes
if k>base+window:break
i = (k-base%k)%k + k*(k==p)
isPrime[i::k] = (False for _ in range(i,window,k))
if not isPrime[p-base]: continue
primeCount += 1 # count primes
if p*p<=n:primes.append(p) # store shifting primes, update sieve
isPrime[p*p-base::p] = (False for _ in range(p*p-base,window,p))
flagPrimes() # update flags with last window (should cover the rest of them)
return primeCount,flags
output:
print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]
print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]
You can play with the window size to get the best trade off between execution time and memory consumption. The execution time (on my laptop) is still rather long for large values of n
though:
from timeit import timeit
for w in range(3,9):
t = timeit(lambda:primeSieve(10**8,[],10**w),number=1)
print(f"10e{w} window:",t)
10e3 window: 119.463959956
10e4 window: 33.33273301199999
10e5 window: 24.153761258999992
10e6 window: 24.649398391000005
10e7 window: 27.616014667
10e8 window: 27.919413531000004
Strangely enough, window sizes beyond 10^6 give worse performance. The sweet spot seems to be somewhere between 10^5 and 10^6. A window of 10^7 would exceed your 50MB limit anyway.
回答5:
I had another idea on how to generate primes quickly in a memory efficient way. It is based on the same concept as the Sieve of Eratosthenes but uses a dictionary to hold the next value that each prime will invalidate (i.e. skip). This only requires storage of one dictionary entry per prime up to the square root of n
.
def genPrimes(maxPrime):
if maxPrime>=2: yield 2 # special processing for 2
primeSkips = dict() # skipValue:prime
for n in range(3,maxPrime+1,2):
if n not in primeSkips: # if not in skip list, it is a new prime
yield n
if n*n <= maxPrime: # first skip will be at n^2
primeSkips[n*n] = n
continue
prime = primeSkips.pop(n) # find next skip for n's prime
skip = n+2*prime
while skip in primeSkips: # must not already be skipped
skip += 2*prime
if skip<=maxPrime: # don't skip beyond maxPrime
primeSkips[skip]=prime
Using this, the primeSieve function can simply run through the prime numbers, count them, and flag the x values:
def primeSieve(n,X):
primeCount = 0
nonPrimes = set(X)
for prime in genPrimes(n):
primeCount += 1
nonPrimes.discard(prime)
return primeCount,[0 if x in nonPrimes else 1 for x in X]
print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]
print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]
This runs slightly faster than my previous answer and only consumes 78K of memory to generate primes up to 10^8 (in 21 seconds).
来源:https://stackoverflow.com/questions/62899578/making-sieve-of-eratosthenes-more-memory-efficient-in-python