ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 10999, Platinum 4
    백준 2021. 10. 8. 01:13
     

    10999번: 구간 합 구하기 2

    첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

    www.acmicpc.net

    문제

    • 빈번하게 변하는 배열의 구간합을 구하시오(변할 때 하나의 값이 아닌 구간으로 변한다)

     

    $O((m+k)log\ n)$

    lazy propagation in segment tree which constrctive recursive function

     

    • lazy propagation 의 핵심 아이디어는 갱신을 미리하지 말고 모았다가 (변화량은 lazy 배열에 저장해놓는다)
      필요할 때 한꺼번에 하는것이다.
    • seg_i 번째 노드의 정보는 다음과 같다
      1. 다루고 있는 배열 $arr_l\ \sim\ arr_r$ 
      2. 구간 $[arr_l\ arr_r]$의 모든 연산 값 (연산이 + 이면 구간의 합을 저장)
      3. 변화해야 하는 변화량, lazy 배열에 저장한다. 따라서 앞으로 lazy 값이라 부른다
    • 아래와 같이 구현할 때, 세그먼트 트리의 값을 담는 배열의 크기는 4*(다루는 숫자의 개수)

     

    def propagate( ) ~ seg_i 값은 갱신, 자식노드들은 갱신 안함

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    void propagate(int seg_i, int arr_l, int arr_r)
    {
        // 변화량이 있을 때
        if(lazy[seg_i]){
            // 자식노트 개수(arr_r - arr_l + 1)만큼 변화가 있어야 한다 
            seg_tree[seg_i] += (arr_r-arr_l+1)*lazy[seg_i];
            // leaf 노드가 아닐 때
            if(arr_l != arr_r){
                lazy[2*seg_i] += lazy[seg_i];
                lazy[2*seg_i+1+= lazy[seg_i];
            }
            // seg_i 을 갱신했으니 나중에 갱신할 값도 0이 됨
            lazy[seg_i] = 0;
        }

    ※ 변화량을 모았다가 바꾸는 것이다. 즉, 등호 lazy[i] = (값) 가 아닌 덧셈!

    ex) 아래와 같은 상황에서 seg_i = 1 에서 propagate 을 실행하면 다음과 같다

     

    def init( ) ~ seg_i 의 값을 리턴

    1
    2
    3
    4
    5
    6
    7
    8
    9
    ll init(int seg_i = 1int arr_l = 1int arr_r = n)
    {
        // leaf 노드일 경우
        if(arr_l == arr_r) return seg_tree[seg_i] = arr[arr_r];
        
        // leaf 노드가 아닐 경우
        int arr_m = (arr_l + arr_r)/2;
        return seg_tree[seg_i] = init(2*seg_i, arr_l, arr_m) + init(2*seg_i+1, arr_m+1, arr_r);

     

    def sum_query( ) ~ seg_i 값 return if [arr_l, arr_r] ⊂ [want_l, want_r]

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    ll sum_query(int want_l, int want_r, int seg_i = 1int arr_l = 1int arr_r = n)
    {
        propagate(seg_i, arr_l, arr_r);
     
        // 완전 다 쓰레기 정보
        if(want_r < arr_l || arr_r < want_l) return 0;
     
        // 정보를 얻을 수 있고, 완전 다 필요해
        if(want_l <= arr_l && arr_r <= want_r) return seg_tree[seg_i];
     
        // 정보를 얻을 수 있으나, 쓰레기 정보 있네?
        int arr_m = (arr_l + arr_r)/2;
        return sum_query(want_l, want_r, 2*seg_i, arr_l, arr_m) + sum_query(want_l, want_r, 2*seg_i+1, arr_m+1, arr_r);

    seg_i 에서 구하려는 구간 $[want_l, want_r]$ 과 다루고 있는 index $[arr_l,\ arr_r]$ 을 비교할 때

    i) 완전 다 쓰레기 정보
    $want_r < arr_l\ or\ arr_r < want_l$

    ii) 정보를 얻을 수 있고, 완전 다 필요해
    $want_l \leq arr_l \leq arr_r \leq want_r$

    iii) 정보를 얻을 수 있으나, 쓰레기 정보 있네?
    그 외의 경우

     

    def update( ) ~ seg_i 의 값 수정

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    void update(int want_l, int want_r, ll delta, int seg_i = 1int arr_l = 1int arr_r = n)
    {
        propagate(seg_i, arr_l, arr_r);
     
        // 완전 다 쓰레기 정보
        if(arr_r < want_l || want_r < arr_l) return;
     
        // 정보를 얻을 수 있고, 완전 다 필요해
        if(want_l <= arr_l && arr_r <= want_r){
            lazy[seg_i] += delta;
            propagate(seg_i, arr_l, arr_r);
            return;
        }
     
        // 정보를 얻을 수 있으나, 쓰레기 정보 있네?
        int arr_m = (arr_l + arr_r)/2;
        update(want_l, want_r, delta, 2*seg_i, arr_l, arr_m);
        update(want_l, want_r, delta, 2*seg_i+1, arr_m+1, arr_r);
        
        // lazy 값의 위로 전파
        seg_tree[seg_i] = seg_tree[2*seg_i] + seg_tree[2*seg_i + 1];

    Q1 3번째 줄에 propagate 왜 필요함? 

    - update 된 구간에서 lazy 값이 남아 있으면 root 값이 부정확하게 갱신됨
    반례)

    5 2 0
    1
    2
    3
    4
    5
    1 1 5 -2
    1 1 2 5
    2 1 5

     

    c++

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    #include <bits/stdc++.h>
    #define endl "\n"
    #define ooop(i, n) for(int i = 0; i < n; i++)
    #define loop(i, n) for(int i = 1; i <= n; i++)
     
    using namespace std;
    typedef long long ll;
    typedef pair<intint> pii;
    typedef pair<ll, ll> pll;
     
    const int N = 1e6;
    ll arr[N+1];
     
    ll seg_tree[4*N];
    ll lazy[4*N];
     
    int n, m, k;
     
    void propagate(int seg_i, int arr_l, int arr_r)
    {
        if(lazy[seg_i]){
            seg_tree[seg_i] += (arr_r-arr_l+1)*lazy[seg_i];
            if(arr_l != arr_r){
                lazy[2*seg_i] += lazy[seg_i];
                lazy[2*seg_i+1+= lazy[seg_i];
            }
            lazy[seg_i] = 0;
        }
    }
     
    ll init(int seg_i = 1int arr_l = 1int arr_r = n)
    {
        if(arr_l == arr_r) return seg_tree[seg_i] = arr[arr_r];
        int arr_m = (arr_l + arr_r)/2;
        return seg_tree[seg_i] = init(2*seg_i, arr_l, arr_m) + init(2*seg_i+1, arr_m+1, arr_r);
    }
     
    void update(int want_l, int want_r, ll delta, int seg_i = 1int arr_l = 1int arr_r = n)
    {
        propagate(seg_i, arr_l, arr_r);
        if(arr_r < want_l || want_r < arr_l) return;
        if(want_l <= arr_l && arr_r <= want_r){
            lazy[seg_i] += delta;
            propagate(seg_i, arr_l, arr_r);
            return;
        }
        int arr_m = (arr_l + arr_r)/2;
        update(want_l, want_r, delta, 2*seg_i, arr_l, arr_m);
        update(want_l, want_r, delta, 2*seg_i+1, arr_m+1, arr_r);
        seg_tree[seg_i] = seg_tree[2*seg_i] + seg_tree[2*seg_i + 1];
    }
     
    ll sum_query(int want_l, int want_r, int seg_i = 1int arr_l = 1int arr_r = n)
    {
        propagate(seg_i, arr_l, arr_r);
        if(want_r < arr_l || arr_r < want_l) return 0;
        if(want_l <= arr_l && arr_r <= want_r) return seg_tree[seg_i];
        int arr_m = (arr_l + arr_r)/2;
        return sum_query(want_l, want_r, 2*seg_i, arr_l, arr_m) + sum_query(want_l, want_r, 2*seg_i+1, arr_m+1, arr_r);
    }
     
    int main()
    {
        ios::sync_with_stdio(false);
        cin.tie(0), cout.tie(0);
     
        cin >> n >> m >> k;
        loop(i, n) cin >> arr[i];
     
        init();
     
        m+=k;
        int a, b, c;
        ll d;
        while(m--){
            cin >> a >> b >> c;
            if(a == 1){
                cin >> d;
                update(b, c, d);
            }
            else cout << sum_query(b, c) << endl;
            loop(i, 9cout << seg_tree[i] << ' ';
            cout << endl;
            loop(i, 9cout << lazy[i] << ' ';
            cout << endl;
        }
     
     
        return 0;

     

    python 3

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    import sys
     
    n, m, k = map(int, sys.stdin.readline().split())
    arr = [0]
    for _ in range(n): arr.append(int(input()))
    seg_tree = [0]*(4*n)
    lazy = [0]*(4*n)
     
    def propagate(seg_i, arr_l, arr_r):
        if lazy[seg_i]:
            seg_tree[seg_i] += (arr_r-arr_l+1)*lazy[seg_i]
            if arr_l != arr_r:
                lazy[2*seg_i] += lazy[seg_i]
                lazy[2*seg_i+1+= lazy[seg_i]
            lazy[seg_i] = 0
     
    def init(seg_i = 1, arr_l = 1, arr_r = n):
        if arr_l == arr_r:
            seg_tree[seg_i] = arr[arr_l]
            return seg_tree[seg_i]
        arr_m = (arr_l + arr_r)//2
        seg_tree[seg_i] = init(2*seg_i, arr_l, arr_m) + init(2*seg_i+1, arr_m+1, arr_r)
        return seg_tree[seg_i]
     
    def update(want_l, want_r, delta, seg_i = 1, arr_l = 1, arr_r = n):
        propagate(seg_i, arr_l, arr_r)
     
        if want_r < arr_l or arr_r < want_l: return
        if want_l <= arr_l and arr_r <= want_r:
            lazy[seg_i] += delta
            propagate(seg_i, arr_l, arr_r)
            return
        arr_m = (arr_l+arr_r)//2
        update(want_l, want_r, delta, 2*seg_i, arr_l, arr_m)
        update(want_l, want_r, delta, 2*seg_i+1, arr_m+1, arr_r)
        seg_tree[seg_i] = seg_tree[2*seg_i] + seg_tree[2*seg_i+1]
     
    def sum_query(want_l, want_r, seg_i = 1, arr_l = 1, arr_r = n):
        propagate(seg_i, arr_l, arr_r)
     
        if want_r < arr_l or arr_r < want_l: return 0
        if want_l <= arr_l and arr_r <= want_r: return seg_tree[seg_i]
        arr_m = (arr_l+arr_r)//2
        return sum_query(want_l, want_r, 2*seg_i, arr_l, arr_m) + sum_query(want_l, want_r, 2*seg_i+1, arr_m+1, arr_r)
     
    init()
    for _ in range(m+k):
        a, *= map(int, sys.stdin.readline().split())
        if a == 1:
            update(*b)
        else:
            print(sum_query(*b))cs

    '백준' 카테고리의 다른 글

    13549, Gold 5  (0) 2022.01.02
    2515, 골드 2  (0) 2021.11.11
    2042, Gold 1  (0) 2021.10.07
    16236, Gold 4  (0) 2021.10.03
    11286, Silver 1  (0) 2021.10.03

    댓글

Designed by Tistory.