LeetCode 215. 数组中的第K个最大元素 快排、最小堆题解,重学快排、最大堆最小堆。

题目

215. 数组中的第K个最大元素

https://leetcode.cn/problems/kth-largest-element-in-an-array/

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

暴力解

暴力思路最直接:

  • 排序!
  • 取对应位置元素即可。

复杂度分析

时间

Arrays.sort默认使用快排,O(n*logn)时间复杂度。

空间

递归栈空间有额外开销,空间复杂度O(logn),某些情况也可以视为无额外空间占用,本题我们主要跟其他方案做对比。

代码

1
2
3
4
5
6
7
public int findKthLargestBruteForce(int[] nums, int k) {
    if (nums == null || nums.length == 0 || nums.length < k) {
        return -1;
    }
    Arrays.sort(nums);
    return nums[nums.length - k];
}

小顶堆解

排序取第n的问题中,我们也常用最大堆最小堆来辅助解决问题,JDK中对应容器为PriorityQueue。也可以试着自己实现下最小堆的siftDown/siftUp堆合理化函数。

思路

我们使用一个小顶堆来维护前k个数据,堆顶为前k个中的边界值:最小值。

遍历大数组其余元素时,只需看当前遍历元素是否比边界值大,如果大,就可以尝试放入堆。

最后,堆顶即为所求。

代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
// 本问题中最小堆需要支持的操作:添加元素、查看最小值(堆顶)元素、移除最小元素

// 从下标1的位置存完全二叉树堆元素
int[] arr;
int capacity;
int size;

public KthLargestElementInAnArray() {
}

public KthLargestElementInAnArray(int capacity) {
    this.capacity = capacity;
    arr = new int[capacity + 1];
    size = 0;
}

/**
    * 添加元素
    */
public void add(int ele) {
    if (size == capacity) {
        throw new IllegalStateException("最小堆已满");
    }
    ++size;
    arr[size] = ele;
    siftUp();
}

/**
    * 查看最小值(堆顶)元素
    */
public int peekMin() {
    return arr[1];
}

/**
    * 移除最小元素
    */
public void removeMin() {
    arr[1] = arr[size];
    --size;

    siftDown();
}

private void siftDown() {
    // 从顶端开始遍历树,检查堆性质
    int parentIdx = 1;
    int leftChildIdx = getLeftChildIdx(parentIdx);
    int rightChildIdx = getRightChildIdx(parentIdx);
    while (leftChildIdx <= size && rightChildIdx <= size) {
        if (arr[parentIdx] <= arr[leftChildIdx] && arr[parentIdx] <= arr[rightChildIdx]) {
            break;
        }
        int smallerIdx = arr[rightChildIdx] <= arr[leftChildIdx] ? rightChildIdx : leftChildIdx;
        swap(arr, smallerIdx, parentIdx);
        parentIdx = smallerIdx;
        leftChildIdx = getLeftChildIdx(parentIdx);
        rightChildIdx = getRightChildIdx(parentIdx);
    }
}

private void siftUp() {
    // 从末端往上检查堆性质
    int childIdx = size;
    int parentIdx = getParentIdx(childIdx);
    while (parentIdx >= 1 && arr[parentIdx] > arr[childIdx]) {
        swap(arr, parentIdx, childIdx);
        childIdx = parentIdx;
        parentIdx = getParentIdx(childIdx);
    }
}


public int getParentIdx(int childIdx) {
    return childIdx / 2;
}

public int getLeftChildIdx(int parentIdx) {
    return 2 * parentIdx;
}

public int getRightChildIdx(int parentIdx) {
    return 2 * parentIdx + 1;
}

/**
    * 执行用时:
    * 2 ms
    * , 在所有 Java 提交中击败了
    * 81.55%
    * 的用户
    * 内存消耗:
    * 41.5 MB
    * , 在所有 Java 提交中击败了
    * 52.56%
    * 的用户
    * 通过测试用例:
    * 32 / 32
    */
public int findKthLargestWithMinHeap(int[] nums, int k) {
    KthLargestElementInAnArray minHeap = new KthLargestElementInAnArray(k);
    for (int i = 0; i < k; i++) {
        minHeap.add(nums[i]);
    }
    for (int i = k; i < nums.length; i++) {
        if (nums[i] > minHeap.peekMin()) {
            minHeap.removeMin();
            minHeap.add(nums[i]);
        }
    }
    return minHeap.peekMin();
}

/**
    * 执行用时:
    * 3 ms
    * , 在所有 Java 提交中击败了
    * 61.15%
    * 的用户
    * 内存消耗:
    * 41.8 MB
    * , 在所有 Java 提交中击败了
    * 10.47%
    * 的用户
    * 通过测试用例:
    * 32 / 32
    */
public int findKthLargestWithPriorityQueue(int[] nums, int k) {
    PriorityQueue<Integer> priorityQueue = new PriorityQueue<>(Comparator.naturalOrder());
    for (int i = 0; i < k; i++) {
        priorityQueue.offer(nums[i]);
    }
    for (int i = k; i < nums.length; i++) {
        if (nums[i] > priorityQueue.peek()) {
            priorityQueue.poll();
            priorityQueue.offer(nums[i]);
        }
    }
    return priorityQueue.peek();
}

复杂度分析

时间

这个方案下的操作:

  1. 遍历n容量的数组;
  2. 碰到大于当前堆顶的元素就入队(极端情况下会触发n-k次入队操作);

遍历开销 1,入队开销 logk,我们取大项logk

时间复杂度为O(k+(n-k)*logk),去除低阶为O(n*logk)

空间

用了一个数组(堆结构)维护我们前k项数据,空间复杂度O(k)

快排解

需熟悉快排的过程与精髓。

回顾下快排:https://github.com/redolog/algorithm-java/blob/main/src/main/java/com/algorithm/sort/QuickSort.java

思路

快排原本是用于全量排序,精髓:

  1. 分区:通过分治思想、递归手段逐步将大数组部分有序化->完全有序化;
  2. 分区同时是原地操作,复杂度低;
  3. 双路优化,优化了大量等值元素的情况;
  4. 随机基准数据,优化了边界极端值的情况;
  5. 三路优化,优化了等值元素多次对比的情况;

快排+fail-fast即为本题解的正确姿势。按liweiwei的说法,这个叫减治:减而治之,核心思想是逐步缩小问题范围。

我理解的核心思想:与搜索、查询接口类比,我们在程序中逐步将不符合要求的操作提前返回,即fail-fast。也叫快速选择。

代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/**
    * 执行用时:
    * 1 ms
    * , 在所有 Java 提交中击败了
    * 98.60%
    * 的用户
    * 内存消耗:
    * 41.6 MB
    * , 在所有 Java 提交中击败了
    * 44.91%
    * 的用户
    * 通过测试用例:
    * 32 / 32
    */
public int findKthLargestWithQuickSortPivot(int[] nums, int k) {
    if (nums == null || nums.length == 0 || nums.length < k) {
        return -1;
    }
    // 利用快排思想实现减治:将查询问题集缩小,也就是fail-fast
    int n = nums.length;
    int leftIdx = 0;
    int rightIdx = n - 1;

    // 统一使用下标法
    k = k - 1;

    while (true) {
        // 选择一个pivotVal 基准值,左区间存大于等于pivot的元素,右区间存小于等于pivotVal的元素;
        // 排出pivotVal所在位置后,判断目标位置k在左还是右,下一轮只搜对应区间
        // 从最右侧取一个pivotVal,为了防止极值用例的情况,我们每次选取pivot做一次随机化下标
        int randomIdx = (int) (Math.random() * (rightIdx - leftIdx + 1) + leftIdx);
        swap(nums, randomIdx, rightIdx);

        int pivotVal = nums[rightIdx];

        // leftI遍历从左,rightI遍历从右
        int leftI = leftIdx, rightI = rightIdx - 1;
        while (true) {
            // 从左一直找到小于或者等于pivotVal的位置,待与右侧元素交换
            while (leftI <= rightIdx - 1 && nums[leftI] > pivotVal) {
                ++leftI;
            }
            // 从右侧一直找到大于或者等于pivotVal的位置,待与左侧元素交换
            while (rightI >= leftIdx && nums[rightI] < pivotVal) {
                --rightI;
            }

            if (rightI <= leftI) {
                break;
            }
            // 左侧小于等于的元素 与 右侧大于等于的元素交换
            swap(nums, rightI, leftI);
            --rightI;
            ++leftI;
        }
        swap(nums, leftI, rightIdx);
        // leftI 作为基准值下标
        if (leftI == k) {
            return nums[leftI];
        } else if (leftI < k) {
            leftIdx = leftI + 1;
        } else if (leftI > k) {
            rightIdx = leftI - 1;
        }
    }
}

复杂度分析

时间
最差:

partition每次选pivotVal都选了最右侧小数区间,并且只有一个值,这种情况我们代码中通过随机pivot下标的方式进行了优化,发生概率极低。

最优:

partition首次分区就定位到了k值的位置,此时复杂度为O(n)

平均:

使用公式逐步推导,参考baeldung网站对快速选择算法的分析

复杂度为O(n)

空间

针对本题解,我们去掉了递归,没有递归执行栈的开销,空间复杂度O(1)

Ref