arithmetic_eval/arith/
mod.rs

1//! `Arithmetic` trait and its implementations.
2//!
3//! # Traits
4//!
5//! An [`Arithmetic`] defines fallible arithmetic operations on primitive values
6//! of an [`ExecutableModule`], namely, addition, subtraction, multiplication, division,
7//! exponentiation (all binary ops), and negation (a unary op). Any module can be run
8//! with any `Arithmetic` on its primitive values, although some modules are reasonably tied
9//! to a particular arithmetic or a class of arithmetics (e.g., arithmetics on finite fields).
10//!
11//! [`OrdArithmetic`] extends [`Arithmetic`] with a partial comparison operation
12//! (i.e., an analogue to [`PartialOrd`]). This is motivated by the fact that comparisons
13//! may be switched off during parsing, and some `Arithmetic`s do not have well-defined comparisons.
14//!
15//! [`ArithmeticExt`] helps converting an [`Arithmetic`] into an [`OrdArithmetic`].
16//!
17//! # Implementations
18//!
19//! This module defines the following kinds of arithmetics:
20//!
21//! - [`StdArithmetic`] takes all implementations from the corresponding [`ops`](core::ops) traits.
22//!   This means that it's safe to use *provided* the ops are infallible. As a counter-example,
23//!   using [`StdArithmetic`] with built-in integer types (such as `u64`) is usually not a good
24//!   idea since the corresponding ops have failure modes (e.g., division by zero or integer
25//!   overflow).
26//! - [`WrappingArithmetic`] is defined for integer types; it uses wrapping semantics for all ops.
27//! - [`CheckedArithmetic`] is defined for integer types; it uses checked semantics for all ops.
28//! - [`ModularArithmetic`] operates on integers modulo the specified number.
29//!
30//! All defined [`Arithmetic`]s strive to be as generic as possible.
31//!
32//! [`ExecutableModule`]: crate::ExecutableModule
33
34use core::{cmp::Ordering, fmt};
35
36pub use self::{
37    generic::{
38        Checked, CheckedArithmetic, CheckedArithmeticKind, NegateOnlyZero, StdArithmetic,
39        Unchecked, WrappingArithmetic,
40    },
41    modular::{DoubleWidth, ModularArithmetic},
42};
43use crate::{alloc::Box, error::ArithmeticError};
44
45#[cfg(feature = "bigint")]
46mod bigint;
47mod generic;
48mod modular;
49
50/// Encapsulates arithmetic operations on a certain primitive type (or an enum of primitive types).
51///
52/// Unlike operations on built-in integer types, arithmetic operations may be fallible.
53/// Additionally, the arithmetic can have a state. This is used, for example, in
54/// [`ModularArithmetic`], which stores the modulus in the state.
55pub trait Arithmetic<T> {
56    /// Adds two values.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
61    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError>;
62
63    /// Subtracts two values.
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if the operation is unsuccessful (e.g., on integer underflow).
68    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError>;
69
70    /// Multiplies two values.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
75    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError>;
76
77    /// Divides two values.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the operation is unsuccessful (e.g., if `y` is zero or does
82    /// not have a multiplicative inverse in the case of modular arithmetic).
83    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError>;
84
85    /// Raises `x` to the power of `y`.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
90    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError>;
91
92    /// Negates a value.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
97    fn neg(&self, x: T) -> Result<T, ArithmeticError>;
98
99    /// Checks if two values are equal. Note that equality can be a non-trivial operation;
100    /// e.g., different numbers may be equal as per modular arithmetic.
101    fn eq(&self, x: &T, y: &T) -> bool;
102}
103
104/// Extends an [`Arithmetic`] with a comparison operation on values.
105pub trait OrdArithmetic<T>: Arithmetic<T> {
106    /// Compares two values. Returns `None` if the numbers are not comparable, or the comparison
107    /// result otherwise.
108    fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering>;
109}
110
111impl<T> fmt::Debug for dyn OrdArithmetic<T> + '_ {
112    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
113        formatter.debug_tuple("OrdArithmetic").finish()
114    }
115}
116
117impl<T> Arithmetic<T> for Box<dyn OrdArithmetic<T>> {
118    #[inline]
119    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
120        (**self).add(x, y)
121    }
122
123    #[inline]
124    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
125        (**self).sub(x, y)
126    }
127
128    #[inline]
129    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
130        (**self).mul(x, y)
131    }
132
133    #[inline]
134    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
135        (**self).div(x, y)
136    }
137
138    #[inline]
139    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
140        (**self).pow(x, y)
141    }
142
143    #[inline]
144    fn neg(&self, x: T) -> Result<T, ArithmeticError> {
145        (**self).neg(x)
146    }
147
148    #[inline]
149    fn eq(&self, x: &T, y: &T) -> bool {
150        (**self).eq(x, y)
151    }
152}
153
154impl<T> OrdArithmetic<T> for Box<dyn OrdArithmetic<T>> {
155    #[inline]
156    fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
157        (**self).partial_cmp(x, y)
158    }
159}
160
161/// Wrapper type allowing to extend an [`Arithmetic`] to an [`OrdArithmetic`] implementation.
162///
163/// # Examples
164///
165/// This type can only be constructed via [`ArithmeticExt`] trait. See it for the examples
166/// of usage.
167pub struct FullArithmetic<T, A> {
168    base: A,
169    comparison: fn(&T, &T) -> Option<Ordering>,
170}
171
172impl<T, A: Clone> Clone for FullArithmetic<T, A> {
173    fn clone(&self) -> Self {
174        Self {
175            base: self.base.clone(),
176            comparison: self.comparison,
177        }
178    }
179}
180
181impl<T, A: Copy> Copy for FullArithmetic<T, A> {}
182
183impl<T, A: fmt::Debug> fmt::Debug for FullArithmetic<T, A> {
184    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
185        formatter
186            .debug_struct("FullArithmetic")
187            .field("base", &self.base)
188            .finish_non_exhaustive()
189    }
190}
191
192impl<T, A> Arithmetic<T> for FullArithmetic<T, A>
193where
194    A: Arithmetic<T>,
195{
196    #[inline]
197    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
198        self.base.add(x, y)
199    }
200
201    #[inline]
202    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
203        self.base.sub(x, y)
204    }
205
206    #[inline]
207    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
208        self.base.mul(x, y)
209    }
210
211    #[inline]
212    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
213        self.base.div(x, y)
214    }
215
216    #[inline]
217    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
218        self.base.pow(x, y)
219    }
220
221    #[inline]
222    fn neg(&self, x: T) -> Result<T, ArithmeticError> {
223        self.base.neg(x)
224    }
225
226    #[inline]
227    fn eq(&self, x: &T, y: &T) -> bool {
228        self.base.eq(x, y)
229    }
230}
231
232impl<T, A> OrdArithmetic<T> for FullArithmetic<T, A>
233where
234    A: Arithmetic<T>,
235{
236    fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
237        (self.comparison)(x, y)
238    }
239}
240
241/// Extension trait for [`Arithmetic`] allowing to combine the arithmetic with comparisons.
242///
243/// # Examples
244///
245/// ```
246/// use arithmetic_eval::arith::{ArithmeticExt, ModularArithmetic};
247/// # use arithmetic_eval::{Environment, ExecutableModule, Value};
248/// # use arithmetic_parser::grammars::{NumGrammar, Untyped, Parse};
249///
250/// # fn main() -> anyhow::Result<()> {
251/// let base = ModularArithmetic::new(11);
252///
253/// // `ModularArithmetic` requires to define how numbers will be compared -
254/// // and the simplest solution is to not compare them at all.
255/// let program = Untyped::<NumGrammar<u32>>::parse_statements("1 < 3 || 1 >= 3")?;
256/// let module = ExecutableModule::new("test", &program)?;
257/// let env = Environment::with_arithmetic(base.without_comparisons());
258/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(false));
259///
260/// // We can compare numbers by their integer value. This can lead
261/// // to pretty confusing results, though.
262/// let bogus_arithmetic = base.with_natural_comparison();
263/// let program = Untyped::<NumGrammar<u32>>::parse_statements("
264///     (x, y, z) = (1, 12, 5);
265///     x == y && x < z && y > z
266/// ")?;
267/// let module = ExecutableModule::new("test", &program)?;
268/// let env = Environment::with_arithmetic(bogus_arithmetic);
269/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
270///
271/// // It's possible to fix the situation using a custom comparison function,
272/// // which will compare numbers by their residual class.
273/// let less_bogus_arithmetic = base.with_comparison(|&x: &u32, &y: &u32| {
274///     (x % 11).partial_cmp(&(y % 11))
275/// });
276/// let env = Environment::with_arithmetic(less_bogus_arithmetic);
277/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(false));
278/// # Ok(())
279/// # }
280/// ```
281pub trait ArithmeticExt<T>: Arithmetic<T> + Sized {
282    /// Combines this arithmetic with a comparison function that assumes any two values are
283    /// incomparable.
284    fn without_comparisons(self) -> FullArithmetic<T, Self> {
285        FullArithmetic {
286            base: self,
287            comparison: |_, _| None,
288        }
289    }
290
291    /// Combines this arithmetic with a comparison function specified by the [`PartialOrd`]
292    /// implementation for `T`.
293    fn with_natural_comparison(self) -> FullArithmetic<T, Self>
294    where
295        T: PartialOrd,
296    {
297        FullArithmetic {
298            base: self,
299            comparison: T::partial_cmp,
300        }
301    }
302
303    /// Combines this arithmetic with the specified comparison function.
304    fn with_comparison(
305        self,
306        comparison: fn(&T, &T) -> Option<Ordering>,
307    ) -> FullArithmetic<T, Self> {
308        FullArithmetic {
309            base: self,
310            comparison,
311        }
312    }
313}
314
315impl<T, A> ArithmeticExt<T> for A where A: Arithmetic<T> {}