Home 백준 - 2042 구간 합 구하기 풀이
Post
Cancel

백준 - 2042 구간 합 구하기 풀이

문제

2042 구간 합 구하기

screencapture

풀이

  • Long형 data를 최대 1,000,000개 합쳐야 하는 문제이다.
  • 단순한 sum으로는 timeout이 발생한다.(아래 오답 부분 참고)
  • 구간의 합을 구하기 위해 segment tree를 이용한다.
  • Segment Tree는 부모 노드의 값이 자식 노드 값을 합이 되는 완전 이진 트리이다.
    • 즉, root 노드는 전체 값의 합이 되고, leaf 노드는 각 값이 된다.

Step by Step (문제의 예제 입력 1)

값을 배열로 입력 받는다.

input

세그먼트 트리 배열을 만든다.

  • 트리 배열의 크기: tn = 2 * 2⌈ log2n ⌉ - 1 = 17
    • 간단하게 계산을 하기 위해서는: tn = 4 * n - 1 = 19으로 계산하여도 된다. 배열이 더 넉넉하게 생성되어 약간의 메모리 낭비가 생긴하는 단점이 있다.
    • 배열의 빈칸은 비워둔다.
    • n이 2의 거듭 제곱수(n = 2x)이면 트리는 포화 이진트리(Perfect Binary Tree)가 되면서 배열이 꽉 차게 되지만, 그 외의 경우 배열의 빈 칸이 생긴다.
  • 트리 형태로 나타내면 아래 그림과 같이 된다.

input tree

  • kotlin code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private fun createTree(arr:LongArray) = LongArray(arr.getTreeSize()).let {
    initTree(0, arr.lastIndex, arr, 0, it)
    it
}

private fun LongArray.getTreeSize() = Math.pow(2.0, kotlin.math.ceil(kotlin.math.log2(this.size.toDouble()))).toInt() * 2 - 1

private fun initTree(arrStartIndex: Int, arrEndIndex:Int, arr:LongArray, nodeIndex:Int, tree:LongArray): Long {
    if(arrStartIndex == arrEndIndex) {
        tree[nodeIndex] = arr[arrStartIndex]
        return tree[nodeIndex]
    }
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    val leftChildNode = initTree(arrStartIndex, arrMidIndex, arr, getNodeLeftChildIndex(nodeIndex), tree)
    val rightChildNode = initTree(arrMidIndex + 1, arrEndIndex, arr, getNodeRightChildIndex(nodeIndex), tree)
    tree[nodeIndex] = leftChildNode + rightChildNode
    return tree[nodeIndex]
}
private fun getArrMidIndex(arrStartIndex: Int, arrEndIndex: Int) = (arrStartIndex + arrEndIndex) / 2
private fun getNodeLeftChildIndex(nodeIndex: Int) = nodeIndex * 2 + 1
private fun getNodeRightChildIndex(nodeIndex: Int) = getNodeLeftChildIndex(nodeIndex) + 1

update

  • root 노드 부터 해당 leaf 노드까지 기존 새로운 값과 기존 값의 차이를 더한다.

update

  • kotlin code
1
2
3
4
5
6
7
8
9
10
11
12
13
private tailrec fun update(arrStartIndex: Int, arrEndIndex: Int, nodeIndex: Int, tree: LongArray, arrChangeIndex: Int, diffValue: Long) {
    tree[nodeIndex] += diffValue
    if(arrStartIndex == arrEndIndex) return
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    if(arrChangeIndex <= arrMidIndex)
        update(arrStartIndex, arrMidIndex, getNodeLeftChildIndex(nodeIndex), tree, arrChangeIndex, diffValue)
    else
        update(arrMidIndex + 1, arrEndIndex, getNodeRightChildIndex(nodeIndex), tree, arrChangeIndex, diffValue)
}

private fun getArrMidIndex(arrStartIndex: Int, arrEndIndex: Int) = (arrStartIndex + arrEndIndex) / 2
private fun getNodeLeftChildIndex(nodeIndex: Int) = nodeIndex * 2 + 1
private fun getNodeRightChildIndex(nodeIndex: Int) = getNodeLeftChildIndex(nodeIndex) + 1

sum

  • 합을 구하기 위한 인덱스 범위안에 노드가 있으면 이용하여 합을 구한다.

update

  • kotlin code
1
2
3
4
5
6
7
private fun sum(arrStartIndex:Int, arrEndIndex:Int, sumStartIndex:Int, sumEndIndex:Int, node:Int, tree:LongArray): Long {
    if(sumStartIndex <= arrStartIndex && arrEndIndex <= sumEndIndex) return tree[node]
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    val leftSum = if(sumStartIndex <= arrMidIndex) sum(arrStartIndex, arrMidIndex, sumStartIndex, sumEndIndex, getNodeLeftChildIndex(node), tree) else 0L
    val rightSum = if(sumEndIndex > arrMidIndex) sum(arrMidIndex + 1, arrEndIndex, sumStartIndex, sumEndIndex, getNodeRightChildIndex(node), tree) else 0L
    return leftSum + rightSum
}

kotlin code: 오답 - (당연하게도)timeout

1
2
3
4
5
6
7
8
9
10
11
fun main() {
    val (n, m, k) = readln().split(' ').map { it.toInt() }
    val array = LongArray(n) { readln().toLong() }
    repeat(m + k) {
        val strArray= readln().split(' ')
        when(strArray[0].toInt()) {
            1 -> array[strArray[1].toInt()-1] = strArray[2].toLong()
            2 -> println((strArray[1].toInt()-1 until strArray[2].toInt()).sumOf { array[it] })
        }
    }
}

kotlin code: 정답 - Segment Tree 사용

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
fun main() {
    val (n, m, k) = readln().split(' ').map { it.toInt() }
    val arr = LongArray(n) { readln().toLong() }
    val tree = createTree(arr)
    repeat(m + k) {
        val strArray= readln().split(' ')
        when(strArray[0].toInt()) {
            1 -> {
                val index = strArray[1].toInt()-1
                update(0, n-1, 0, tree, index, strArray[2].toLong() - arr[index])
                arr[index] = strArray[2].toLong()
            }
            2 -> println(sum(0, n-1, strArray[1].toInt()-1, strArray[2].toInt()-1, 0, tree))
        }
    }
}

private fun createTree(arr:LongArray) = LongArray(arr.getTreeSize()).let {
    initTree(0, arr.lastIndex, arr, 0, it)
    it
}

private fun LongArray.getTreeSize() = Math.pow(2.0, kotlin.math.ceil(kotlin.math.log2(this.size.toDouble()))).toInt() * 2 - 1

private fun initTree(arrStartIndex: Int, arrEndIndex:Int, arr:LongArray, nodeIndex:Int, tree:LongArray): Long {
    if(arrStartIndex == arrEndIndex) {
        tree[nodeIndex] = arr[arrStartIndex]
        return tree[nodeIndex]
    }
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    val leftChildNode = initTree(arrStartIndex, arrMidIndex, arr, getNodeLeftChildIndex(nodeIndex), tree)
    val rightChildNode = initTree(arrMidIndex + 1, arrEndIndex, arr, getNodeRightChildIndex(nodeIndex), tree)
    tree[nodeIndex] = leftChildNode + rightChildNode
    return tree[nodeIndex]
}


private fun sum(arrStartIndex:Int, arrEndIndex:Int, sumStartIndex:Int, sumEndIndex:Int, nodeIndex:Int, tree:LongArray): Long {
    if(sumStartIndex <= arrStartIndex && arrEndIndex <= sumEndIndex) return tree[nodeIndex]
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    val leftSum = if(sumStartIndex <= arrMidIndex) sum(arrStartIndex, arrMidIndex, sumStartIndex, sumEndIndex, getNodeLeftChildIndex(nodeIndex), tree) else 0L
    val rightSum = if(sumEndIndex > arrMidIndex) sum(arrMidIndex + 1, arrEndIndex, sumStartIndex, sumEndIndex, getNodeRightChildIndex(nodeIndex), tree) else 0L
    return leftSum + rightSum
}

private tailrec fun update(arrStartIndex: Int, arrEndIndex: Int, nodeIndex: Int, tree: LongArray, arrChangeIndex: Int, diffValue: Long) {
    tree[nodeIndex] += diffValue
    if(arrStartIndex == arrEndIndex) return
    val arrMidIndex = getArrMidIndex(arrStartIndex, arrEndIndex)
    if(arrChangeIndex <= arrMidIndex)
        update(arrStartIndex, arrMidIndex, getNodeLeftChildIndex(nodeIndex), tree, arrChangeIndex, diffValue)
    else
        update(arrMidIndex + 1, arrEndIndex, getNodeRightChildIndex(nodeIndex), tree, arrChangeIndex, diffValue)
}

private fun getArrMidIndex(arrStartIndex: Int, arrEndIndex: Int) = (arrStartIndex + arrEndIndex) / 2
private fun getNodeLeftChildIndex(nodeIndex: Int) = nodeIndex * 2 + 1
private fun getNodeRightChildIndex(nodeIndex: Int) = getNodeLeftChildIndex(nodeIndex) + 1

Attachments

This post is licensed under CC BY 4.0 by the author.

우선순위 큐

프로그래머스 - 92334 신고 결과 받기 풀이