haar_lib/typical/double_sigma/
mod.rs

1//! $\sum \sum f(a_i, a_j)$の形の問題
2
3pub mod difference;
4pub mod max;
5pub mod prod;
6pub mod range_prod;
7pub mod range_sum;
8pub mod range_xor;
9pub mod sum;
10pub mod xor;
11
12#[cfg(test)]
13mod tests {
14    use std::ops::AddAssign;
15
16    use rand::Rng;
17
18    use crate::num::const_modint::*;
19
20    const M998244353: u32 = 998244353;
21
22    fn solve<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
23    where
24        U: AddAssign,
25        F: FnMut(&[T], usize, usize) -> U,
26    {
27        let mut ans = init;
28        let n = a.len();
29
30        for i in 0..n {
31            for j in i + 1..n {
32                ans += f(&a, i, j);
33            }
34        }
35
36        ans
37    }
38
39    fn solve_range<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
40    where
41        U: AddAssign,
42        F: FnMut(&[T], usize, usize) -> U,
43    {
44        let mut ans = init;
45        let n = a.len();
46
47        for i in 0..n {
48            for j in i + 1..=n {
49                ans += f(&a, i, j);
50            }
51        }
52
53        ans
54    }
55
56    #[test]
57    fn test_difference() {
58        let mut rng = rand::thread_rng();
59        let n = 300;
60        let a = (0..n).map(|_| rng.gen::<i32>() as i64).collect::<Vec<_>>();
61
62        let res = super::difference::sum_of_sum_of_difference(a.clone());
63        let ans = solve(a, 0, |a, i, j| (a[i] - a[j]).abs());
64
65        assert_eq!(res, ans);
66    }
67
68    #[test]
69    fn test_xor() {
70        let mut rng = rand::thread_rng();
71        let n = 300;
72        let a = (0..n).map(|_| rng.gen::<u32>() as u64).collect::<Vec<_>>();
73
74        let res = super::xor::sum_of_sum_of_xor(a.clone()) as u64;
75        let ans = solve(a, 0, |a, i, j| a[i] ^ a[j]);
76
77        assert_eq!(res, ans);
78    }
79
80    #[test]
81    fn test_range_xor() {
82        let mut rng = rand::thread_rng();
83        let n = 100;
84        let a = (0..n)
85            .map(|_| rng.gen::<u64>() % 2_u64.pow(32))
86            .collect::<Vec<_>>();
87
88        let res = super::range_xor::sum_of_sum_of_range_xor(a.clone()) as u64;
89        let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().fold(0, |x, y| x ^ y));
90
91        assert_eq!(res, ans);
92    }
93
94    #[test]
95    fn test_sum() {
96        let mut rng = rand::thread_rng();
97        let n = 300;
98        let modulo = ConstModIntBuilder::<M998244353>;
99        let a = (0..n)
100            .map(|_| modulo.from_i64(rng.gen::<i64>()))
101            .collect::<Vec<_>>();
102
103        let res = super::sum::sum_of_sum_of_sum(a.clone());
104        let ans = solve(a, modulo.from_i64(0), |a, i, j| a[i] + a[j]);
105
106        assert_eq!(res, ans);
107    }
108
109    #[test]
110    fn test_prod() {
111        let mut rng = rand::thread_rng();
112        let n = 300;
113        let modulo = ConstModIntBuilder::<M998244353>;
114        let a = (0..n)
115            .map(|_| modulo.from_i64(rng.gen::<i64>()))
116            .collect::<Vec<_>>();
117
118        let res = super::prod::sum_of_sum_of_prod(a.clone());
119        let ans = solve(a, modulo.from_u64(0), |a, i, j| a[i] * a[j]);
120
121        assert_eq!(res, ans);
122    }
123
124    #[test]
125    fn test_range_sum() {
126        let mut rng = rand::thread_rng();
127        let n = 100;
128        let a = (0..n).map(|_| rng.gen::<i32>() as i64).collect::<Vec<_>>();
129
130        let res = super::range_sum::sum_of_sum_of_range_sum(a.clone());
131        let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().sum());
132
133        assert_eq!(res, ans);
134    }
135
136    #[test]
137    fn test_max() {
138        let mut rng = rand::thread_rng();
139        let n = 300;
140        let a = (0..n).map(|_| rng.gen::<i32>() as i64).collect::<Vec<_>>();
141
142        let res = super::max::sum_of_sum_of_max(a.clone());
143        let ans = solve(a, 0, |a, i, j| a[i].max(a[j]));
144
145        assert_eq!(res, ans);
146    }
147
148    #[test]
149    fn test_range_prod() {
150        let mut rng = rand::thread_rng();
151        let n = 100;
152        let modulo = ConstModIntBuilder::<M998244353>;
153        let a = (0..n)
154            .map(|_| modulo.from_i64(rng.gen::<i64>()))
155            .collect::<Vec<_>>();
156
157        let res = super::range_prod::sum_of_sum_of_range_prod(a.clone());
158        let ans = solve_range(a, modulo.from_u64(0), |a, i, j| {
159            a[i..j].iter().fold(modulo.from_u64(1), |x, &y| x * y)
160        });
161
162        assert_eq!(res, ans);
163    }
164}