본문 바로가기
TIL

[코테연습] 1792. Maximum Average Pass Ratio

by 크라00 2024. 12. 15.

문제 타입

  • 우선순위 큐 (Priority Queue / Heap)
  • 탐욕 알고리즘 (Greedy Algorithm)

문제 분석

  • 문제 요구사항:
    • 각 수업의 통과 비율을 최대화하도록 추가 학생을 배치.
    • 통과 비율은 pass / total로 계산되며, 학생을 추가하면 비율이 변동.
    • 최종적으로 모든 수업의 평균 통과 비율을 최대화해야 함.
  • 핵심 개념:
    1. 추가 학생을 배치할 때, 한 번의 배치로 최대 증가량을 얻는 수업을 우선적으로 선택.
    2. 비율 증가량을 계산:
      • 현재 비율: pass / total
      • 한 명 추가 후 비율: (pass + 1) / (total + 1)
      • 증가량: (pass + 1) / (total + 1) - pass / total
    3. 이 증가량을 기준으로 우선순위를 설정하면 문제를 효율적으로 해결 가능.
  • 우선순위 큐 활용:
    • 증가량이 큰 수업을 우선적으로 선택해야 하므로, 최대 힙으로 구현.
    • Python의 heapq는 최소 힙이므로, 값에 음수를 붙여 최대 힙처럼 사용.

문제 풀이

  1. 증가량 계산 함수 정의:
    • 각 수업에서 학생 1명을 추가했을 때 비율 증가량을 계산하는 함수 작성.
  2. 우선순위 큐 초기화:
    • 모든 수업을 증가량 기준으로 정렬하여 우선순위 큐에 삽입.
    • 큐에 삽입할 때, (-증가량, [pass, total]) 형태로 저장하여 최대 힙 구현.
  3. 추가 학생 배치:
    • extraStudents만큼 반복하면서:
      • 증가량이 가장 큰 수업을 큐에서 꺼내 학생 1명을 추가.
      • 업데이트된 값을 다시 우선순위 큐에 삽입.
  4. 최종 평균 계산:
    • 우선순위 큐에 남은 수업들의 통과 비율을 합산하여 평균 계산.
  5. 소수점 반올림:
    • 결과를 소수점 5자리까지 반올림하여 반환.

> java

 

class Solution {
    public double maxAverageRatio(int[][] classes, int extraStudents) {
        int n = classes.length; // 전체 수업의 수

        // 우선순위 큐 생성 (Comparator를 사용하여 우선순위를 설정)
        // 우선순위는 (추가 학생으로 인한 증가량) 기준으로 설정
        PriorityQueue<int[]> pq = new PriorityQueue<>((o1, o2) -> {
            // 현재 비율 계산
            double current1 = ((double) o1[0]) / ((double) o1[1]);
            // 추가 학생 1명 배치 후의 비율 계산
            double next1 = ((double) (o1[0] + 1)) / (o1[1] + 1);
            
            // 현재 비율 계산 (다른 수업)
            double current2 = ((double) o2[0]) / ((double) o2[1]);
            // 추가 학생 1명 배치 후의 비율 계산 (다른 수업)
            double next2 = ((double) (o2[0] + 1)) / (o2[1] + 1);
            
            // (next2 - current2)와 (next1 - current1)를 비교
            // 더 큰 증가량을 가지는 수업이 우선순위가 높도록 설정
            return Double.compare((next2 - current2), (next1 - current1));
        });

        // 모든 수업을 우선순위 큐에 추가
        for (int[] cls : classes) {
            pq.offer(cls);
        }

        // extraStudents만큼 추가 학생 배치
        for (int i = 0; i < extraStudents; i++) {
            // 가장 증가율이 높은 수업 선택
            int[] ele = pq.poll();
            // 학생 한 명 추가
            ele[0] += 1;
            ele[1] += 1;
            // 다시 큐에 삽입 (변경된 값 기준으로 재정렬됨)
            pq.offer(ele);
        }
        
        double result = 0; // 최종 평균 비율 계산을 위한 변수

        // 우선순위 큐에서 모든 수업의 비율을 계산하여 결과에 누적
        while (!pq.isEmpty()) {
            int[] ele = pq.poll();
            int pass = ele[0]; // 통과한 학생 수
            int total = ele[1]; // 총 학생 수
            double ratio = (double) pass / (double) total; // 비율 계산
            result += ratio; // 비율 합산
        }

        // 전체 수업 수로 나누어 평균 비율 계산
        result /= (double) n;

        // 소수점 5자리까지 반올림하여 반환
        return (Math.round(result * 100000)) / 100000.0;
    }
}

 

> python

 

import heapq

class Solution:
    def maxAverageRatio(self, classes, extraStudents):
        n = len(classes)  # 전체 수업의 수

        # 우선순위 큐에 삽입할 항목을 정의
        # Python의 heapq는 최소 힙이므로, -증가량으로 최대 힙처럼 동작하도록 설정
        def improvement(class_info):
            pass_students, total_students = class_info
            current_ratio = pass_students / total_students
            next_ratio = (pass_students + 1) / (total_students + 1)
            return next_ratio - current_ratio

        # 우선순위 큐 초기화 (증가량 기준으로 정렬)
        pq = []
        for cls in classes:
            heapq.heappush(pq, (-improvement(cls), cls))  # -증가량을 큐에 삽입

        # extraStudents만큼 추가 학생 배치
        for _ in range(extraStudents):
            _, cls = heapq.heappop(pq)  # 가장 증가량이 큰 수업 선택
            cls[0] += 1  # 통과한 학생 수 증가
            cls[1] += 1  # 총 학생 수 증가
            heapq.heappush(pq, (-improvement(cls), cls))  # 변경된 수업 정보를 다시 큐에 삽입

        # 최종 평균 비율 계산
        result = 0
        while pq:
            _, cls = heapq.heappop(pq)
            pass_students, total_students = cls
            result += pass_students / total_students

        # 평균 비율 계산
        result /= n

        # 소수점 5자리까지 반올림
        return round(result, 5)