1use std::ops::{Add, Neg, Sub};
4
5#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
7pub enum NumInf<T> {
8 NegInf,
10 Value(T),
12 Inf,
14}
15
16impl<T: Copy> NumInf<T> {
17 pub fn is_value(self) -> bool {
19 matches!(self, Self::Value(_))
20 }
21
22 pub fn is_inf(self) -> bool {
24 matches!(self, Self::Inf)
25 }
26
27 pub fn is_neg_inf(self) -> bool {
29 matches!(self, Self::NegInf)
30 }
31
32 pub fn unwrap(self) -> T {
38 match self {
39 Self::Value(x) => x,
40 Self::Inf => panic!("called `NumInf::unwrap()` on a `Inf` value"),
41 Self::NegInf => panic!("called `NumInf::unwrap()` on a `NegInf` value"),
42 }
43 }
44}
45
46impl<T: Add<Output = T>> Add for NumInf<T> {
47 type Output = Self;
48
49 fn add(self, other: Self) -> Self {
50 match self {
51 Self::Value(x) => match other {
52 Self::Value(y) => Self::Value(x + y),
53 y => y,
54 },
55 y => y,
56 }
57 }
58}
59
60impl<T: Sub<Output = T>> Sub for NumInf<T> {
61 type Output = Self;
62
63 fn sub(self, other: Self) -> Self {
64 match self {
65 Self::Value(x) => match other {
66 Self::Value(y) => Self::Value(x - y),
67 Self::Inf => Self::NegInf,
68 Self::NegInf => Self::Inf,
69 },
70 y => y,
71 }
72 }
73}
74
75impl<T: Neg<Output = T>> Neg for NumInf<T> {
76 type Output = Self;
77
78 fn neg(self) -> Self {
79 match self {
80 Self::Value(x) => Self::Value(-x),
81 Self::Inf => Self::NegInf,
82 Self::NegInf => Self::Inf,
83 }
84 }
85}
86
87impl<T: Default> Default for NumInf<T> {
88 fn default() -> Self {
89 Self::Value(T::default())
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn test() {
99 assert!(NumInf::Value(1) < NumInf::Inf);
101 assert!(NumInf::NegInf < NumInf::Value(-100));
102 assert!(NumInf::<i64>::NegInf < NumInf::Inf);
103
104 assert_eq!(NumInf::Value(1).min(NumInf::Inf), NumInf::Value(1));
105 assert_eq!(NumInf::Value(1).max(NumInf::Inf), NumInf::Inf);
106
107 let a = NumInf::Value(100);
109 let inf = NumInf::<i64>::Inf;
110 let ninf = NumInf::<i64>::NegInf;
111
112 assert_eq!(NumInf::Value(1) + NumInf::Value(-4), NumInf::Value(-3));
113 assert_eq!(inf + a, inf);
114 assert_eq!(ninf + a, ninf);
115
116 assert_eq!(a + inf, inf);
117 assert_eq!(a + ninf, ninf);
118
119 assert_eq!(inf + ninf, inf);
120 assert_eq!(inf + inf, inf);
121 assert_eq!(ninf + inf, ninf);
122 assert_eq!(ninf + ninf, ninf);
123
124 let a = NumInf::Value(100);
126 let inf = NumInf::<i64>::Inf;
127 let ninf = NumInf::<i64>::NegInf;
128
129 assert_eq!(NumInf::Value(1) - NumInf::Value(-4), NumInf::Value(5));
130 assert_eq!(inf - a, inf);
131 assert_eq!(ninf - a, ninf);
132
133 assert_eq!(a - inf, ninf);
134 assert_eq!(a - ninf, inf);
135
136 assert_eq!(inf - ninf, inf);
137 assert_eq!(inf - inf, inf);
138 assert_eq!(ninf - inf, ninf);
139 assert_eq!(ninf - ninf, ninf);
140
141 let inf = NumInf::<i64>::Inf;
143 let ninf = NumInf::<i64>::NegInf;
144
145 assert_eq!(-NumInf::Value(1), NumInf::Value(-1));
146 assert_eq!(-inf, ninf);
147 assert_eq!(-ninf, inf);
148 }
149}