Find the top k sums of two sorted arrays

前端 未结 4 2004
独厮守ぢ
独厮守ぢ 2021-02-06 01:33

You are given two sorted arrays, of sizes n and m respectively. Your task (should you choose to accept it), is to output the largest k sums of the form a[i]+b[j].

4条回答
  •  离开以前
    2021-02-06 02:26

    Many thanks to @rlibby and @xuhdev with such an original idea to solve this kind of problem. I had a similar coding exercise interview require to find N largest sums formed by K elements in K descending sorted arrays - means we must pick 1 element from each sorted arrays to build the largest sum.

    Example: List findHighestSums(int[][] lists, int n) {}
    
    [5,4,3,2,1]
    [4,1]
    [5,0,0]
    [6,4,2]
    [1]
    
    and a value of 5 for n, your procedure should return a List of size 5:
    
    [21,20,19,19,18]
    

    Below is my code, please take a look carefully for those block comments :D

    private class Pair implements Comparable{
        String state;
    
        int sum;
    
        public Pair(String state, int sum) {
            this.state = state;
            this.sum = sum;
        }
    
        @Override
        public int compareTo(Pair o) {
            // Max heap
            return o.sum - this.sum;
        }
    }
    
    List findHighestSums(int[][] lists, int n) {
    
        int numOfLists = lists.length;
        int totalCharacterInState = 0;
    
        /*
         * To represent State of combination of largest sum as String
         * The number of characters for each list should be Math.ceil(log(list[i].length))
         * For example: 
         *      If list1 length contains from 11 to 100 elements
         *      Then the State represents for list1 will require 2 characters
         */
        int[] positionStartingCharacterOfListState = new int[numOfLists + 1];
        positionStartingCharacterOfListState[0] = 0;
    
        // the reason to set less or equal here is to get the position starting character of the last list
        for(int i = 1; i <= numOfLists; i++) {  
            int previousListNumOfCharacters = 1;
            if(lists[i-1].length > 10) {
                previousListNumOfCharacters = (int)Math.ceil(Math.log10(lists[i-1].length));
            }
            positionStartingCharacterOfListState[i] = positionStartingCharacterOfListState[i-1] + previousListNumOfCharacters;
            totalCharacterInState += previousListNumOfCharacters;
        }
    
        // Check the state <---> make sure that combination of a sum is new
        Set states = new HashSet<>();
        List result = new ArrayList<>();
        StringBuilder sb = new StringBuilder();
    
        // This is a max heap contain 
        PriorityQueue pq = new PriorityQueue<>();
    
        char[] stateChars = new char[totalCharacterInState];
        Arrays.fill(stateChars, '0');
        sb.append(stateChars);
        String firstState = sb.toString();
        states.add(firstState);
    
        int firstLargestSum = 0;
        for(int i = 0; i < numOfLists; i++) firstLargestSum += lists[i][0];
    
        // Imagine this is the initial state in a graph
        pq.add(new Pair(firstState, firstLargestSum));
    
        while(n > 0) {
            // In case n is larger than the number of combinations of all list entries 
            if(pq.isEmpty()) break;
            Pair top = pq.poll();
            String currentState = top.state;
            int currentSum = top.sum;
    
            /*
             * Loop for all lists and generate new states of which only 1 character is different from the former state  
             * For example: the initial state (Stage 0) 0 0 0 0 0
             * So the next states (Stage 1) should be:
             *  1 0 0 0 0
             *  0 1 0 0 0 (choose element at index 2 from 2nd array)
             *  0 0 1 0 0 (choose element at index 2 from 3rd array)
             *  0 0 0 0 1 
             * But don't forget to check whether index in any lists have exceeded list's length
             */
            for(int i = 0; i < numOfLists; i++) {
                int indexInList = Integer.parseInt(
                        currentState.substring(positionStartingCharacterOfListState[i], positionStartingCharacterOfListState[i+1]));
                if( indexInList < lists[i].length - 1) {
                    int numberOfCharacters = positionStartingCharacterOfListState[i+1] - positionStartingCharacterOfListState[i];
                    sb = new StringBuilder(currentState.substring(0, positionStartingCharacterOfListState[i]));
                    sb.append(String.format("%0" + numberOfCharacters + "d", indexInList + 1));
                    sb.append(currentState.substring(positionStartingCharacterOfListState[i+1]));
                    String newState = sb.toString();
                    if(!states.contains(newState)) {
    
                        // The newSum is always <= currentSum
                        int newSum = currentSum - lists[i][indexInList] + lists[i][indexInList+1];
    
                        states.add(newState);
                        // Using priority queue, we can immediately retrieve the largest Sum at Stage k and track all other unused states.
                        // From that Stage k largest Sum's state, then we can generate new states
                        // Those sums composed by recently generated states don't guarantee to be larger than those sums composed by old unused states.
                        pq.add(new Pair(newState, newSum));
                    }
    
                }
            }
            result.add(currentSum);
            n--;
        }
        return result;
    }
    

    Let me explain how I come up with the solution:

    1. The while loop in my answer executes N times, consider the max heap ( priority queue).
    2. Poll operation 1 time with complexity O(log( sumOfListLength )) because the maximum element Pair in heap is sumOfListLength.
    3. Insertion operations might up to K times, the complexity for each insertion is log(sumOfListLength). Therefore, the complexity is O(N * log(sumOfListLength) ),

提交回复
热议问题