1use core::{cmp::Ordering, convert::TryFrom, marker::PhantomData, ops};
4
5use num_traits::{
6 CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedSub, NumOps, One, Pow, Signed, Unsigned,
7 WrappingAdd, WrappingMul, WrappingNeg, WrappingSub, Zero, checked_pow,
8};
9
10use crate::{
11 arith::{Arithmetic, OrdArithmetic},
12 error::ArithmeticError,
13};
14
15#[derive(Debug, Clone, Copy, Default)]
22pub struct StdArithmetic;
23
24impl<T> Arithmetic<T> for StdArithmetic
25where
26 T: Clone + NumOps + PartialEq + ops::Neg<Output = T> + Pow<T, Output = T>,
27{
28 #[inline]
29 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
30 Ok(x + y)
31 }
32
33 #[inline]
34 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
35 Ok(x - y)
36 }
37
38 #[inline]
39 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
40 Ok(x * y)
41 }
42
43 #[inline]
44 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
45 Ok(x / y)
46 }
47
48 #[inline]
49 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
50 Ok(x.pow(y))
51 }
52
53 #[inline]
54 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
55 Ok(-x)
56 }
57
58 #[inline]
59 fn eq(&self, x: &T, y: &T) -> bool {
60 *x == *y
61 }
62}
63
64impl<T> OrdArithmetic<T> for StdArithmetic
65where
66 Self: Arithmetic<T>,
67 T: PartialOrd,
68{
69 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
70 x.partial_cmp(y)
71 }
72}
73
74#[cfg(all(test, feature = "std"))]
75static_assertions::assert_impl_all!(StdArithmetic: OrdArithmetic<f32>, OrdArithmetic<f64>);
76
77#[cfg(all(test, feature = "complex"))]
78static_assertions::assert_impl_all!(
79 StdArithmetic: Arithmetic<num_complex::Complex32>,
80 Arithmetic<num_complex::Complex64>
81);
82
83pub trait CheckedArithmeticKind<T> {
85 fn checked_neg(value: T) -> Option<T>;
87}
88
89#[derive(Debug)]
99pub struct CheckedArithmetic<Kind = Checked>(PhantomData<Kind>);
100
101impl<Kind> Clone for CheckedArithmetic<Kind> {
102 fn clone(&self) -> Self {
103 *self
104 }
105}
106
107impl<Kind> Copy for CheckedArithmetic<Kind> {}
108
109impl<Kind> Default for CheckedArithmetic<Kind> {
110 fn default() -> Self {
111 Self(PhantomData)
112 }
113}
114
115impl<Kind> CheckedArithmetic<Kind> {
116 pub const fn new() -> Self {
118 Self(PhantomData)
119 }
120}
121
122impl<T, Kind> Arithmetic<T> for CheckedArithmetic<Kind>
123where
124 T: Clone + PartialEq + Zero + One + CheckedAdd + CheckedSub + CheckedMul + CheckedDiv,
125 Kind: CheckedArithmeticKind<T>,
126 usize: TryFrom<T>,
127{
128 #[inline]
129 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
130 x.checked_add(&y).ok_or(ArithmeticError::IntegerOverflow)
131 }
132
133 #[inline]
134 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
135 x.checked_sub(&y).ok_or(ArithmeticError::IntegerOverflow)
136 }
137
138 #[inline]
139 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
140 x.checked_mul(&y).ok_or(ArithmeticError::IntegerOverflow)
141 }
142
143 #[inline]
144 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
145 x.checked_div(&y).ok_or(ArithmeticError::DivisionByZero)
146 }
147
148 #[inline]
149 #[allow(clippy::map_err_ignore)]
150 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
151 let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
152 checked_pow(x, exp).ok_or(ArithmeticError::IntegerOverflow)
153 }
154
155 #[inline]
156 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
157 Kind::checked_neg(x).ok_or(ArithmeticError::IntegerOverflow)
158 }
159
160 #[inline]
161 fn eq(&self, x: &T, y: &T) -> bool {
162 *x == *y
163 }
164}
165
166#[derive(Debug)]
169pub struct Checked(());
170
171impl<T: CheckedNeg> CheckedArithmeticKind<T> for Checked {
172 fn checked_neg(value: T) -> Option<T> {
173 value.checked_neg()
174 }
175}
176
177#[derive(Debug)]
179pub struct NegateOnlyZero(());
180
181impl<T: Unsigned + Zero> CheckedArithmeticKind<T> for NegateOnlyZero {
182 fn checked_neg(value: T) -> Option<T> {
183 if value.is_zero() { Some(value) } else { None }
184 }
185}
186
187#[derive(Debug)]
191pub struct Unchecked(());
192
193impl<T: Signed> CheckedArithmeticKind<T> for Unchecked {
194 fn checked_neg(value: T) -> Option<T> {
195 Some(-value)
196 }
197}
198
199impl<T, Kind> OrdArithmetic<T> for CheckedArithmetic<Kind>
200where
201 Self: Arithmetic<T>,
202 T: PartialOrd,
203{
204 #[inline]
205 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
206 x.partial_cmp(y)
207 }
208}
209
210#[cfg(test)]
211static_assertions::assert_impl_all!(
212 CheckedArithmetic: OrdArithmetic<u8>,
213 OrdArithmetic<i8>,
214 OrdArithmetic<u16>,
215 OrdArithmetic<i16>,
216 OrdArithmetic<u32>,
217 OrdArithmetic<i32>,
218 OrdArithmetic<u64>,
219 OrdArithmetic<i64>,
220 OrdArithmetic<u128>,
221 OrdArithmetic<i128>
222);
223
224#[derive(Debug, Clone, Copy, Default)]
229pub struct WrappingArithmetic;
230
231impl<T> Arithmetic<T> for WrappingArithmetic
232where
233 T: Copy
234 + PartialEq
235 + Zero
236 + One
237 + WrappingAdd
238 + WrappingSub
239 + WrappingMul
240 + WrappingNeg
241 + ops::Div<T, Output = T>,
242 usize: TryFrom<T>,
243{
244 #[inline]
245 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
246 Ok(x.wrapping_add(&y))
247 }
248
249 #[inline]
250 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
251 Ok(x.wrapping_sub(&y))
252 }
253
254 #[inline]
255 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
256 Ok(x.wrapping_mul(&y))
257 }
258
259 #[inline]
260 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
261 if y.is_zero() {
262 Err(ArithmeticError::DivisionByZero)
263 } else if y.wrapping_neg().is_one() {
264 Ok(x.wrapping_neg())
267 } else {
268 Ok(x / y)
269 }
270 }
271
272 #[inline]
273 #[allow(clippy::map_err_ignore)]
274 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
275 let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
276 Ok(wrapping_exp(x, exp))
277 }
278
279 #[inline]
280 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
281 Ok(x.wrapping_neg())
282 }
283
284 #[inline]
285 fn eq(&self, x: &T, y: &T) -> bool {
286 *x == *y
287 }
288}
289
290impl<T> OrdArithmetic<T> for WrappingArithmetic
291where
292 Self: Arithmetic<T>,
293 T: PartialOrd,
294{
295 #[inline]
296 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
297 x.partial_cmp(y)
298 }
299}
300
301fn wrapping_exp<T: Copy + One + WrappingMul>(mut base: T, mut exp: usize) -> T {
304 if exp == 0 {
305 return T::one();
306 }
307
308 while exp & 1 == 0 {
309 base = base.wrapping_mul(&base);
310 exp >>= 1;
311 }
312 if exp == 1 {
313 return base;
314 }
315
316 let mut acc = base;
317 while exp > 1 {
318 exp >>= 1;
319 base = base.wrapping_mul(&base);
320 if exp & 1 == 1 {
321 acc = acc.wrapping_mul(&base);
322 }
323 }
324 acc
325}
326
327#[cfg(test)]
328static_assertions::assert_impl_all!(
329 WrappingArithmetic: OrdArithmetic<u8>,
330 OrdArithmetic<i8>,
331 OrdArithmetic<u16>,
332 OrdArithmetic<i16>,
333 OrdArithmetic<u32>,
334 OrdArithmetic<i32>,
335 OrdArithmetic<u64>,
336 OrdArithmetic<i64>,
337 OrdArithmetic<u128>,
338 OrdArithmetic<i128>
339);