kyopro-lib

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:question: Wavelet matrix
(Mylib/DataStructure/WaveletMatrix/wavelet_matrix.cpp)

Operations

Requirements

Notes

Problems

References

Depends on

Verified with

Code

#pragma once
#include <cassert>
#include <optional>
#include <queue>
#include <tuple>
#include <utility>
#include <vector>
#include "Mylib/DataStructure/WaveletMatrix/succinct_dictionary.cpp"

namespace haar_lib {
  template <typename T, int B>
  class wavelet_matrix {
  public:
    using value_type = T;

  private:
    int N_;
    succinct_dict sdict_[B];
    int zero_pos_[B];

  public:
    wavelet_matrix() {}
    wavelet_matrix(std::vector<T> data) : N_(data.size()) {
      std::vector<bool> s(N_);

      for (int k = 0; k < B; ++k) {
        std::vector<T> left, right;

        for (int i = 0; i < N_; ++i) {
          s[i] = (data[i] >> (B - 1 - k)) & 1;
          if (s[i]) {
            right.push_back(data[i]);
          } else {
            left.push_back(data[i]);
          }
        }

        sdict_[k]    = succinct_dict(s);
        zero_pos_[k] = left.size();

        std::swap(data, left);
        data.insert(data.end(), right.begin(), right.end());
      }
    }

    /**
     * @return data[index]
     */
    T access(int index) {
      assert(0 <= index and index < N_);
      T ret = 0;

      int p = index;
      for (int i = 0; i < B; ++i) {
        int t = sdict_[i].access(p);
        ret |= ((T) t << (B - 1 - i));
        p = sdict_[i].rank(p, t) + t * zero_pos_[i];
      }

      return ret;
    }

    std::pair<int, int> rank_aux(int index, const T &val) {
      int l = 0, r = index;

      for (int i = 0; i < B; ++i) {
        int t = (val >> (B - i - 1)) & 1;
        l     = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r     = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return std::make_pair(l, r);
    }

    /**
     * @return data[0, index)に含まれるvalの個数
     */
    int rank(int index, const T &val) {
      auto [l, r] = rank_aux(index, val);
      return r - l;
    }

    /*
     * @return data[l, r)に含まれるvalの個数
     */
    int count(int l, int r, const T &val) {
      assert(0 <= l and l <= r and r <= N_);
      return rank(r, val) - rank(l, val);
    }

    /**
     * @return count(1-indexed)番目のvalの位置
     */
    std::optional<int> select(int count, const T &val) {
      assert(1 <= count);

      auto [l, r] = rank_aux(N_, val);
      if (r - l < count) return {};

      int p = l + count - 1;

      for (int i = B - 1; i >= 0; --i) {
        int t = (val >> (B - i - 1)) & 1;
        p     = *sdict_[i].select(p - t * zero_pos_[i] + 1, t);
      }

      return {p};
    }

    /**
     * @return data[l, r)でk(1-index)番目に小さい値
     */
    std::optional<T> quantile(int l, int r, int k) {
      assert(0 <= l and l < r and r <= N_);
      if (k == 0) return {};

      T ret = 0;

      for (int i = 0; i < B; ++i) {
        const int count_1 = sdict_[i].rank(r, 1) - sdict_[i].rank(l, 1);
        const int count_0 = r - l - count_1;

        int t = 0;

        if (k > count_0) {
          t = 1;
          ret |= ((T) t << (B - i - 1));
          k -= count_0;
        }

        l = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return {ret};
    }

    T maximum(int l, int r) {
      assert(l < r);
      return *quantile(l, r, r - l);
    }

    T minimum(int l, int r) {
      assert(l < r);
      return *quantile(l, r, 1);
    }

    /**
     * @return data[l, r)のlb以上で最小の値
     */
    std::optional<T> next_value(int l, int r, T lb) {
      int c = range_freq_lt(l, r, lb);
      return quantile(l, r, c + 1);
    }

    /**
     * @return data[l, r)のub未満で最大の値
     */
    std::optional<T> prev_value(int l, int r, T ub) {
      int c = range_freq_lt(l, r, ub);
      return quantile(l, r, c);
    }

    int range_freq_lt(int l, int r, T ub) {
      int ret = 0;

      for (int i = 0; i < B; ++i) {
        int t = (ub >> (B - i - 1)) & 1;

        if (t) {
          ret += sdict_[i].count(l, r, 0);
        }

        l = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return ret;
    }

    /**
     * @return data[l, r)内で[lb, ub)であるような値の個数
     */
    int range_freq(int l, int r, T lb, T ub) {
      return range_freq_lt(l, r, ub) - range_freq_lt(l, r, lb);
    }

    /**
     * @return data[l, r)で[lb, ub)を満たすものを出現頻度と値のpairで返す。
     */
    auto range_freq_list(int l, int r, T lb, T ub) {
      std::vector<std::pair<int, T>> ret;
      std::queue<std::tuple<int, int, int, T>> q;

      q.emplace(l, r, 0, 0);

      while (not q.empty()) {
        auto [l, r, d, val] = q.front();
        q.pop();

        if (d == B) {
          if (lb <= val and val < ub) {
            ret.emplace_back(r - l, val);
          }
          continue;
        }

        const T mask = ~(T) 0 ^ (((T) 1 << (B - d)) - 1);
        const T b    = (T) 1 << (B - d - 1);

        if (sdict_[d].count(l, r, 0) != 0) {
          if (val != (lb & mask) or not(lb & b)) {
            int L = sdict_[d].rank(l, 0);
            int R = sdict_[d].rank(r, 0);
            q.emplace(L, R, d + 1, val);
          }
        }

        if (sdict_[d].count(l, r, 1) != 0) {
          if (val != (ub & mask) or (ub & b)) {
            int L = sdict_[d].rank(l, 1) + zero_pos_[d];
            int R = sdict_[d].rank(r, 1) + zero_pos_[d];
            q.emplace(L, R, d + 1, val | b);
          }
        }
      }

      return ret;
    }

    /**
     * @return data[l, r)で出現頻度が高い順にk個を返す
     */
    auto top_k(int l, int r, int k) const {
      std::priority_queue<std::tuple<int, int, int, int, T>> q;
      std::vector<std::pair<int, T>> ret;

      q.emplace(r - l, l, r, 0, 0);

      while (not q.empty()) {
        auto [len, l, r, d, val] = q.top();
        q.pop();

        if (d == B) {
          ret.emplace_back(len, val);
          if ((int) ret.size() >= k) break;
          continue;
        }

        if (sdict_[d].count(l, r, 0) != 0) {
          int L = sdict_[d].rank(l, 0);
          int R = sdict_[d].rank(r, 0);
          q.emplace(R - L, L, R, d + 1, val);
        }

        if (sdict_[d].count(l, r, 1) != 0) {
          int L = sdict_[d].rank(l, 1) + zero_pos_[d];
          int R = sdict_[d].rank(r, 1) + zero_pos_[d];
          q.emplace(R - L, L, R, d + 1, val | ((T) 1 << (B - d - 1)));
        }
      }

      return ret;
    }
  };

  wavelet_matrix<uint32_t, 32> make_wavelet_matrix_int(const std::vector<uint32_t> &data) {
    return wavelet_matrix<uint32_t, 32>(data);
  }
}  // namespace haar_lib
#line 2 "Mylib/DataStructure/WaveletMatrix/wavelet_matrix.cpp"
#include <cassert>
#include <optional>
#include <queue>
#include <tuple>
#include <utility>
#include <vector>
#line 5 "Mylib/DataStructure/WaveletMatrix/succinct_dictionary.cpp"

namespace haar_lib {
  class succinct_dict {
    int N_;

    static const int chunk_size_ = 256;
    static const int block_size_ = 64;
    std::vector<uint64_t> data_;

    std::vector<std::vector<uint8_t>> blocks_;

    std::vector<uint32_t> chunks_;

    int chunk_num_;
    static const int block_num_ = chunk_size_ / block_size_;

  public:
    succinct_dict() : N_(0) {}
    succinct_dict(const std::vector<bool> &b) : N_(b.size()) {
      chunk_num_ = (N_ + chunk_size_ - 1) / chunk_size_;

      data_.assign(chunk_num_ * block_num_ + 1, 0);

      for (int i = 0; i < N_; ++i) {
        if (b[i]) {
          int block_index = i / block_size_;
          int index       = i % block_size_;
          data_[block_index] |= (1LL << index);
        }
      }

      chunks_.assign(chunk_num_ + 1, 0);
      blocks_.assign(chunk_num_ + 1, std::vector<uint8_t>(block_num_, 0));

      for (int i = 0; i < chunk_num_; ++i) {
        for (int j = 0; j < block_num_ - 1; ++j) {
          blocks_[i][j + 1] = blocks_[i][j] + __builtin_popcountll(data_[i * block_num_ + j]);
        }

        chunks_[i + 1] = chunks_[i] + blocks_[i][block_num_ - 1] + __builtin_popcountll(data_[(i + 1) * block_num_ - 1]);
      }
    }

    int size() const { return N_; }

    /**
     * @return [0, index)のbの個数
     */
    int rank(int index, int b) const {
      if (b == 0) {
        return index - rank(index, 1);
      } else {
        if (index > N_) index = N_;

        const int chunk_pos = index / chunk_size_;
        const int block_pos = (index % chunk_size_) / block_size_;

        const uint64_t mask =
            data_[chunk_pos * block_num_ + block_pos] & ((1LL << (index % block_size_)) - 1);

        const int ret = chunks_[chunk_pos] +
                        blocks_[chunk_pos][block_pos] +
                        __builtin_popcountll(mask);

        return ret;
      }
    }

    /**
     * @return [l, r)のbの個数
     */
    int count(int l, int r, int b) const {
      return rank(r, b) - rank(l, b);
    }

    /**
     * @return b[index]
     */
    int access(int index) const {
      return (data_[index / block_size_] >> (index % block_size_)) & 1;
    }

    /**
     * @note n in [1, N]
     * @return 先頭からn番目のbの位置
     */
    std::optional<int> select(int n, int b) const {
      assert(n >= 1);

      if (rank(N_, b) < n) return {};

      int lb = -1, ub = N_;
      while (std::abs(lb - ub) > 1) {
        int mid = (lb + ub) / 2;

        if (rank(mid, b) >= n) {
          ub = mid;
        } else {
          lb = mid;
        }
      }

      return {lb};
    }
  };
}  // namespace haar_lib
#line 9 "Mylib/DataStructure/WaveletMatrix/wavelet_matrix.cpp"

namespace haar_lib {
  template <typename T, int B>
  class wavelet_matrix {
  public:
    using value_type = T;

  private:
    int N_;
    succinct_dict sdict_[B];
    int zero_pos_[B];

  public:
    wavelet_matrix() {}
    wavelet_matrix(std::vector<T> data) : N_(data.size()) {
      std::vector<bool> s(N_);

      for (int k = 0; k < B; ++k) {
        std::vector<T> left, right;

        for (int i = 0; i < N_; ++i) {
          s[i] = (data[i] >> (B - 1 - k)) & 1;
          if (s[i]) {
            right.push_back(data[i]);
          } else {
            left.push_back(data[i]);
          }
        }

        sdict_[k]    = succinct_dict(s);
        zero_pos_[k] = left.size();

        std::swap(data, left);
        data.insert(data.end(), right.begin(), right.end());
      }
    }

    /**
     * @return data[index]
     */
    T access(int index) {
      assert(0 <= index and index < N_);
      T ret = 0;

      int p = index;
      for (int i = 0; i < B; ++i) {
        int t = sdict_[i].access(p);
        ret |= ((T) t << (B - 1 - i));
        p = sdict_[i].rank(p, t) + t * zero_pos_[i];
      }

      return ret;
    }

    std::pair<int, int> rank_aux(int index, const T &val) {
      int l = 0, r = index;

      for (int i = 0; i < B; ++i) {
        int t = (val >> (B - i - 1)) & 1;
        l     = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r     = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return std::make_pair(l, r);
    }

    /**
     * @return data[0, index)に含まれるvalの個数
     */
    int rank(int index, const T &val) {
      auto [l, r] = rank_aux(index, val);
      return r - l;
    }

    /*
     * @return data[l, r)に含まれるvalの個数
     */
    int count(int l, int r, const T &val) {
      assert(0 <= l and l <= r and r <= N_);
      return rank(r, val) - rank(l, val);
    }

    /**
     * @return count(1-indexed)番目のvalの位置
     */
    std::optional<int> select(int count, const T &val) {
      assert(1 <= count);

      auto [l, r] = rank_aux(N_, val);
      if (r - l < count) return {};

      int p = l + count - 1;

      for (int i = B - 1; i >= 0; --i) {
        int t = (val >> (B - i - 1)) & 1;
        p     = *sdict_[i].select(p - t * zero_pos_[i] + 1, t);
      }

      return {p};
    }

    /**
     * @return data[l, r)でk(1-index)番目に小さい値
     */
    std::optional<T> quantile(int l, int r, int k) {
      assert(0 <= l and l < r and r <= N_);
      if (k == 0) return {};

      T ret = 0;

      for (int i = 0; i < B; ++i) {
        const int count_1 = sdict_[i].rank(r, 1) - sdict_[i].rank(l, 1);
        const int count_0 = r - l - count_1;

        int t = 0;

        if (k > count_0) {
          t = 1;
          ret |= ((T) t << (B - i - 1));
          k -= count_0;
        }

        l = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return {ret};
    }

    T maximum(int l, int r) {
      assert(l < r);
      return *quantile(l, r, r - l);
    }

    T minimum(int l, int r) {
      assert(l < r);
      return *quantile(l, r, 1);
    }

    /**
     * @return data[l, r)のlb以上で最小の値
     */
    std::optional<T> next_value(int l, int r, T lb) {
      int c = range_freq_lt(l, r, lb);
      return quantile(l, r, c + 1);
    }

    /**
     * @return data[l, r)のub未満で最大の値
     */
    std::optional<T> prev_value(int l, int r, T ub) {
      int c = range_freq_lt(l, r, ub);
      return quantile(l, r, c);
    }

    int range_freq_lt(int l, int r, T ub) {
      int ret = 0;

      for (int i = 0; i < B; ++i) {
        int t = (ub >> (B - i - 1)) & 1;

        if (t) {
          ret += sdict_[i].count(l, r, 0);
        }

        l = sdict_[i].rank(l, t) + t * zero_pos_[i];
        r = sdict_[i].rank(r, t) + t * zero_pos_[i];
      }

      return ret;
    }

    /**
     * @return data[l, r)内で[lb, ub)であるような値の個数
     */
    int range_freq(int l, int r, T lb, T ub) {
      return range_freq_lt(l, r, ub) - range_freq_lt(l, r, lb);
    }

    /**
     * @return data[l, r)で[lb, ub)を満たすものを出現頻度と値のpairで返す。
     */
    auto range_freq_list(int l, int r, T lb, T ub) {
      std::vector<std::pair<int, T>> ret;
      std::queue<std::tuple<int, int, int, T>> q;

      q.emplace(l, r, 0, 0);

      while (not q.empty()) {
        auto [l, r, d, val] = q.front();
        q.pop();

        if (d == B) {
          if (lb <= val and val < ub) {
            ret.emplace_back(r - l, val);
          }
          continue;
        }

        const T mask = ~(T) 0 ^ (((T) 1 << (B - d)) - 1);
        const T b    = (T) 1 << (B - d - 1);

        if (sdict_[d].count(l, r, 0) != 0) {
          if (val != (lb & mask) or not(lb & b)) {
            int L = sdict_[d].rank(l, 0);
            int R = sdict_[d].rank(r, 0);
            q.emplace(L, R, d + 1, val);
          }
        }

        if (sdict_[d].count(l, r, 1) != 0) {
          if (val != (ub & mask) or (ub & b)) {
            int L = sdict_[d].rank(l, 1) + zero_pos_[d];
            int R = sdict_[d].rank(r, 1) + zero_pos_[d];
            q.emplace(L, R, d + 1, val | b);
          }
        }
      }

      return ret;
    }

    /**
     * @return data[l, r)で出現頻度が高い順にk個を返す
     */
    auto top_k(int l, int r, int k) const {
      std::priority_queue<std::tuple<int, int, int, int, T>> q;
      std::vector<std::pair<int, T>> ret;

      q.emplace(r - l, l, r, 0, 0);

      while (not q.empty()) {
        auto [len, l, r, d, val] = q.top();
        q.pop();

        if (d == B) {
          ret.emplace_back(len, val);
          if ((int) ret.size() >= k) break;
          continue;
        }

        if (sdict_[d].count(l, r, 0) != 0) {
          int L = sdict_[d].rank(l, 0);
          int R = sdict_[d].rank(r, 0);
          q.emplace(R - L, L, R, d + 1, val);
        }

        if (sdict_[d].count(l, r, 1) != 0) {
          int L = sdict_[d].rank(l, 1) + zero_pos_[d];
          int R = sdict_[d].rank(r, 1) + zero_pos_[d];
          q.emplace(R - L, L, R, d + 1, val | ((T) 1 << (B - d - 1)));
        }
      }

      return ret;
    }
  };

  wavelet_matrix<uint32_t, 32> make_wavelet_matrix_int(const std::vector<uint32_t> &data) {
    return wavelet_matrix<uint32_t, 32>(data);
  }
}  // namespace haar_lib
Back to top page