Reasonably efficient pure-functional matrix product in Haskell?

前端 未结 3 383
梦谈多话
梦谈多话 2021-01-31 05:26

I know Haskell a little bit, and I wonder if it\'s possible to write something like a matrix-matrix product in Haskell that is all of the following:

  1. Pure-f
相关标签:
3条回答
  • 2021-01-31 05:34

    There are two angles to attack this problem on.

    1. Research, along these lines, is ongoing. Now, there are plenty of Haskell programmers who are smarter than me; a fact I am constantly reminded of and humbled by. One of them may come by and correct me, but I don't know of any simple way to compose safe Haskell primitives into a top-of-the-line matrix multiplication routine. Those papers that you talk about sound like a good start.

      However, I'm not a computer science researcher. I wonder if it's possible to keep simple things simple in Haskell.

      If you cite those papers, maybe we could help decipher them.

    2. Software engineering, along these lines, is well-understood, straightforward, and even easy. A savvy Haskell coder would use a thin wrapper around BLAS, or look for such a wrapper in Hackage.

    Deciphering cutting-edge research is an ongoing process that shifts knowledge from the researchers to the engineers. It was a computer science researcher, C.A.R. Hoare, who first discovered quicksort and published a paper about it. Today, it is a rare computer science graduate who can't personally implement quicksort from memory (at least, those that graduated recently).

    Bit of history

    Almost this exact question has been asked in history a few times before.

    1. Is it possible to write matrix arithmetic in Fortran that is as fast as assembly?

    2. Is it possible to write matrix arithmetic in C that is as fast as Fortran?

    3. Is it possible to write matrix arithmetic in Java that is as fast as C?

    4. Is it possible to write matrix arithmetic in Haskell that is as fast as Java?

    So far, the answer has always been, "not yet", followed by "close enough". The advances that make this possible come from improvements in writing code, improvements to compilers, and improvements in the programming language itself.

    As a specific example, C was not able to surpass Fortran in many real-world applications until C99 compilers became widespread in the past decade. In Fortran, different arrays are assumed to have distinct storage from each other, whereas in C this is not generally the case. Fortran compilers were therefore permitted to make optimizations that C compilers could not. Well, not until C99 came out and you could add the restrict qualifier to your code.

    The Fortran compilers waited. Eventually the processors became complex enough that good assembly writing became more difficult, and the compilers became sophisticated enough that the Fortran was fast.

    Then C programmers waited until the 2000s for the ability to write code that matched Fortran. Until that point, they used libraries written in Fortran or assembler (or both), or put up with the reduced speed.

    The Java programers, likewise, had to wait for JIT compilers, and had to wait for specific optimizations to appear. JIT compilers were originally an esoteric research concept until they became a part of daily life. Bounds checking optimization was also necessary in order to avoid a test and branch for every array access.

    Back to Haskell

    So, it is clear the Haskell programmers are "waiting", just like the Java, C, and Fortran programmers before them. What are we waiting for?

    • Maybe we're just waiting for someone to write the code, and show us how it's done.

    • Maybe we're waiting for the compilers to get better.

    • Maybe we're waiting for an update to the Haskell language itself.

    And maybe we're waiting for some combination of the above.

    About purity

    Purity and monads get conflated a lot in Haskell. The reason for this is because in Haskell, impure functions always use the IO monad. For example, the State monad is 100% pure. So when you say, "pure" and "type signature does not use the State monad", those are actually completely independent and separate requirements.

    However, you can also use the IO monad in the implementation of pure functions, and in fact, it's quite easy:

    addSix :: Int -> Int
    addSix n = unsafePerformIO $ return (n + 6)
    

    Okay, yes, that's a stupid function, but it is pure. It's even obviously pure. The test for purity is twofold:

    1. Does it give the same result for the same inputs? Yes.

    2. Does it produce any semantically significant side effects? No.

    The reason we like purity is because pure functions are easier to compose and manipulate than impure functions are. How they're implemented doesn't matter as much. I don't know if you're aware of this, but Integer and ByteString are both basically wrappers around impure C functions, even though the interface is pure. (There's work on a new implementation of Integer, I don't know how far it is.)

    Final answer

    The question is whether Haskell's approach (purity encoded in the type system) is compatible with efficiency, memory-safety and simplicity.

    The answer to that part is "yes", since we can take simple functions from BLAS and put them in a pure, type-safe wrapper. The wrapper's type encodes the safety of the function, even though the Haskell compiler is unable to prove that the function's implementation is pure. Our use of unsafePerformIO in its implementation is both an acknowledgement that we have proven the purity of the function, and it is also a concession that we couldn't figure out a way to express that proof in Haskell's type system.

    But the answer is also "not yet", since I don't know how to implement the function entirely in Haskell as such.

    Research in this area is ongoing. People are looking at proof systems like Coq and new languages like Agda, as well as developments in GHC itself. In order to see what kind of type system we'd need to prove that high-performance BLAS routines can be used safely. These tools can also be used with other languages like Java. For example, you could write a proof in Coq that your Java implementation is pure.

    I apologize for the "yes and no" answer, but no other answer would recognize both the contributions of engineers (who care about "yes") and researchers (who care about "not yet").

    P.S. Please cite the papers.

    0 讨论(0)
  • 2021-01-31 05:39

    As efficient as, say, Java. For concreteness, let's assume I'm talking about a simple triple loop, single precision, contiguous column-major layout (float[], not float[][]) and matrices of size 1000x1000, and a single-core CPU. (If you are getting 0.5-2 floating point operations per cycle, you are probably in the ballpark)

    So something like

    public class MatrixProd {
        static float[] matProd(float[] a, int ra, int ca, float[] b, int rb, int cb) {
            if (ca != rb) {
                throw new IllegalArgumentException("Matrices not fitting");
            }
            float[] c = new float[ra*cb];
            for(int i = 0; i < ra; ++i) {
                for(int j = 0; j < cb; ++j) {
                    float sum = 0;
                    for(int k = 0; k < ca; ++k) {
                        sum += a[i*ca+k]*b[k*cb+j];
                    }
                    c[i*cb+j] = sum;
                }
            }
            return c;
        }
    
        static float[] mkMat(int rs, int cs, float x, float d) {
            float[] arr = new float[rs*cs];
            for(int i = 0; i < rs; ++i) {
                for(int j = 0; j < cs; ++j) {
                    arr[i*cs+j] = x;
                    x += d;
                }
            }
            return arr;
        }
    
        public static void main(String[] args) {
            int sz = 100;
            float strt = -32, del = 0.0625f;
            if (args.length > 0) {
                sz = Integer.parseInt(args[0]);
            }
            if (args.length > 1) {
                strt = Float.parseFloat(args[1]);
            }
            if (args.length > 2) {
                del = Float.parseFloat(args[2]);
            }
            float[] a = mkMat(sz,sz,strt,del);
            float[] b = mkMat(sz,sz,strt-16,del);
            System.out.println(a[sz*sz-1]);
            System.out.println(b[sz*sz-1]);
            long t0 = System.currentTimeMillis();
            float[] c = matProd(a,sz,sz,b,sz,sz);
            System.out.println(c[sz*sz-1]);
            long t1 = System.currentTimeMillis();
            double dur = (t1-t0)*1e-3;
            System.out.println(dur);
        }
    }
    

    I suppose? (I hadn't read the specs properly before coding, so the layout is row-major, but since the access pattern is the same, that doesn't make a difference as mixing layouts would, so I'll assume that's okay.)

    I haven't spent any time on thinking about a clever algorithm or low-level optimisation tricks (I wouldn't achieve much in Java with those anyway). I just wrote the simple loop, because

    I don't want this to sound like a challenge, but note that Java can satisfy all of the above easily

    And that's what Java gives easily, so I'll take that.

    (If you are getting 0.5-2 floating point operations per cycle, you are probably in the ballpark)

    Nowhere near, I'm afraid, neither in Java nor in Haskell. Too many cache misses to reach that throughput with the simple triple loop.

    Doing the same in Haskell, again no thinking about being clever, a plain straightforward triple loop:

    {-# LANGUAGE BangPatterns #-}
    module MatProd where
    
    import Data.Array.ST
    import Data.Array.Unboxed
    
    matProd :: UArray Int Float -> Int -> Int -> UArray Int Float -> Int -> Int -> UArray Int Float
    matProd a ra ca b rb cb =
        let (al,ah)     = bounds a
            (bl,bh)     = bounds b
            {-# INLINE getA #-}
            getA i j    = a!(i*ca + j)
            {-# INLINE getB #-}
            getB i j    = b!(i*cb + j)
            {-# INLINE idx #-}
            idx i j     = i*cb + j
        in if al /= 0 || ah+1 /= ra*ca || bl /= 0 || bh+1 /= rb*cb || ca /= rb
             then error $ "Matrices not fitting: " ++ show (ra,ca,al,ah,rb,cb,bl,bh)
             else runSTUArray $ do
                arr <- newArray (0,ra*cb-1) 0
                let outer i j
                        | ra <= i   = return arr
                        | cb <= j   = outer (i+1) 0
                        | otherwise = do
                            !x <- inner i j 0 0
                            writeArray arr (idx i j) x
                            outer i (j+1)
                    inner i j k !y
                        | ca <= k   = return y
                        | otherwise = inner i j (k+1) (y + getA i k * getB k j)
                outer 0 0
    
    mkMat :: Int -> Int -> Float -> Float -> UArray Int Float
    mkMat rs cs x d = runSTUArray $ do
        let !r = rs - 1
            !c = cs - 1
            {-# INLINE idx #-}
            idx i j = cs*i + j
        arr <- newArray (0,rs*cs-1) 0
        let outer i j y
                | r < i     = return arr
                | c < j     = outer (i+1) 0 y
                | otherwise = do
                    writeArray arr (idx i j) y
                    outer i (j+1) (y + d)
        outer 0 0 x
    

    and the calling module

    module Main (main) where
    
    import System.Environment (getArgs)
    import Data.Array.Unboxed
    
    import System.CPUTime
    import Text.Printf
    
    import MatProd
    
    main :: IO ()
    main = do
        args <- getArgs
        let (sz, strt, del) = case args of
                                (a:b:c:_) -> (read a, read b, read c)
                                (a:b:_)   -> (read a, read b, 0.0625)
                                (a:_)     -> (read a, -32, 0.0625)
                                _         -> (100, -32, 0.0625)
            a = mkMat sz sz strt del
            b = mkMat sz sz (strt - 16) del
        print (a!(sz*sz-1))
        print (b!(sz*sz-1))
        t0 <- getCPUTime
        let c = matProd a sz sz b sz sz
        print $ c!(sz*sz-1)
        t1 <- getCPUTime
        printf "%.6f\n" (fromInteger (t1-t0)*1e-12 :: Double)
    

    So we're doing almost exactly the same things in both languages. Compile the Haskell with -O2, the Java with javac

    $ java MatrixProd 1000 "-13.7" 0.013
    12915.623
    12899.999
    8.3592897E10
    8.193
    $ ./vmmult 1000 "-13.7" 0.013
    12915.623
    12899.999
    8.35929e10
    8.558699
    

    And the resulting times are quite close.

    And if we compile the Java code to native, with gcj -O3 -Wall -Wextra --main=MatrixProd -fno-bounds-check -fno-store-check -o jmatProd MatrixProd.java,

    $ ./jmatProd 1000 "-13.7" 0.013
    12915.623
    12899.999
    8.3592896512E10
    8.215
    

    there's still no big difference.

    As a special bonus, the same algorithm in C (gcc -O3):

    $ ./cmatProd 1000 "-13.7" 0.013
    12915.623047
    12899.999023
    8.35929e+10
    8.079759
    

    So this reveals no fundamental difference between straightforward Java and straightforward Haskell when it comes to computationally intensive tasks using floating point numbers (when dealing with integer arithmetic on medium to large numbers, the use of GMP by GHC makes Haskell outperform Java's BigInteger by a huge margin for many tasks, but that is of course a library issue, not a language one), and both are close to C with this algorithm.

    In all fairness, though, that is because the access pattern causes a cache-miss every other nanosecond, so in all three languages this computation is memory-bound.

    If we improve the access pattern by multiplying a row-major matrix with a column-major matrix, all become faster, the gcc-compiled C finishes it 1.18s, java takes 1.23s and the ghc-compiled Haskell takes around 5.8s, which can be reduced to 3 seconds by using the llvm backend.

    Here, the range-check by the array library really hurts. Using the unchecked array access (as one should, after checking for bugs, since the checks are already done in the code controlling the loops), GHC's native backend finishes in 2.4s, going via the llvm backend lets the computation finish in 1.55s, which is decent, although significantly slower than both C and Java. Using the primitives from GHC.Prim instead of the array library, the llvm backend produces code that runs in 1.16s (again, without bounds-checking on each access, but that only valid indices are produced during the computation can in this case easily be proved before, so here, no memory-safety is sacrificed¹; checking each access brings the time up to 1.96s, still significantly better than the bounds checking of the array library).

    Bottom line: GHC needs (much) faster branching for the bounds-checking, and there's room for improvement in the optimiser, but in principle, "Haskell's approach (purity encoded in the type system) is compatible with efficiency, memory-safety and simplicity", we're just not yet there. For the time being, one has to decide how much of which point one is willing to sacrifice.


    ¹ Yes, that's a special case, in general omitting the bounds-check does sacrifice memory-safety, or it is at least harder to prove that it doesn't.

    0 讨论(0)
  • 2021-01-31 05:39

    Like Java, Haskell is not the best language for writing numerical code.

    Haskell's numeric-heavy codegeneration is... average. It hasn't had the years of research behind it that the likes of Intel and GCC have.

    What Haskell gives you instead, is a way to cleanly interface your "fast" code with the rest of your application. Remember that 3% of code is responsible for 97% of your application's running time. 1

    With Haskell, you have a way to call these highly optimized functions in a way that interfaces extremely nicely with the rest of your code: via the very nice C Foreign Function Interface. In fact, if you so desired, you could write your numeric code in the assembly language of your architecture and get even more performance! Dipping into C for performance-heavy parts of your application isn't a bug - it's a feature.

    But I digress.

    By having these highly optimized functions isolated, and with a similar interface to the rest of your Haskell code, you could perform high level optimizations with Haskell's very powerful rewrite rules, which allow you to write rules such as reverse . reverse == id which automagically reduce complex expressions at compile time 2. This leads to extremely fast, purely functional, and easy-to-use libraries like Data.Text 3 and Data.Vector [4].

    By combining high and low levels of optimization, we end up with a much more optimized implementation, with each half ("C/asm", and "Haskell") relatively easy to read. The low level optimization is done in its native tongue (C or assembly), the high level optimization gets a special DSL (Haskell rewrite rules), and the rest of the code is oblivious to it completely.

    In conclusion, yes, Haskell can be faster than Java. But it cheats by going through C for the raw FLOPS. This is much harder to do in Java (as well as having a much higher overhead for Java's FFI), so it's avoided. In Haskell, it's natural. If your application spends an exorbitant amount of time doing numeric calculations, then maybe instead of looking at Haskell or Java, you look at Fortran for your needs. If your application spends a large portion of its time in a tiny part of performance-sensitive code, then the Haskell FFI is your best bet. If your application doesn't spend any time in numeric code... then use whatever you like. =)

    Haskell (nor Java, for that matter) isn't Fortran.

    1 These numbers were made up, but you get my point.

    2 http://www.cse.unsw.edu.au/~dons/papers/CLS07.html

    3 http://hackage.haskell.org/package/text

    [4] http://hackage.haskell.org/package/vector


    Now that that's out of the way, to answer your actual question:

    No, it's not currently smart to write your matrix multiplications in Haskell. At the moment, REPA is the canonical way to do this [5]. The implementation partially breaks memory safety, (they use unsafeSlice), but the "broken memory safety" is isolated to that function, actually very safe (but not easily verified by the compiler), and easy to remove if things go wrong (replace "unsafeSlice" with "slice").

    But this is Haskell! Very rarely are the performance characteristics of a function to be taken in isolation. That can be a bad thing (in the case of space leaks), or a very, very good thing.

    Although the matrix multiplication algorithm used is naive, it will perform worse in a raw benchmark. But rarely does our code look like benchmarks.

    What if you were a scientist with millions of data points and want to multiply huge matrices? [7]

    For those people, we have mmultP [6]. This performs matrix multiplication, but is data-parallel, and subject to REPA's nested data parallelism. Also note that the code is essentially unchanged from the sequential version.

    For those people that don't multiply huge matrices, and instead multiply lots of little matrices, there tends to be other code interacting with said matrices. Possibly cutting it up into column vectors and finding their dot products, maybe finding its eigenvalues, maybe something else entirely. Unlike C, Haskell knows that although you like to solve problems in isolation, the most efficient solution usually isn't found there.

    Like ByteString, Text, and Vector, REPA arrays are subject to fusion. 2 You should actually read 2 by the way - it's a very well written paper. This, combined with aggressive inlining of relevant code and REPA's highly parallel nature allows us to express these high-level mathematical concepts with very advanced high-level optimizations behind the scenes.

    Although a method of writing an efficient matrix multiplication in pure functional languages isn't currently know, we can come somewhat close (no automatic vectorization, a few excessive dereferences to get to the actual data, etc.), but nothing near what IFORT or GCC can do. But program's don't exist on an island, and making the island as a whole perform well is much, much easier in Haskell than Java.

    [5] http://hackage.haskell.org/packages/archive/repa-algorithms/3.2.1.1/doc/html/src/Data-Array-Repa-Algorithms-Matrix.html#mmultS

    [6] http://hackage.haskell.org/packages/archive/repa-algorithms/3.2.1.1/doc/html/src/Data-Array-Repa-Algorithms-Matrix.html#mmultP

    [7] Acutally, the best way to do this is by using the GPU. There are a few GPU DSLs available for Haskell which make this possible to do natively. They're really neat!

    0 讨论(0)
提交回复
热议问题