Median of Medians in Java

前端 未结 5 1490
清歌不尽
清歌不尽 2020-12-03 09:11

I am trying to implement Median of Medians in Java for a method like this:

Select(Comparable[] list, int pos, int colSize, int colMed)
相关标签:
5条回答
  • 2020-12-03 09:50

    I agree with the answer/solution from Chip Uni. I will just comment the sorting part and provide some further explanations:

    You do not need any sorting algorithm. The algorithm is similar to quicksort, with the difference that only one partition is solved (left or right). We just need to find an optimal pivot so that left and right parts are as equal as possible, which would mean N/2 + N/4 + N/8 ... = 2N iterations, and thus the time complexity of O(N). The above algorithms, called median of medians, computes the median of medians of 5, which turns out to yield linear time complexity of the algorithm.

    However, sorting algorithm is used when the range being searched for nth smallest/greatest element (which I suppose you are implementing with this algorithm) in order to speed up the algorithm. Insertion sort is particularly fast on small arrays up to 7 to 10 elements.

    Implementation note:

    M = select({x[i]}, n/10)
    

    actually means taking the median of all those medians of 5-element groups. You can accomplish that by creating another array of size (n - 1)/5 + 1 and call the same algorithm recursively to find the n/10-th element (which is median of the newly created array).

    0 讨论(0)
  • 2020-12-03 09:58

    I don't know if you still need this problem solved, but http://www.ics.uci.edu/~eppstein/161/960130.html has an algorithm:

    select(L,k)
    {
        if (L has 10 or fewer elements)
        {
            sort L
            return the element in the kth position
        }
    
        partition L into subsets S[i] of five elements each
            (there will be n/5 subsets total).
    
        for (i = 1 to n/5) do
            x[i] = select(S[i],3)
    
        M = select({x[i]}, n/10)
    
        partition L into L1<M, L2=M, L3>M
        if (k <= length(L1))
            return select(L1,k)
        else if (k > length(L1)+length(L2))
            return select(L3,k-length(L1)-length(L2))
        else return M
    }
    

    Good luck!

    0 讨论(0)
  • 2020-12-03 10:02

    I know it's a very old post and you might not remember about it any more. But I wonder did you measure the running time of your implementation when you implemented it?

    I tried this algorithm and compare it with the simple approach using java sorting method (Arrays.sort() ), then pick the kth element from sorted array. The result that I received is that this algorithm only out-beat java sorting algorithm when the size of the array is about hundred thousand elements or more. And it's only about 2 or 3 times faster, which is obviously not log(n) time faster.

    Do you have any comment on that?

    0 讨论(0)
  • 2020-12-03 10:04

    @android developer :

    for (i = 1 to n/5) do
        x[i] = select(S[i],3)
    

    is really

    for (i = 1 to ceiling(n/5) do
        x[i] = select(S[i],3)
    

    with a ceiling function appropriate for your data(eg in java 2 doubles) This affects the median as well wrt simply taking n/10, but we are finding closest to the mean that occurs in the array, not the true mean. Another note is that S[i] may have fewer than 3 elements, so we want to find the median with respect to length; passing it into select with k=3 won't always work.( eg n =11, we have 3 subgroups 2 w 5, 1 w 1 element)

    0 讨论(0)
  • 2020-12-03 10:14

    The question asked for Java, so here it is

    import java.util.*;
    
    public class MedianOfMedians {
        private MedianOfMedians() {
    
        }
    
        /**
         * Returns median of list in linear time.
         * 
         * @param list list to search, which may be reordered on return
         * @return median of array in linear time.
         */
        public static Comparable getMedian(ArrayList<Comparable> list) {
            int s = list.size();
            if (s < 1)
                throw new IllegalArgumentException();
            int pos = select(list, 0, s, s / 2);
            return list.get(pos);
        }
    
        /**
         * Returns position of k'th largest element of sub-list.
         * 
         * @param list list to search, whose sub-list may be shuffled before
         *            returning
         * @param lo first element of sub-list in list
         * @param hi just after last element of sub-list in list
         * @param k
         * @return position of k'th largest element of (possibly shuffled) sub-list.
         */
        public static int select(ArrayList<Comparable> list, int lo, int hi, int k) {
            if (lo >= hi || k < 0 || lo + k >= hi)
                throw new IllegalArgumentException();
            if (hi - lo < 10) {
                Collections.sort(list.subList(lo, hi));
                return lo + k;
            }
            int s = hi - lo;
            int np = s / 5; // Number of partitions
            for (int i = 0; i < np; i++) {
                // For each partition, move its median to front of our sublist
                int lo2 = lo + i * 5;
                int hi2 = (i + 1 == np) ? hi : (lo2 + 5);
                int pos = select(list, lo2, hi2, 2);
                Collections.swap(list, pos, lo + i);
            }
    
            // Partition medians were moved to front, so we can recurse without making another list.
            int pos = select(list, lo, lo + np, np / 2);
    
            // Re-partition list to [<pivot][pivot][>pivot]
            int m = triage(list, lo, hi, pos);
            int cmp = lo + k - m;
            if (cmp > 0)
                return select(list, m + 1, hi, k - (m - lo) - 1);
            else if (cmp < 0)
                return select(list, lo, m, k);
            return lo + k;
        }
    
        /**
         * Partition sub-list into 3 parts [<pivot][pivot][>pivot].
         * 
         * @param list
         * @param lo
         * @param hi
         * @param pos input position of pivot value
         * @return output position of pivot value
         */
        private static int triage(ArrayList<Comparable> list, int lo, int hi,
                int pos) {
            Comparable pivot = list.get(pos);
            int lo3 = lo;
            int hi3 = hi;
            while (lo3 < hi3) {
                Comparable e = list.get(lo3);
                int cmp = e.compareTo(pivot);
                if (cmp < 0)
                    lo3++;
                else if (cmp > 0)
                    Collections.swap(list, lo3, --hi3);
                else {
                    while (hi3 > lo3 + 1) {
                        assert (list.get(lo3).compareTo(pivot) == 0);
                        e = list.get(--hi3);
                        cmp = e.compareTo(pivot);
                        if (cmp <= 0) {
                            if (lo3 + 1 == hi3) {
                                Collections.swap(list, lo3, lo3 + 1);
                                lo3++;
                                break;
                            }
                            Collections.swap(list, lo3, lo3 + 1);
                            assert (list.get(lo3 + 1).compareTo(pivot) == 0);
                            Collections.swap(list, lo3, hi3);
                            lo3++;
                            hi3++;
                        }
                    }
                    break;
                }
            }
            assert (list.get(lo3).compareTo(pivot) == 0);
            return lo3;
        }
    
    }
    

    Here is a Unit test to check it works...

    import java.util.*;
    
    import junit.framework.TestCase;
    
    public class MedianOfMedianTest extends TestCase {
        public void testMedianOfMedianTest() {
            Random r = new Random(1);
            int n = 87;
            for (int trial = 0; trial < 1000; trial++) {
                ArrayList list = new ArrayList();
                int[] a = new int[n];
                for (int i = 0; i < n; i++) {
                    int v = r.nextInt(256);
                    a[i] = v;
                    list.add(v);
                }
                int m1 = (Integer)MedianOfMedians.getMedian(list);
                Arrays.sort(a);
                int m2 = a[n/2];
                assertEquals(m1, m2);
            }
        }
    }
    

    However, the above code is too slow for practical use.

    Here is a simpler way to get the k'th element that does not guarantee performance, but is much faster in practice:

    /**
     * Returns position of k'th largest element of sub-list.
     * 
     * @param list list to search, whose sub-list may be shuffled before
     *            returning
     * @param lo first element of sub-list in list
     * @param hi just after last element of sub-list in list
     * @param k
     * @return position of k'th largest element of (possibly shuffled) sub-list.
     */
    static int select(double[] list, int lo, int hi, int k) {
        int n = hi - lo;
        if (n < 2)
            return lo;
    
        double pivot = list[lo + (k * 7919) % n]; // Pick a random pivot
    
        // Triage list to [<pivot][=pivot][>pivot]
        int nLess = 0, nSame = 0, nMore = 0;
        int lo3 = lo;
        int hi3 = hi;
        while (lo3 < hi3) {
            double e = list[lo3];
            int cmp = compare(e, pivot);
            if (cmp < 0) {
                nLess++;
                lo3++;
            } else if (cmp > 0) {
                swap(list, lo3, --hi3);
                if (nSame > 0)
                    swap(list, hi3, hi3 + nSame);
                nMore++;
            } else {
                nSame++;
                swap(list, lo3, --hi3);
            }
        }
        assert (nSame > 0);
        assert (nLess + nSame + nMore == n);
        assert (list[lo + nLess] == pivot);
        assert (list[hi - nMore - 1] == pivot);
        if (k >= n - nMore)
            return select(list, hi - nMore, hi, k - nLess - nSame);
        else if (k < nLess)
            return select(list, lo, lo + nLess, k);
        return lo + k;
    }
    
    0 讨论(0)
提交回复
热议问题