According to Leetcode:
We highly recommend Kth Largest Element in an Array, which has been asked many times in an Amazon phone interview.
Given an integer array nums
and an integer k
, return the kth
largest element in the array.
Note that it is the kth
largest element in the sorted order, not the kth
distinct element.
You must solve it in O(n)
time complexity.
There are several ways to solve this task. The most straightforward is to iterate through the array and keep the top Kth largest element while iterating. For that purpose, a heap is the best data structure.
The algorithm:
iterate through the array and put each element to the heap
if the heap size is greater than k - remove the smallest one(which will be at the beginning of our queue)
return the first element
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> heap = new PriorityQueue<>();
for (int n: nums) {
heap.add(n);
if (heap.size() > k)
heap.poll();
}
return heap.poll();
}
}
We iterate through the array and each time we insert the element in the heap of length K, so the time complexity is O(NlogK). Can we do better? :)
First, let’s discuss high-level design, and then we can dive deeper. Here is a link to the problem.
We need to find K-th largest element in the array. Since for us more comfortable sorting an array in non-decreasing order, let’s rephrase the task and say that we need to find the N-k smallest element in the array.
The idea behind the algorithm is to use a partitioning algorithm from quickselect.
We partition the whole array(O(n) time complexity).
Then through away half of it and continue with another half(O(n/2) time complexity).
Again through away half of it and continue with 1/4 of the original array (O(n/4) time complexity).
Continue doing so until we reach a single element. To summarize overall time complexity:
n + n/2 + n/4 + n/8 +... ~ 2n = n
You may be confused about time complexity thinking “wait, we are doing almost the same with binary search, it is logN time complexity, where is the logN part”. Well, you are right. The key point here is that not all partitions are doing the same amount of work.
This more precise analysis, which uses the fact that the work done keeps decreasing on each iteration, gives the O(n) runtime.
If you are still confused about time complexity take a look at this answer and this article.
One chooses a pivot and defines its position in a sorted array in a linear time using so-called partition algorithm.
The toughest thing in this algorithm is understanding how the partition works. Let’s imagine we have an array [2,6,3,4,7,1,8,5]. With the Lomuto partitioning scheme the algorithm will look like this:
i
and startIndex
, both starting from the beginning of the target interval.i
will scan the whole interval and check the condition - if the value at i
is less than the pivot then swap it with the value at the pointer startIndex
and increment startIndex
startIndex
(since this is a place for our pivot value) and return startIndex
.
To achieve that we need a single for loop and dedicated variable startIndex which is equal to the index of the first element in the array(in our case it is 0).
Now we ended up with storeIndex = 3, which is actually a place for our pivot element. We can swap them and return this pivot index. The array looks like [2,3,4,5,7,6]
, pivot index is 3.
public void swap(int i, int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
public int partition(int start, int end) {
int pivot = nums[end];
int storeIndex = start;
for (int i = start; i <= end; i++) {
if (nums[i] < pivot) {
swap(i, storeIndex);
storeIndex++;
}
}
//don't forget to move pivot element from the end of the array to its position
swap(end, storeIndex);
return storeIndex;
}
The worst case of this algorithm will be O(n^2). Why is that so? The algorithm is sensitive to the pivot that is chosen. Imagine you have already sorted the array and each time you select the first element as a pivot. That means each partition will decrease the range of the elements only by 1. To avoid this, we need to select a random pivot each time:
Random random = new Random();
//asume that 'start' is an index of the first element in search interval of the array
//and 'end' is an index of the last element in that interval, then:
int pivot = left + random.nextInt(right - left);
At last, we need to implement a quickselect algorithm. The steps are:
Here is the source code:
class Solution {
int[] nums;
public void swap(int i, int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
public int partition(int start, int end, int pivotIndex) {
int pivot = nums[pivotIndex];
//move pivot to the end of the array;
swap(end, pivotIndex);
int startIndex = start;
for (int i = start; i <= end; i++) {
if (nums[i] < pivot) {
swap(i, startIndex);
startIndex++;
}
}
//don't forget to move pivot element from the end of the array to its position
swap(end, startIndex);
return startIndex;
}
public int findKthLargest(int[] nums, int k) {
this.nums = nums;
return quickselect(0, nums.length - 1, nums.length - k);
}
public int quickselect(int start, int end, int k) {
if (start == end) {
return nums[start];
}
Random random = new Random();
int pivotIndex = start + random.nextInt(end - start);
pivotIndex = partition(start, end, pivotIndex);
if (pivotIndex == k) {
return nums[pivotIndex];
}
if (pivotIndex < k) {
return quickselect(pivotIndex + 1, end, k);
}
return quickselect(start, pivotIndex - 1, k);
}
}