Memory efficient sort of massive numpy array in Python

后端 未结 2 1821
孤街浪徒
孤街浪徒 2021-02-05 11:33

I need to sort a VERY large genomic dataset using numpy. I have an array of 2.6 billion floats, dimensions = (868940742, 3) which takes up about 20GB of memory on m

2条回答
  •  说谎
    说谎 (楼主)
    2021-02-05 12:22

    EDIT: In case anyone new to programming and numpy comes across this post, I want to point out the importance of considering the np.dtype that you are using. In my case, I was actually able to get away with using half-precision floating point, i.e. np.float16, which reduced a 20GB object in memory to 5GB and made sorting much more manageable. The default used by numpy is np.float64, which is a lot of precision that you may not need. Check out the doc here, which describes the capacity of the different data types. Thanks to @ali_m for pointing this out in the comments.

    I did a bad job explaining this question but I have discovered some helpful workarounds that I think would be useful to share for anyone who needs to sort a truly massive numpy array.

    I am building a very large numpy array from 22 "sub-arrays" of human genome data containing the elements [position, value]. Ultimately, the final array must be numerically sorted "in place" based on the values in a particular column and without shuffling the values within rows.

    The sub-array dimensions follow the form:

    arr1.shape = (N1, 2)
    ...
    arr22.shape = (N22, 2)
    

    sum([N1..N2]) = 868940742 i.e. there are close to 1BN positions to sort.

    First I process the 22 sub-arrays with the function process_sub_arrs, which returns a 3-tuple of 1D arrays the same length as the input. I stack the 1D arrays into a new (N, 3) array and insert them into an np.zeros array initialized for the full dataset:

        full_arr = np.zeros([868940742, 3])
        i, j = 0, 0
    
        for arr in list(arr1..arr22):  
            # indices (i, j) incremented at each loop based on sub-array size
            j += len(arr)
            full_arr[i:j, :] = np.column_stack( process_sub_arrs(arr) )
            i = j
    
        return full_arr
    

    EDIT: Since I realized my dataset could be represented with half-precision floats, I now initialize full_arr as follows: full_arr = np.zeros([868940742, 3], dtype=np.float16), which is only 1/4 the size and much easier to sort.

    Result is a massive 20GB array:

    full_arr.nbytes = 20854577808
    

    As @ali_m pointed out in his detailed post, my earlier routine was inefficient:

    sort_idx = np.argsort(full_arr[:,idx])
    full_arr = full_arr[sort_idx]
    

    the array sort_idx, which is 33% the size of full_arr, hangs around and wastes memory after sorting full_arr. This sort supposedly generates a copy of full_arr due to "fancy" indexing, potentially pushing memory use to 233% of what is already used to hold the massive array! This is the slow step, lasting about ten minutes and relying heavily on virtual memory.

    I'm not sure the "fancy" sort makes a persistent copy however. Watching the memory usage on my machine, it seems that full_arr = full_arr[sort_idx] deletes the reference to the unsorted original, because after about 1 second all that is left is the memory used by the sorted array and the index, even if there is a transient copy.

    A more compact usage of argsort() to save memory is this one:

        full_arr = full_arr[full_arr[:,idx].argsort()]
    

    This still causes a spike at the time of the assignment, where both a transient index array and a transient copy are made, but the memory is almost instantly freed again.

    @ali_m pointed out a nice trick (credited to Joe Kington) for generating a de facto structured array with a view on full_arr. The benefit is that these may be sorted "in place", maintaining stable row order:

    full_arr.view('f8, f8, f8').sort(order=['f0'], axis=0)
    

    Views work great for performing mathematical array operations, but for sorting it is far too inefficient for even a single sub-array from my dataset. In general, structured arrays just don't seem to scale very well even though they have really useful properties. If anyone has any idea why this is I would be interested to know.

    One good option to minimize memory consumption and improve performance with very large arrays is to build a pipeline of small, simple functions. Functions clear local variables once they have completed so if intermediate data structures are building up and sapping memory this can be a good solution.

    This a sketch of the pipeline I've used to speed up the massive array sort:

    def process_sub_arrs(arr):
        """process a sub-array and return a 3-tuple of 1D values arrays"""
    
        return values1, values2, values3
    
    def build_arr():
        """build the initial array by joining processed sub-arrays"""
    
        full_arr = np.zeros([868940742, 3])
        i, j = 0, 0
    
        for arr in list(arr1..arr22):  
            # indices (i, j) incremented at each loop based on sub-array size
            j += len(arr)
            full_arr[i:j, :] = np.column_stack( process_sub_arrs(arr) )
            i = j
    
        return full_arr
    
    def sort_arr():
        """return full_arr and sort_idx"""
    
        full_arr = build_arr()
        sort_idx = np.argsort(full_arr[:, index])
    
        return full_arr[sort_idx]
    
    def get_sorted_arr():
        """call through nested functions to return the sorted array"""
    
        sorted_arr = sort_arr()
        
    
        return statistics
    

    call stack: get_sorted_arr --> sort_arr --> build_arr --> process_sub_arrs

    Once each inner function is completed get_sorted_arr() finally just holds the sorted array and then returns a small array of statistics.

    EDIT: It is also worth pointing out here that even if you are able to use a more compact dtype to represent your huge array, you will want to use higher precision for summary calculations. For example, since full_arr.dtype = np.float16, the command np.mean(full_arr[:,idx]) tries to calculate the mean in half-precision floating point, but this quickly overflows when summing over a massive array. Using np.mean(full_arr[:,idx], dtype=np.float64) will prevent the overflow.

    I posted this question initially because I was puzzled by the fact that a dataset of identical size suddenly began choking up my system memory, although there was a big difference in the proportion of unique values in the new "slow" set. @ali_m pointed out that, indeed, more uniform data with fewer unique values is easier to sort:

    The qsort variant of Quicksort works by recursively selecting a 'pivot' element in the array, then reordering the array such that all the elements less than the pivot value are placed before it, and all of the elements greater than the pivot value are placed after it. Values that are equal to the pivot are already sorted, so intuitively, the fewer unique values there are in the array, the smaller the number of swaps there are that need to be made.

    On that note, the final change I ended up making to attempt to resolve this issue was to round the newer dataset in advance, since there was an unnecessarily high level of decimal precision leftover from an interpolation step. This ultimately had an even bigger effect than the other memory saving steps, showing that the sort algorithm itself was the limiting factor in this case.

    Look forward to other comments or suggestions anyone might have on this topic, and I almost certainly misspoke about some technical issues so I would be glad to hear back :-)

提交回复
热议问题