raahii.meのブログのロゴ画像

ウェブログ

Union Find のメモと Go による実装

AtCoder Beginners Content 120のD問題でUnionFindを使う問題が出題されたので学習した流れと実装をメモ.

問題

以下,問題ページ(D: Decayed Bridges)より引用.

問題文:

NN 個の島と MM 本の橋があります。

ii 番目の橋は AiA_i 番目の島と BiB_i 番目の島を繋いでおり、双方向に行き来可能です。

はじめ、どの 2 つの島についてもいくつかの橋を渡って互いに行き来できます。調査の結果、老朽化のためこれら MM 本の橋は 1 番目の橋から順に全て崩落することがわかりました。

「いくつかの橋を渡って互いに行き来できなくなった 2 つの島の組(a,b)(a<b)(a,b) (a<b)の数」を不便さと呼ぶことにします。

i(1iM)i (1\leq i \leq M) について、ii 番目の橋が崩落した直後の不便さを求めてください。

制約:

入力は全て整数である

  • 2N1052\leq N \leq 10^5
  • 1M1051 \leq M \leq 10^5
  • 1Ai<BiN1 \leq A_i \lt B_i \leq N
  • (Ai,Bi)(A_i, B_i)の組はすべて異なる
  • 初期状態における不便さは0である

全探索による解法

今回の問題は O(NM)O(NM) が通らないので全探索は無理なのですが,そもそもグラフの問題をきちんと解いたことがなかったので,まずは素直に実装してみた.前から順番に橋を落としていき,毎回独立に0から隣接行列を計算して到達可能でない島の数を数えています.

package main

func calcReachable(A, B []int, n int) [][]int {
  reachable := Ints2d(n, n)
  for i := 0; i < len(A); i++ {
    reachable[A[i]][B[i]], reachable[B[i]][A[i]] = 1, 1
  }

  // 隣接する島を順に解決していくので,
  // 最大n-1の島をたどる必要がある
  // 計算量は N^4
  for k := 0; k < n-1; k++ {
    for i := 0; i < n; i++ {
      for j := 0; j < n; j++ {
        // 各要素の隣の隣の島を取得する
        if i != j && reachable[i][j] == 1 {
          for l := 0; l < n; l++ {
            if i != l && reachable[j][l] == 1 {
              reachable[i][l] = 1
            }
          }
        }
      }
    }
  }
  return reachable
}

func solve(A, B []int, n int) {
  m := len(A)

  for k := 1; k <= m; k++ {
    reachable := calcReachable(A[k:], B[k:], n)

    ans := 0
    for i := 0; i < n; i++ {
      for j := i + 1; j < n; j++ {
        if reachable[i][j] == 0 {
          ans++
        }
      }
    }
    fmt.Println(ans)
  }
}

func main() {
  var n, m int
  fmt.Scan(&n, &m)

  A, B := make([]int, m), make([]int, m)
  var a, b int
  for i := 0; i < m; i++ {
    fmt.Scan(&a, &b)
    A[i], B[i] = a-1, b-1
  }

  solve(A, B, n)
}

重要なのは新しい隣接情報が与えられたときに,隣の隣の島の情報もきちんと反映させることで,今回は一つ隣の島の隣接情報を拾う処理をN1N-1回繰り返すことでそれを実現しています.

この方法で書くにしてももっと効率よく書けそうな気もしますが,計算量はMN4MN^4​でとにかく全く間に合いません.

代表頂点を使ってグループを管理する

全探索では駄目なので工夫が必要なのですが,まずグラフからノードを削除していく方針ではなく,グラフを0から順に構築していくように見ていく方針を取ることでより簡単に問題が解けます.確かに,隣接している島の情報を消していく操作は,その隣接に依存している他の隣接情報を解決する必要があり筋が悪そうです.

よって,まず不便さの初期値をN(N1)/2N(N-1)/2とし,与えられた隣接情報を後ろからみて追加していくことで各不便さを求め,それを逆順に出力するようにします.

またここで,行き来可能な島の集合というのは,グラフにおけるグループだと考えることができます.このグループ管理を各頂点が所属するグループの代表頂点の番号を保存することで行うことにします.

こうすることで,2つの頂点が同じグループに属しているかどうかは代表頂点の相違で判断することができます.またこれは実装上,単なる配列の要素参照となり高速に実現できます.

これを使うと,新たな橋の隣接情報が与えられた時,2つの頂点が元々同じグループであれば不便さに変化はなし,違うグループであれば2つのグループが併合することになるのでお互いの要素数の積だけ不便さが減少することになります.

package main

func merge(a, b int) {
  ra, rb := root[a], root[b]

  for i := 0; i < len(item[rb]); i++ {
    root[item[rb][i]] = ra
  }

  for i := 0; i < len(item[rb]); i++ {
    item[ra] = append(item[ra], item[rb][i])
  }
  item[rb] = []int{}
}

func solve(A, B []int, n int) {
  m := len(A)
  ans := make([]int, m)
  now := n * (n - 1) / 2
  ans[m-1] = now

  for i := m - 1; i > 0; i-- {
    if root[A[i]] == root[B[i]] {
      ans[i-1] = now
      continue
    }

    s1 := len(item[root[A[i]]])
    s2 := len(item[root[B[i]]])
    now -= s1 * s2
    ans[i-1] = now
    merge(A[i], B[i])
  }

  for _, v := range ans {
    fmt.Println(v)
  }
}

var item [][]int
var root []int

func prepare(n int) {
  item = make([][]int, n)
  root = make([]int, n)
  for i := 0; i < n; i++ {
    item[i] = []int{i}
    root[i] = i
  }
}

func main() {
  var n, m int
  fmt.Scan(&n, &m)
  prepare(n)

  A, B := make([]int, m), make([]int, m)
  var a, b int
  for i := 0; i < m; i++ {
    fmt.Scan(&a, &b)
    A[i], B[i] = a-1, b-1
  }

  solve(A, B, n)
}

しかしながらこの方法でも,2つのグループを併合する際に,要素の移動(O(N)O(N))が発生してしまうため,全体の計算量はO(MN)O(MN)となりTLEとなってしいます.

Union Find

そこでUnion Findが登場します.実はグループの管理の仕方自体は変わりません.変更点は,先程起きたグループの併合における計算量の問題を,代表頂点の代表頂点を書き換えることによってO(1)O(1)で行うところにあります.

ただし,この変更により,ある頂点の代表頂点は単に配列の要素を参照するのではなく,親の親…という風に再帰的に探索する必要がでてきます.これは特に要素1のグループが1つの階層構造を持ってしまった場合(全グループが縦に並ぶイメージ)に,探索がO(N)O(N)​かかってしまうため結局問題となります.よって,以下の2つを行うことでこれを解決します.

  • 代表頂点の探索処理と同時に,その処理に関わった全てのノードを直接根につなぎ直す(多段になっている階層構造を解消する)
  • 2つのグループを併合する時に,要素数の大きい側に少ない側を併合するようにする.

これにより,Union Findの計算量は少なくともO(logN)O(\text{log}N)未満となるようです.重要な制約として,グループの併合はできても分割はできないことが挙げられます.最終的なコードは以下の通りになります.

package main

type UnionFind struct {
  Size  int
  Nodes []int
  Ranks []int
}

func NewUnionFind(n int) *UnionFind {
  nodes := make([]int, n)
  counts := make([]int, n)
  for i := 0; i < n; i++ {
    nodes[i] = i
    counts[i] = 1
  }
  return &UnionFind{n, nodes, counts}
}

func (p *UnionFind) Root(a int) int {
  nodes := []int{}

  par := p.Nodes[a]
  var root int
  for {
    if par == p.Nodes[par] {
      root = par
      break
    }
    nodes = append(nodes, par)
    par = p.Nodes[par]
  }

  for _, n := range nodes {
    p.Nodes[n] = root
  }

  return root
}

func (p *UnionFind) Merge(a, b int) {
  if a == b {
    return
  }

  if p.Rank(a) > p.Rank(b) {
    a, b = b, a
  }
  ra, rb := p.Root(a), p.Root(b)
  p.Nodes[rb] = ra

  p.Ranks[ra] += p.Ranks[rb]
  p.Ranks[rb] = 0
}

func (p *UnionFind) Same(a, b int) bool {
  return p.Root(a) == p.Root(b)
}

func (p *UnionFind) Rank(a int) int {
  return p.Ranks[p.Root(a)]
}

func solve(A, B []int, n int) {
  m := len(A)
  ans := make([]int, m)
  now := n * (n - 1) / 2
  ans[m-1] = now

  uf := NewUnionFind(n)

  for i := m - 1; i > 0; i-- {
    if uf.Same(A[i], B[i]) {
      ans[i-1] = now
      continue
    }

    now -= uf.Rank(A[i]) * uf.Rank(B[i])
    ans[i-1] = now
    uf.Merge(A[i], B[i])
  }

  for _, v := range ans {
    fmt.Println(v)
  }
}

func main() {
  var n, m int
  fmt.Scan(&n, &m)

  A, B := make([]int, m), make([]int, m)
  var a, b int
  for i := 0; i < m; i++ {
    fmt.Scan(&a, &b)
    A[i], B[i] = a-1, b-1
  }

  solve(A, B, n)
}

参考