1use std::ops::{Add, Neg, Sub};
4
5#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
7pub enum NumInf<T> {
8 NegInf,
10 Value(T),
12 Inf,
14}
15
16impl<T: Copy> NumInf<T> {
17 pub fn is_value(self) -> bool {
19 matches!(self, NumInf::Value(_))
20 }
21
22 pub fn is_inf(self) -> bool {
24 matches!(self, NumInf::Inf)
25 }
26
27 pub fn is_neg_inf(self) -> bool {
29 matches!(self, NumInf::NegInf)
30 }
31
32 pub fn unwrap(self) -> T {
38 match self {
39 NumInf::Value(x) => x,
40 NumInf::Inf => panic!("called `NumInf::unwrap()` on a `Inf` value"),
41 NumInf::NegInf => panic!("called `NumInf::unwrap()` on a `NegInf` value"),
42 }
43 }
44}
45
46impl<T: Add<Output = T>> Add for NumInf<T> {
47 type Output = Self;
48
49 fn add(self, other: Self) -> Self {
50 match self {
51 NumInf::Value(x) => match other {
52 NumInf::Value(y) => NumInf::Value(x + y),
53 y => y,
54 },
55 y => y,
56 }
57 }
58}
59
60impl<T: Sub<Output = T>> Sub for NumInf<T> {
61 type Output = Self;
62
63 fn sub(self, other: Self) -> Self {
64 match self {
65 NumInf::Value(x) => match other {
66 NumInf::Value(y) => NumInf::Value(x - y),
67 NumInf::Inf => NumInf::NegInf,
68 NumInf::NegInf => NumInf::Inf,
69 },
70 y => y,
71 }
72 }
73}
74
75impl<T: Neg<Output = T>> Neg for NumInf<T> {
76 type Output = Self;
77
78 fn neg(self) -> Self {
79 match self {
80 NumInf::Value(x) => NumInf::Value(-x),
81 NumInf::Inf => NumInf::NegInf,
82 NumInf::NegInf => NumInf::Inf,
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test() {
93 assert!(NumInf::Value(1) < NumInf::Inf);
95 assert!(NumInf::NegInf < NumInf::Value(-100));
96 assert!(NumInf::<i64>::NegInf < NumInf::Inf);
97
98 assert_eq!(NumInf::Value(1).min(NumInf::Inf), NumInf::Value(1));
99 assert_eq!(NumInf::Value(1).max(NumInf::Inf), NumInf::Inf);
100
101 let a = NumInf::Value(100);
103 let inf = NumInf::<i64>::Inf;
104 let ninf = NumInf::<i64>::NegInf;
105
106 assert_eq!(NumInf::Value(1) + NumInf::Value(-4), NumInf::Value(-3));
107 assert_eq!(inf + a, inf);
108 assert_eq!(ninf + a, ninf);
109
110 assert_eq!(a + inf, inf);
111 assert_eq!(a + ninf, ninf);
112
113 assert_eq!(inf + ninf, inf);
114 assert_eq!(inf + inf, inf);
115 assert_eq!(ninf + inf, ninf);
116 assert_eq!(ninf + ninf, ninf);
117
118 let a = NumInf::Value(100);
120 let inf = NumInf::<i64>::Inf;
121 let ninf = NumInf::<i64>::NegInf;
122
123 assert_eq!(NumInf::Value(1) - NumInf::Value(-4), NumInf::Value(5));
124 assert_eq!(inf - a, inf);
125 assert_eq!(ninf - a, ninf);
126
127 assert_eq!(a - inf, ninf);
128 assert_eq!(a - ninf, inf);
129
130 assert_eq!(inf - ninf, inf);
131 assert_eq!(inf - inf, inf);
132 assert_eq!(ninf - inf, ninf);
133 assert_eq!(ninf - ninf, ninf);
134
135 let inf = NumInf::<i64>::Inf;
137 let ninf = NumInf::<i64>::NegInf;
138
139 assert_eq!(-NumInf::Value(1), NumInf::Value(-1));
140 assert_eq!(-inf, ninf);
141 assert_eq!(-ninf, inf);
142 }
143}