问题
Spoiler alert: this is related to Problem 14 from Project Euler.
The following code takes around 15s to run. I have a non-recursive Java solution that runs in 1s. I think I should be able to get this code much closer to that.
import Data.List
collatz a 1 = a
collatz a x
| even x = collatz (a + 1) (x `div` 2)
| otherwise = collatz (a + 1) (3 * x + 1)
main = do
print ((foldl1' max) . map (collatz 1) $ [1..1000000])
I have profiled with +RHS -p
and noticed that the allocated memory is large, and grows as the input grows. For n = 100,000
1gb is allocated(!), for n = 1,000,000
13gb(!!) is allocated.
Then again, -sstderr
shows that although lots of bytes were allocated, total memory use was 1mb, and productivity was 95%+, so maybe 13gb is red-herring.
I can think of a few possibilities:
Something isn't as strict as it needs to be. I've already discovered
foldl1'
, but maybe I need to do more? Is it possible to markcollatz
as strict (does that even make sense?collatz
isn't tail-call optimizing. I think it should be but don't know a way to confirm.The compiler isn't doing some optimizations I think it should - for instance only two results of
collatz
need to be in memory at any one time (max and current)
Any suggestions?
This is pretty much a duplicate of Why is this Haskell expression so slow?, though I will note that the fast Java solution does not have to perform any memoization. Are there any ways to speed this up without having to resort to it?
For reference, here is my profiling output:
Wed Dec 28 09:33 2011 Time and Allocation Profiling Report (Final)
scratch +RTS -p -hc -RTS
total time = 5.12 secs (256 ticks @ 20 ms)
total alloc = 13,229,705,716 bytes (excludes profiling overheads)
COST CENTRE MODULE %time %alloc
collatz Main 99.6 99.4
individual inherited
COST CENTRE MODULE no. entries %time %alloc %time %alloc
MAIN MAIN 1 0 0.0 0.0 100.0 100.0
CAF Main 208 10 0.0 0.0 100.0 100.0
collatz Main 215 1 0.0 0.0 0.0 0.0
main Main 214 1 0.4 0.6 100.0 100.0
collatz Main 216 0 99.6 99.4 99.6 99.4
CAF GHC.IO.Handle.FD 145 2 0.0 0.0 0.0 0.0
CAF System.Posix.Internals 144 1 0.0 0.0 0.0 0.0
CAF GHC.Conc 128 1 0.0 0.0 0.0 0.0
CAF GHC.IO.Handle.Internals 119 1 0.0 0.0 0.0 0.0
CAF GHC.IO.Encoding.Iconv 113 5 0.0 0.0 0.0 0.0
And -sstderr:
./scratch +RTS -sstderr
525
21,085,474,908 bytes allocated in the heap
87,799,504 bytes copied during GC
9,420 bytes maximum residency (1 sample(s))
12,824 bytes maximum slop
1 MB total memory in use (0 MB lost due to fragmentation)
Generation 0: 40219 collections, 0 parallel, 0.40s, 0.51s elapsed
Generation 1: 1 collections, 0 parallel, 0.00s, 0.00s elapsed
INIT time 0.00s ( 0.00s elapsed)
MUT time 35.38s ( 36.37s elapsed)
GC time 0.40s ( 0.51s elapsed)
RP time 0.00s ( 0.00s elapsed) PROF time 0.00s ( 0.00s elapsed)
EXIT time 0.00s ( 0.00s elapsed)
Total time 35.79s ( 36.88s elapsed) %GC time 1.1% (1.4% elapsed) Alloc rate 595,897,095 bytes per MUT second
Productivity 98.9% of total user, 95.9% of total elapsed
And Java solution (not mine, taken from Project Euler forums with memoization removed):
public class Collatz {
public int getChainLength( int n )
{
long num = n;
int count = 1;
while( num > 1 )
{
num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
count++;
}
return count;
}
public static void main(String[] args) {
Collatz obj = new Collatz();
long tic = System.currentTimeMillis();
int max = 0, len = 0, index = 0;
for( int i = 3; i < 1000000; i++ )
{
len = obj.getChainLength(i);
if( len > max )
{
max = len;
index = i;
}
}
long toc = System.currentTimeMillis();
System.out.println(toc-tic);
System.out.println( "Index: " + index + ", length = " + max );
}
}
回答1:
At first, I thought you should try putting an exclamation mark before a in collatz
:
collatz !a 1 = a
collatz !a x
| even x = collatz (a + 1) (x `div` 2)
| otherwise = collatz (a + 1) (3 * x + 1)
(You'll need to put {-# LANGUAGE BangPatterns #-}
at the top of your source file for this to work.)
My reasoning went as follows: The problem is that you're building up a massive thunk in the first argument to collatz: it starts off as 1
, and then becomes 1 + 1
, and then becomes (1 + 1) + 1
, ... all without ever being forced. This bang pattern forces the first argument of collatz
to be forced whenever a call is made, so it starts off as 1, and then becomes 2, and so on, without building up a large unevaluated thunk: it just stays as an integer.
Note that a bang pattern is just shorthand for using seq; in this case, we could rewrite collatz
as follows:
collatz a _ | seq a False = undefined
collatz a 1 = a
collatz a x
| even x = collatz (a + 1) (x `div` 2)
| otherwise = collatz (a + 1) (3 * x + 1)
The trick here is to force a in the guard, which then always evaluates to False (and so the body is irrelevant). Then evaluation continues with the next case, a having already been evaluated. However, a bang pattern is clearer.
Unfortunately, when compiled with -O2
, this doesn't run any faster than the original! What else can we try? Well, one thing we can do is assume that the two numbers never overflow a machine-sized integer, and give collatz
this type annotation:
collatz :: Int -> Int -> Int
We'll leave the bang pattern there, since we should still avoid building up thunks, even if they aren't the root of the performance problem. This brings the time down to 8.5 seconds on my (slow) computer.
The next step is to try bringing this closer to the Java solution. The first thing to realise is that in Haskell, div
behaves in a more mathematically correct manner with respect to negative integers, but is slower than "normal" C division, which in Haskell is called quot
. Replacing div
with quot
brought the runtime down to 5.2 seconds, and replacing x `quot` 2
with x `shiftR` 1
(importing Data.Bits) to match the Java solution brought it down to 4.9 seconds.
This is about as low as I can get it for now, but I think this is a pretty good result; since your computer is faster than mine, it should hopefully be even closer to the Java solution.
Here's the final code (I did a little bit of clean-up on the way):
{-# LANGUAGE BangPatterns #-}
import Data.Bits
import Data.List
collatz :: Int -> Int
collatz = collatz' 1
where collatz' :: Int -> Int -> Int
collatz' !a 1 = a
collatz' !a x
| even x = collatz' (a + 1) (x `shiftR` 1)
| otherwise = collatz' (a + 1) (3 * x + 1)
main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]
Looking at the GHC Core for this program (with ghc-core), I think this is probably about as good as it gets; the collatz
loop uses unboxed integers and the rest of the program looks OK. The only improvement I can think of would be eliminating the boxing from the map collatz [1..1000000]
iteration.
By the way, don't worry about the "total alloc" figure; it's the total memory allocated over the lifetime of the program, and it never decreases even when the GC reclaims that memory. Figures of multiple terabytes are common.
回答2:
You could lose the list and the bang patterns and still get the same performance by using the stack instead.
import Data.List
import Data.Bits
coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
let a = coll (n - 1)
collatz a 1 = a
collatz a x
| even x = collatz (a + 1) (x `shiftR` 1)
| otherwise = collatz (a + 1) (3 * x + 1)
in max a (collatz 1 n)
main = do
print $ coll 100000
One problem with this is that you will have to increase the size of the stack for large inputs, like 1_000_000.
update:
Here is a tail recursive version that doesn't suffer from the stack overflow problem.
import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
| x == 1 = (a,x)
| even x = collatz (a + 1) (x `quot` 2)
| otherwise = collatz (a + 1) (3 * x + 1)
coll :: Word -> Word
coll n = collTail 0 n
where
collTail m 1 = m
collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)
Notice the use of Word
instead of Int
. It makes a difference in performance. You could still use the bang patterns if you want, and that would nearly double the performance.
回答3:
One thing I found made a surprising difference in this problem. I stuck with the straight recurrence relation rather than folding, you should pardon the expression, the counting in with it. Rewriting
collatz n = if even n then n `div` 2 else 3 * n + 1
as
collatz n = case n `divMod` 2 of
(n', 0) -> n'
_ -> 3 * n + 1
took 1.2 seconds off the runtime for my program on a system with a 2.8 GHz Athlon II X4 430 CPU. My initial faster version (2.3 seconds after the use of divMod):
{-# LANGUAGE BangPatterns #-}
import Data.List
import Data.Ord
collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
where collatzChainLen' n !l
| n == 1 = l
| otherwise = collatzChainLen' (collatz n) (l + 1)
collatz:: Int -> Int
collatz n = case n `divMod` 2 of
(n', 0) -> n'
_ -> 3 * n + 1
pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]
main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))
A perhaps more idiomatic Haskell version run in about 9.7 seconds (8.5 with divMod); it's identical save for
collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n
Using Data.List.Stream is supposed to allow stream fusion that would make this version run more like that with the explicit accumulation, but I can't find an Ubuntu libghc* package that has Data.List.Stream, so I can't yet verify that.
来源:https://stackoverflow.com/questions/8659345/why-is-this-simple-haskell-algorithm-so-slow