haar_lib/typical/double_sigma/
mod.rs1pub mod difference;
4pub mod max;
5pub mod prod;
6pub mod range_prod;
7pub mod range_sum;
8pub mod range_xor;
9pub mod sum;
10pub mod xor;
11
12#[cfg(test)]
13mod tests {
14 use std::ops::AddAssign;
15
16 use rand::Rng;
17
18 use crate::{math::prime_mod::Prime, num::const_modint::*};
19
20 const M: u32 = 998244353;
21 type P = Prime<M>;
22
23 fn solve<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
24 where
25 U: AddAssign,
26 F: FnMut(&[T], usize, usize) -> U,
27 {
28 let mut ans = init;
29 let n = a.len();
30
31 for i in 0..n {
32 for j in i + 1..n {
33 ans += f(&a, i, j);
34 }
35 }
36
37 ans
38 }
39
40 fn solve_range<T, U, F>(a: Vec<T>, init: U, mut f: F) -> U
41 where
42 U: AddAssign,
43 F: FnMut(&[T], usize, usize) -> U,
44 {
45 let mut ans = init;
46 let n = a.len();
47
48 for i in 0..n {
49 for j in i + 1..=n {
50 ans += f(&a, i, j);
51 }
52 }
53
54 ans
55 }
56
57 #[test]
58 fn test_difference() {
59 let mut rng = rand::thread_rng();
60 let n = 300;
61 let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
62 .take(n)
63 .collect::<Vec<_>>();
64
65 let res = super::difference::sum_of_sum_of_difference(a.clone());
66 let ans = solve(a, 0, |a, i, j| (a[i] - a[j]).abs());
67
68 assert_eq!(res, ans);
69 }
70
71 #[test]
72 fn test_xor() {
73 let mut rng = rand::thread_rng();
74 let n = 300;
75 let a = std::iter::repeat_with(|| rng.gen::<u32>() as u64)
76 .take(n)
77 .collect::<Vec<_>>();
78
79 let res = super::xor::sum_of_sum_of_xor(a.clone()) as u64;
80 let ans = solve(a, 0, |a, i, j| a[i] ^ a[j]);
81
82 assert_eq!(res, ans);
83 }
84
85 #[test]
86 fn test_range_xor() {
87 let mut rng = rand::thread_rng();
88 let n = 100;
89 let a = std::iter::repeat_with(|| rng.gen::<u64>() % 2_u64.pow(32))
90 .take(n)
91 .collect::<Vec<_>>();
92
93 let res = super::range_xor::sum_of_sum_of_range_xor(a.clone()) as u64;
94 let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().fold(0, |x, y| x ^ y));
95
96 assert_eq!(res, ans);
97 }
98
99 #[test]
100 fn test_sum() {
101 let mut rng = rand::thread_rng();
102 let n = 300;
103 let modulo = ConstModIntBuilder::<P>::new();
104 let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
105 .take(n)
106 .collect::<Vec<_>>();
107
108 let res = super::sum::sum_of_sum_of_sum(a.clone());
109 let ans = solve(a, modulo.from_i64(0), |a, i, j| a[i] + a[j]);
110
111 assert_eq!(res, ans);
112 }
113
114 #[test]
115 fn test_prod() {
116 let mut rng = rand::thread_rng();
117 let n = 300;
118 let modulo = ConstModIntBuilder::<P>::new();
119 let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
120 .take(n)
121 .collect::<Vec<_>>();
122
123 let res = super::prod::sum_of_sum_of_prod(a.clone());
124 let ans = solve(a, modulo.from_u64(0), |a, i, j| a[i] * a[j]);
125
126 assert_eq!(res, ans);
127 }
128
129 #[test]
130 fn test_range_sum() {
131 let mut rng = rand::thread_rng();
132 let n = 100;
133 let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
134 .take(n)
135 .collect::<Vec<_>>();
136
137 let res = super::range_sum::sum_of_sum_of_range_sum(a.clone());
138 let ans = solve_range(a, 0, |a, i, j| a[i..j].iter().sum());
139
140 assert_eq!(res, ans);
141 }
142
143 #[test]
144 fn test_max() {
145 let mut rng = rand::thread_rng();
146 let n = 300;
147 let a = std::iter::repeat_with(|| rng.gen::<i32>() as i64)
148 .take(n)
149 .collect::<Vec<_>>();
150
151 let res = super::max::sum_of_sum_of_max(a.clone());
152 let ans = solve(a, 0, |a, i, j| a[i].max(a[j]));
153
154 assert_eq!(res, ans);
155 }
156
157 #[test]
158 fn test_range_prod() {
159 let mut rng = rand::thread_rng();
160 let n = 100;
161 let modulo = ConstModIntBuilder::<P>::new();
162 let a = std::iter::repeat_with(|| modulo.from_i64(rng.gen::<i64>()))
163 .take(n)
164 .collect::<Vec<_>>();
165
166 let res = super::range_prod::sum_of_sum_of_range_prod(a.clone());
167 let ans = solve_range(a, modulo.from_u64(0), |a, i, j| {
168 a[i..j].iter().fold(modulo.from_u64(1), |x, &y| x * y)
169 });
170
171 assert_eq!(res, ans);
172 }
173}