arithmetic_eval/arith/mod.rs
1//! `Arithmetic` trait and its implementations.
2//!
3//! # Traits
4//!
5//! An [`Arithmetic`] defines fallible arithmetic operations on primitive values
6//! of an [`ExecutableModule`], namely, addition, subtraction, multiplication, division,
7//! exponentiation (all binary ops), and negation (a unary op). Any module can be run
8//! with any `Arithmetic` on its primitive values, although some modules are reasonably tied
9//! to a particular arithmetic or a class of arithmetics (e.g., arithmetics on finite fields).
10//!
11//! [`OrdArithmetic`] extends [`Arithmetic`] with a partial comparison operation
12//! (i.e., an analogue to [`PartialOrd`]). This is motivated by the fact that comparisons
13//! may be switched off during parsing, and some `Arithmetic`s do not have well-defined comparisons.
14//!
15//! [`ArithmeticExt`] helps converting an [`Arithmetic`] into an [`OrdArithmetic`].
16//!
17//! # Implementations
18//!
19//! This module defines the following kinds of arithmetics:
20//!
21//! - [`StdArithmetic`] takes all implementations from the corresponding [`ops`](core::ops) traits.
22//! This means that it's safe to use *provided* the ops are infallible. As a counter-example,
23//! using [`StdArithmetic`] with built-in integer types (such as `u64`) is usually not a good
24//! idea since the corresponding ops have failure modes (e.g., division by zero or integer
25//! overflow).
26//! - [`WrappingArithmetic`] is defined for integer types; it uses wrapping semantics for all ops.
27//! - [`CheckedArithmetic`] is defined for integer types; it uses checked semantics for all ops.
28//! - [`ModularArithmetic`] operates on integers modulo the specified number.
29//!
30//! All defined [`Arithmetic`]s strive to be as generic as possible.
31//!
32//! [`ExecutableModule`]: crate::ExecutableModule
33
34use core::{cmp::Ordering, fmt};
35
36pub use self::{
37 generic::{
38 Checked, CheckedArithmetic, CheckedArithmeticKind, NegateOnlyZero, StdArithmetic,
39 Unchecked, WrappingArithmetic,
40 },
41 modular::{DoubleWidth, ModularArithmetic},
42};
43use crate::{alloc::Box, error::ArithmeticError};
44
45#[cfg(feature = "bigint")]
46mod bigint;
47mod generic;
48mod modular;
49
50/// Encapsulates arithmetic operations on a certain primitive type (or an enum of primitive types).
51///
52/// Unlike operations on built-in integer types, arithmetic operations may be fallible.
53/// Additionally, the arithmetic can have a state. This is used, for example, in
54/// [`ModularArithmetic`], which stores the modulus in the state.
55pub trait Arithmetic<T> {
56 /// Adds two values.
57 ///
58 /// # Errors
59 ///
60 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
61 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError>;
62
63 /// Subtracts two values.
64 ///
65 /// # Errors
66 ///
67 /// Returns an error if the operation is unsuccessful (e.g., on integer underflow).
68 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError>;
69
70 /// Multiplies two values.
71 ///
72 /// # Errors
73 ///
74 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
75 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError>;
76
77 /// Divides two values.
78 ///
79 /// # Errors
80 ///
81 /// Returns an error if the operation is unsuccessful (e.g., if `y` is zero or does
82 /// not have a multiplicative inverse in the case of modular arithmetic).
83 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError>;
84
85 /// Raises `x` to the power of `y`.
86 ///
87 /// # Errors
88 ///
89 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
90 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError>;
91
92 /// Negates a value.
93 ///
94 /// # Errors
95 ///
96 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
97 fn neg(&self, x: T) -> Result<T, ArithmeticError>;
98
99 /// Checks if two values are equal. Note that equality can be a non-trivial operation;
100 /// e.g., different numbers may be equal as per modular arithmetic.
101 fn eq(&self, x: &T, y: &T) -> bool;
102}
103
104/// Extends an [`Arithmetic`] with a comparison operation on values.
105pub trait OrdArithmetic<T>: Arithmetic<T> {
106 /// Compares two values. Returns `None` if the numbers are not comparable, or the comparison
107 /// result otherwise.
108 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering>;
109}
110
111impl<T> fmt::Debug for dyn OrdArithmetic<T> + '_ {
112 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
113 formatter.debug_tuple("OrdArithmetic").finish()
114 }
115}
116
117impl<T> Arithmetic<T> for Box<dyn OrdArithmetic<T>> {
118 #[inline]
119 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
120 (**self).add(x, y)
121 }
122
123 #[inline]
124 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
125 (**self).sub(x, y)
126 }
127
128 #[inline]
129 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
130 (**self).mul(x, y)
131 }
132
133 #[inline]
134 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
135 (**self).div(x, y)
136 }
137
138 #[inline]
139 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
140 (**self).pow(x, y)
141 }
142
143 #[inline]
144 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
145 (**self).neg(x)
146 }
147
148 #[inline]
149 fn eq(&self, x: &T, y: &T) -> bool {
150 (**self).eq(x, y)
151 }
152}
153
154impl<T> OrdArithmetic<T> for Box<dyn OrdArithmetic<T>> {
155 #[inline]
156 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
157 (**self).partial_cmp(x, y)
158 }
159}
160
161/// Wrapper type allowing to extend an [`Arithmetic`] to an [`OrdArithmetic`] implementation.
162///
163/// # Examples
164///
165/// This type can only be constructed via [`ArithmeticExt`] trait. See it for the examples
166/// of usage.
167pub struct FullArithmetic<T, A> {
168 base: A,
169 comparison: fn(&T, &T) -> Option<Ordering>,
170}
171
172impl<T, A: Clone> Clone for FullArithmetic<T, A> {
173 fn clone(&self) -> Self {
174 Self {
175 base: self.base.clone(),
176 comparison: self.comparison,
177 }
178 }
179}
180
181impl<T, A: Copy> Copy for FullArithmetic<T, A> {}
182
183impl<T, A: fmt::Debug> fmt::Debug for FullArithmetic<T, A> {
184 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
185 formatter
186 .debug_struct("FullArithmetic")
187 .field("base", &self.base)
188 .finish_non_exhaustive()
189 }
190}
191
192impl<T, A> Arithmetic<T> for FullArithmetic<T, A>
193where
194 A: Arithmetic<T>,
195{
196 #[inline]
197 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
198 self.base.add(x, y)
199 }
200
201 #[inline]
202 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
203 self.base.sub(x, y)
204 }
205
206 #[inline]
207 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
208 self.base.mul(x, y)
209 }
210
211 #[inline]
212 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
213 self.base.div(x, y)
214 }
215
216 #[inline]
217 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
218 self.base.pow(x, y)
219 }
220
221 #[inline]
222 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
223 self.base.neg(x)
224 }
225
226 #[inline]
227 fn eq(&self, x: &T, y: &T) -> bool {
228 self.base.eq(x, y)
229 }
230}
231
232impl<T, A> OrdArithmetic<T> for FullArithmetic<T, A>
233where
234 A: Arithmetic<T>,
235{
236 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
237 (self.comparison)(x, y)
238 }
239}
240
241/// Extension trait for [`Arithmetic`] allowing to combine the arithmetic with comparisons.
242///
243/// # Examples
244///
245/// ```
246/// use arithmetic_eval::arith::{ArithmeticExt, ModularArithmetic};
247/// # use arithmetic_eval::{Environment, ExecutableModule, Value};
248/// # use arithmetic_parser::grammars::{NumGrammar, Untyped, Parse};
249///
250/// # fn main() -> anyhow::Result<()> {
251/// let base = ModularArithmetic::new(11);
252///
253/// // `ModularArithmetic` requires to define how numbers will be compared -
254/// // and the simplest solution is to not compare them at all.
255/// let program = Untyped::<NumGrammar<u32>>::parse_statements("1 < 3 || 1 >= 3")?;
256/// let module = ExecutableModule::new("test", &program)?;
257/// let env = Environment::with_arithmetic(base.without_comparisons());
258/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(false));
259///
260/// // We can compare numbers by their integer value. This can lead
261/// // to pretty confusing results, though.
262/// let bogus_arithmetic = base.with_natural_comparison();
263/// let program = Untyped::<NumGrammar<u32>>::parse_statements("
264/// (x, y, z) = (1, 12, 5);
265/// x == y && x < z && y > z
266/// ")?;
267/// let module = ExecutableModule::new("test", &program)?;
268/// let env = Environment::with_arithmetic(bogus_arithmetic);
269/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
270///
271/// // It's possible to fix the situation using a custom comparison function,
272/// // which will compare numbers by their residual class.
273/// let less_bogus_arithmetic = base.with_comparison(|&x: &u32, &y: &u32| {
274/// (x % 11).partial_cmp(&(y % 11))
275/// });
276/// let env = Environment::with_arithmetic(less_bogus_arithmetic);
277/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(false));
278/// # Ok(())
279/// # }
280/// ```
281pub trait ArithmeticExt<T>: Arithmetic<T> + Sized {
282 /// Combines this arithmetic with a comparison function that assumes any two values are
283 /// incomparable.
284 fn without_comparisons(self) -> FullArithmetic<T, Self> {
285 FullArithmetic {
286 base: self,
287 comparison: |_, _| None,
288 }
289 }
290
291 /// Combines this arithmetic with a comparison function specified by the [`PartialOrd`]
292 /// implementation for `T`.
293 fn with_natural_comparison(self) -> FullArithmetic<T, Self>
294 where
295 T: PartialOrd,
296 {
297 FullArithmetic {
298 base: self,
299 comparison: T::partial_cmp,
300 }
301 }
302
303 /// Combines this arithmetic with the specified comparison function.
304 fn with_comparison(
305 self,
306 comparison: fn(&T, &T) -> Option<Ordering>,
307 ) -> FullArithmetic<T, Self> {
308 FullArithmetic {
309 base: self,
310 comparison,
311 }
312 }
313}
314
315impl<T, A> ArithmeticExt<T> for A where A: Arithmetic<T> {}