코딩 알고리즘 문제/Leetcode

973. K Closest Points to Origin (Array, Math, Divide and Conquer, Geometry, Sorting, Heap (Priority Queue), Quickselect)

highlightmoon 2025. 10. 17. 10:54
반응형

링크 - https://leetcode.com/problems/k-closest-points-to-origin/description/?envType=company&envId=facebook&favoriteSlug=facebook-thirty-days

난이도 - Medium

Intuition

1) Max Heap and Max Priority Queue

Heap을 사용하면 모든 포인트들을 돌면서 거리가 가장 짧은 k개만 유지시킬수 있다. 이때 heap에 넣을때는 (-distance, point)형태의 튜플로 넣어야 가장 적은 거리의 포인트들을 유지시킬 수 있다.

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        heap = []

        def getDistance(point):
            x, y = point
            return sqrt(x**2 + y**2)

        for point in points:
            distance = getDistance(point)
            heapq.heappush(heap, (-distance, point))
            if len(heap) > k:
                heapq.heappop(heap)

        return list(map(lambda x: x[1], heap))

Time Complexity: O(nlogk) heap에 원소를 넣고 빼는것은 O(logk)가 걸리므로 전체 n개에 대해서는 O(nlogk)가 나오게 된다.

Space Complexity: O(k) 힙은 최대 k개의 원소들만 갖는다.

2) Binary Search

이 알고리즘은 먼저 각 포인트들의 거리를 구해놓고, 거리의 최솟값(0)과 최댓값(max(distances))를 이용해 k번째를 구하는 방법이다. 이 방법은 매번 탐색하는 배열의 크기가 절반씩 줄어드므로, 결과적으로 O(2N)이 나오므로 시간복잡도가 O(N)이 되는 장점이 있다.

1. 먼저 각 포인트들의 거리를 구해 distances배열에 넣는다.

2. 거리의 최솟값을 0, 최대값을 max(distances)로 초기화 한다. (각각 low와 high)

3. mid = (low+high)/2로 구한다. 그 다음 distances값들을 보면서 mid보다 작은 값을 가지는 index는 closer에, 아니면 farther에 저장한다.

4. closer의 개수가 k개보다 많으면 우리가 찾는 포인트들은 closer에만 있다. 따라서 closer를 가지고 3번부터 다시 시작한다. 이때 high는 mid로 설정한다.

5. closer의 개수가 k개보다 적거나 같으면 closer에 있는 포인트 인덱스들은 우리가 찾고있는 답들이므로 closest에 extend시킨다. 그리고 k를 closer의 개수만큼 감소시킨다. 그리고 farther을 가지고 3번부터 다시 시작한다. 이때 low는 mid로 설정한다. 

6. closest에는 우리가 찾는 k개의 포인트 인덱스들이 있다. 이를 가지고 point를 가져와 반환한다.

 

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        distances = [self.getDistance(point) for point in points]
        remaining = [i for i in range(len(points))]
        low, high = 0, max(distances)

        closest = []
        while k:
            mid = (low + high) / 2
            closer, farther = self.splitDistance(remaining, distances, mid)
            if len(closer) > k:
                remaining = closer
                high = mid
            else:
                k -= len(closer)
                closest.extend(closer)
                remaining = farther
                low = mid

        return [points[i] for i in closest]
        

    def splitDistance(self, remaining: List[int], distances: List[float], mid: int) -> List[List[int]]:
        closer, farther = [], []
        for idx in remaining:
            if distances[idx] <= mid:
                closer.append(idx)
            else:
                farther.append(idx)
        return [closer, farther]
    
    def getDistance(self, point: List[int]) -> float:
        return point[0] ** 2 + point[1] ** 2

Time Complexity: O(n) worst case는 O(n^2)

Space Complexity: O(n)

3) QuickSelect

QuickSelect를 사용하면 2)와 같은 시간복잡도를 가지면서 Space complexity를 O(1)으로 가질 수 있다. 

1. 먼저 left와 right를 각각 0과 len(points) - 1로 초기화 시킨다. pivot_idx는 len(points)로 초기화 시킨다.

2. while문을 통해 pivot_idx가 k가 될때까지 돌린다.

3. 먼저 left와 right를 통해 point[left+(right-left)//2]를 구해서 이 포인트의 거리를 pivot_dist로 설정한다.

4. pivot_dist를 중심으로 left와 right를 움직인다. 만약 points[left]의 거리가 pivot_dist보다 크면, points[right]와 바꿔주고 right의 값을 하나 감소시킨다. 만약 points[left]의 거리가 pivot_dist보다 작으면, left를 1 증가시킨다.

5. 마지막에 다시 points[left]의 거리가 pivot_dist보다 작으면, left를 1 증가시킨다. 그리고 pivot_idx를 left로 설정한다.

6. pivot_idx <k 이면, 왼쪽의 값들은 이미 k개의 가까운 점들에 해당하므로 그것보다 거리가 먼 것들에서 답을 찾아야 한다. 그래서 left를 pivot_idx로 설정한다. pivot_idx >= k이면, 우리는 이미 k개보다 더 많은 답들을 가지고 있으므로 가까운 점들에서 답을 찾아야 한다. 그래서 right를 pivot_idx-1로 설정한다.

7. 3-6을 반복하면 pivot_idx는 k가 되어서 우리는 k개의 답을 points에 가지고 있다. 따라서 points[:k]를 반환한다.

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        return self.quickSelect(points, k)

    def quickSelect(self, points: List[List[int]], k: int) -> List[List[int]]:
        left, right = 0, len(points) - 1
        pivot_idx = len(points)
        while pivot_idx != k:
            pivot_idx = self.partition(points, left, right)
            if pivot_idx < k:
                left = pivot_idx
            else:
                right = pivot_idx - 1

        return points[:k]
        

    def partition(self, points: List[List[int]], left: int, right: int) -> int:
        pivot = self.choose_pivot(points, left, right)
        pivot_dist = self.getDistance(pivot)
        while left < right:
            if self.getDistance(points[left]) >= pivot_dist:
                points[left], points[right] = points[right], points[left]
                right -= 1
            else:
                left += 1

        if self.getDistance(points[left]) < pivot_dist:
            left += 1
        return left

    def choose_pivot(self, points: List[List[int]], left: int, right: int) -> List[int]:
        return points[left + (right-left) // 2]
    
    def getDistance(self, point: List[int]) -> float:
        return point[0] ** 2 + point[1] ** 2

Time Complexity: O(n) worst case는 O(n^2)

Space Complexity: O(1)

반응형