kyopro-lib

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:x: Persistent union-find
(Mylib/DataStructure/UnionFind/persistent_unionfind.cpp)

Operations

Requirements

Notes

Problems

References

Depends on

Verified with

Code

#pragma once
#include <vector>
#include "Mylib/DataStructure/Array/persistent_array.cpp"

namespace haar_lib {
  class persistent_unionfind {
    persistent_array<int> par_;

    persistent_unionfind(persistent_array<int> par) : par_(par) {}

  public:
    persistent_unionfind() {}
    persistent_unionfind(int n) : par_(persistent_array<int>(std::vector<int>(n, -1))) {}

    int root_of(int i) const {
      const int p = par_[i];
      if (p < 0) return i;
      return root_of(p);
    }

    bool is_same(int i, int j) const {
      return root_of(i) == root_of(j);
    }

    int size_of(int i) const {
      return -par_[root_of(i)];
    }

    persistent_unionfind merge(int i, int j) const {
      const int ri = root_of(i), rj = root_of(j);
      if (ri == rj) return *this;

      const int size_i = -par_[ri];
      const int size_j = -par_[rj];

      persistent_array<int> ret = par_;

      if (size_i > size_j) {
        ret = ret.set(ri, -(size_i + size_j));
        ret = ret.set(rj, ri);
      } else {
        ret = ret.set(rj, -(size_i + size_j));
        ret = ret.set(ri, rj);
      }

      return persistent_unionfind(ret);
    }
  };
}  // namespace haar_lib
#line 2 "Mylib/DataStructure/UnionFind/persistent_unionfind.cpp"
#include <vector>
#line 2 "Mylib/DataStructure/Array/persistent_array.cpp"
#include <iostream>
#include <memory>
#line 5 "Mylib/DataStructure/Array/persistent_array.cpp"

namespace haar_lib {
  template <typename T>
  class persistent_array {
  public:
    using value_type = T;

  private:
    struct node {
      bool is_terminal;
      int size   = 1;
      node *left = nullptr, *right = nullptr;
      std::unique_ptr<T> value;

      node() : is_terminal(false) {}
      node(T v) : is_terminal(true), value(new T(v)) {}
    };

    size_t size_;
    int depth_;
    node *root_ = nullptr;

    int get_size(node *t) const {
      return t ? t->size : 0;
    }

    node *init(int s, int d) {
      if (s == 0) return nullptr;
      if (d == depth_) {
        return new node(T());
      } else {
        node *t  = new node();
        t->left  = init(s / 2, d + 1);
        t->right = init(s - s / 2, d + 1);
        t->size  = get_size(t->left) + get_size(t->right);
        return t;
      }
    }

    void apply_init(node *t, const std::vector<T> &ret, int &i) {
      if (not t) return;

      if (t->is_terminal) {
        *(t->value) = ret[i];
        ++i;
        return;
      }

      apply_init(t->left, ret, i);
      apply_init(t->right, ret, i);
    }

    persistent_array(node *root) : root_(root) {}

    void calc_depth() {
      depth_ = 1;
      while ((int) size_ > (1 << depth_)) depth_ += 1;
      depth_ += 1;
    }

  public:
    persistent_array() {}
    persistent_array(size_t size) : size_(size) {
      calc_depth();
      root_ = init(size_, 1);
    }

    persistent_array(const std::vector<T> &v) : size_(v.size()) {
      calc_depth();
      root_ = init(size_, 1);

      int i = 0;
      apply_init(root_, v, i);
    }

    persistent_array(const persistent_array &v) {
      this->root_  = v.root_;
      this->size_  = v.size_;
      this->depth_ = v.depth_;
    }

  protected:
    T get(node *t, int i) const {
      if (t->is_terminal) return *(t->value);

      int k = get_size(t->left);
      if (i < k)
        return get(t->left, i);
      else
        return get(t->right, i - k);
    }

  public:
    T operator[](int i) const {
      return get(root_, i);
    }

  protected:
    node *set(node *prev, int i, const T &val) const {
      if (prev->is_terminal) return new node(val);

      int k = get_size(prev->left);

      node *t = new node();
      if (i < k) {
        t->right = prev->right;
        t->left  = set(prev->left, i, val);
        t->size  = get_size(t->right) + get_size(t->left);
      } else {
        t->left  = prev->left;
        t->right = set(prev->right, i - k, val);
        t->size  = get_size(t->right) + get_size(t->left);
      }
      return t;
    }

  public:
    persistent_array set(int i, const T &val) const {
      node *ret = set(root_, i, val);
      return persistent_array(ret);
    }

  protected:
    void traverse(node *t, std::vector<T> &ret) const {
      if (not t) return;

      if (t->is_terminal) {
        ret.push_back(*(t->value));
        return;
      }

      traverse(t->left, ret);
      traverse(t->right, ret);
    }

  public:
    std::vector<T> data() const {
      std::vector<T> ret;
      traverse(root_, ret);
      return ret;
    }
  };
}  // namespace haar_lib
#line 4 "Mylib/DataStructure/UnionFind/persistent_unionfind.cpp"

namespace haar_lib {
  class persistent_unionfind {
    persistent_array<int> par_;

    persistent_unionfind(persistent_array<int> par) : par_(par) {}

  public:
    persistent_unionfind() {}
    persistent_unionfind(int n) : par_(persistent_array<int>(std::vector<int>(n, -1))) {}

    int root_of(int i) const {
      const int p = par_[i];
      if (p < 0) return i;
      return root_of(p);
    }

    bool is_same(int i, int j) const {
      return root_of(i) == root_of(j);
    }

    int size_of(int i) const {
      return -par_[root_of(i)];
    }

    persistent_unionfind merge(int i, int j) const {
      const int ri = root_of(i), rj = root_of(j);
      if (ri == rj) return *this;

      const int size_i = -par_[ri];
      const int size_j = -par_[rj];

      persistent_array<int> ret = par_;

      if (size_i > size_j) {
        ret = ret.set(ri, -(size_i + size_j));
        ret = ret.set(rj, ri);
      } else {
        ret = ret.set(rj, -(size_i + size_j));
        ret = ret.set(ri, rj);
      }

      return persistent_unionfind(ret);
    }
  };
}  // namespace haar_lib
Back to top page