haar_lib/typical/double_sigma/
mod.rs

1//! $\sum \sum f(a_i, a_j)$の形の問題
2
3pub mod difference;
4pub mod max;
5pub mod prod;
6pub mod range_prod;
7pub mod range_sum;
8pub mod range_xor;
9pub mod sum;
10pub mod xor;
11
12#[cfg(test)]
13mod tests {
14    use std::ops::AddAssign;
15
16    use rand::Rng;
17
18    use crate::{math::prime_mod::Prime, num::const_modint::*};
19
20    const M: u32 = 998244353;
21    type P = Prime<M>;
22
23    fn solve<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
24    where
25        U: AddAssign,
26        F: FnMut(&[T], usize, usize) -> U,
27    {
28        let mut ans = init;
29        let n = a.len();
30
31        for i in 0..n {
32            for j in i + 1..n {
33                ans += f(&a, i, j);
34            }
35        }
36
37        ans
38    }
39
40    fn solve_range<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
41    where
42        U: AddAssign,
43        F: FnMut(&[T], usize, usize) -> U,
44    {
45        let mut ans = init;
46        let n = a.len();
47
48        for i in 0..n {
49            for j in i + 1..=n {
50                ans += f(&a, i, j);
51            }
52        }
53
54        ans
55    }
56
57    #[test]
58    fn test_difference() {
59        let mut rng = rand::thread_rng();
60        let n = 300;
61        let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
62            .take(n)
63            .collect::<Vec<_>>();
64
65        let res = super::difference::sum_of_sum_of_difference(a.clone());
66        let ans = solve(a, 0, |a, i, j| (a[i] - a[j]).abs());
67
68        assert_eq!(res, ans);
69    }
70
71    #[test]
72    fn test_xor() {
73        let mut rng = rand::thread_rng();
74        let n = 300;
75        let a = std::iter::repeat_with(|| rng.gen::<u32>() as u64)
76            .take(n)
77            .collect::<Vec<_>>();
78
79        let res = super::xor::sum_of_sum_of_xor(a.clone()) as u64;
80        let ans = solve(a, 0, |a, i, j| a[i] ^ a[j]);
81
82        assert_eq!(res, ans);
83    }
84
85    #[test]
86    fn test_range_xor() {
87        let mut rng = rand::thread_rng();
88        let n = 100;
89        let a = std::iter::repeat_with(|| rng.gen::<u64>() % 2_u64.pow(32))
90            .take(n)
91            .collect::<Vec<_>>();
92
93        let res = super::range_xor::sum_of_sum_of_range_xor(a.clone()) as u64;
94        let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().fold(0, |x, y| x ^ y));
95
96        assert_eq!(res, ans);
97    }
98
99    #[test]
100    fn test_sum() {
101        let mut rng = rand::thread_rng();
102        let n = 300;
103        let modulo = ConstModIntBuilder::<P>::new();
104        let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
105            .take(n)
106            .collect::<Vec<_>>();
107
108        let res = super::sum::sum_of_sum_of_sum(a.clone());
109        let ans = solve(a, modulo.from_i64(0), |a, i, j| a[i] + a[j]);
110
111        assert_eq!(res, ans);
112    }
113
114    #[test]
115    fn test_prod() {
116        let mut rng = rand::thread_rng();
117        let n = 300;
118        let modulo = ConstModIntBuilder::<P>::new();
119        let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
120            .take(n)
121            .collect::<Vec<_>>();
122
123        let res = super::prod::sum_of_sum_of_prod(a.clone());
124        let ans = solve(a, modulo.from_u64(0), |a, i, j| a[i] * a[j]);
125
126        assert_eq!(res, ans);
127    }
128
129    #[test]
130    fn test_range_sum() {
131        let mut rng = rand::thread_rng();
132        let n = 100;
133        let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
134            .take(n)
135            .collect::<Vec<_>>();
136
137        let res = super::range_sum::sum_of_sum_of_range_sum(a.clone());
138        let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().sum());
139
140        assert_eq!(res, ans);
141    }
142
143    #[test]
144    fn test_max() {
145        let mut rng = rand::thread_rng();
146        let n = 300;
147        let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
148            .take(n)
149            .collect::<Vec<_>>();
150
151        let res = super::max::sum_of_sum_of_max(a.clone());
152        let ans = solve(a, 0, |a, i, j| a[i].max(a[j]));
153
154        assert_eq!(res, ans);
155    }
156
157    #[test]
158    fn test_range_prod() {
159        let mut rng = rand::thread_rng();
160        let n = 100;
161        let modulo = ConstModIntBuilder::<P>::new();
162        let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
163            .take(n)
164            .collect::<Vec<_>>();
165
166        let res = super::range_prod::sum_of_sum_of_range_prod(a.clone());
167        let ans = solve_range(a, modulo.from_u64(0), |a, i, j| {
168            a[i..j].iter().fold(modulo.from_u64(1), |x, &y| x * y)
169        });
170
171        assert_eq!(res, ans);
172    }
173}