arithmetic_typing/arith/
mod.rs

1//! Types allowing to customize various aspects of the type system, such as type constraints
2//! and behavior of unary / binary ops.
3
4use core::{fmt, str::FromStr};
5
6use arithmetic_parser::{BinaryOp, UnaryOp};
7use num_traits::NumOps;
8
9pub(crate) use self::constraints::CompleteConstraints;
10pub use self::{
11    constraints::{
12        Constraint, ConstraintSet, LinearType, Linearity, ObjectSafeConstraint, Ops,
13        StructConstraint,
14    },
15    substitutions::Substitutions,
16};
17use crate::{
18    error::{ErrorKind, ErrorPathFragment, OpErrors},
19    PrimitiveType, Type,
20};
21
22mod constraints;
23mod substitutions;
24
25/// Maps a literal value from a certain [`Grammar`] to its type. This assumes that all literals
26/// are primitive.
27///
28/// [`Grammar`]: arithmetic_parser::grammars::Grammar
29pub trait MapPrimitiveType<Val> {
30    /// Types of literals output by this mapper.
31    type Prim: PrimitiveType;
32
33    /// Gets the type of the provided literal value.
34    fn type_of_literal(&self, lit: &Val) -> Self::Prim;
35}
36
37/// Arithmetic allowing to customize primitive types and how unary and binary operations are handled
38/// during type inference.
39///
40/// # Examples
41///
42/// See crate examples for examples how define custom arithmetics.
43pub trait TypeArithmetic<Prim: PrimitiveType> {
44    /// Handles a unary operation.
45    fn process_unary_op(
46        &self,
47        substitutions: &mut Substitutions<Prim>,
48        context: &UnaryOpContext<Prim>,
49        errors: OpErrors<'_, Prim>,
50    ) -> Type<Prim>;
51
52    /// Handles a binary operation.
53    fn process_binary_op(
54        &self,
55        substitutions: &mut Substitutions<Prim>,
56        context: &BinaryOpContext<Prim>,
57        errors: OpErrors<'_, Prim>,
58    ) -> Type<Prim>;
59}
60
61/// Code spans related to a unary operation.
62///
63/// Used in [`TypeArithmetic::process_unary_op()`].
64#[derive(Debug, Clone)]
65pub struct UnaryOpContext<Prim: PrimitiveType> {
66    /// Unary operation.
67    pub op: UnaryOp,
68    /// Operation argument.
69    pub arg: Type<Prim>,
70}
71
72/// Code spans related to a binary operation.
73///
74/// Used in [`TypeArithmetic::process_binary_op()`].
75#[derive(Debug, Clone)]
76pub struct BinaryOpContext<Prim: PrimitiveType> {
77    /// Binary operation.
78    pub op: BinaryOp,
79    /// Spanned left-hand side.
80    pub lhs: Type<Prim>,
81    /// Spanned right-hand side.
82    pub rhs: Type<Prim>,
83}
84
85/// [`PrimitiveType`] that has Boolean type as one of its variants.
86pub trait WithBoolean: PrimitiveType {
87    /// Boolean type.
88    const BOOL: Self;
89}
90
91/// Simplest [`TypeArithmetic`] implementation that defines unary / binary ops only on
92/// the Boolean type. Useful as a building block for more complex arithmetics.
93#[derive(Debug, Clone, Copy, Default)]
94pub struct BoolArithmetic;
95
96impl<Prim: WithBoolean> TypeArithmetic<Prim> for BoolArithmetic {
97    /// Processes a unary operation.
98    ///
99    /// - `!` requires a Boolean input and outputs a Boolean.
100    /// - Other operations fail with [`ErrorKind::UnsupportedFeature`].
101    fn process_unary_op(
102        &self,
103        substitutions: &mut Substitutions<Prim>,
104        context: &UnaryOpContext<Prim>,
105        mut errors: OpErrors<'_, Prim>,
106    ) -> Type<Prim> {
107        let op = context.op;
108        if op == UnaryOp::Not {
109            substitutions.unify(&Type::BOOL, &context.arg, errors);
110            Type::BOOL
111        } else {
112            let err = ErrorKind::unsupported(op);
113            errors.push(err);
114            substitutions.new_type_var()
115        }
116    }
117
118    /// Processes a binary operation.
119    ///
120    /// - `==` and `!=` require LHS and RHS to have the same type (no matter which one).
121    ///   These ops return `Bool`.
122    /// - `&&` and `||` require LHS and RHS to have `Bool` type. These ops return `Bool`.
123    /// - Other operations fail with [`ErrorKind::UnsupportedFeature`].
124    fn process_binary_op(
125        &self,
126        substitutions: &mut Substitutions<Prim>,
127        context: &BinaryOpContext<Prim>,
128        mut errors: OpErrors<'_, Prim>,
129    ) -> Type<Prim> {
130        match context.op {
131            BinaryOp::Eq | BinaryOp::NotEq => {
132                substitutions.unify(&context.lhs, &context.rhs, errors);
133                Type::BOOL
134            }
135
136            BinaryOp::And | BinaryOp::Or => {
137                substitutions.unify(
138                    &Type::BOOL,
139                    &context.lhs,
140                    errors.join_path(ErrorPathFragment::Lhs),
141                );
142                substitutions.unify(
143                    &Type::BOOL,
144                    &context.rhs,
145                    errors.join_path(ErrorPathFragment::Rhs),
146                );
147                Type::BOOL
148            }
149
150            _ => {
151                errors.push(ErrorKind::unsupported(context.op));
152                substitutions.new_type_var()
153            }
154        }
155    }
156}
157
158/// Settings for constraints placed on arguments of binary arithmetic operations.
159#[derive(Debug, Clone)]
160pub struct OpConstraintSettings<'a, Prim: PrimitiveType> {
161    /// Constraint applied to the argument of `T op Num` / `Num op T` ops.
162    pub lin: &'a dyn Constraint<Prim>,
163    /// Constraint applied to the arguments of in-kind binary arithmetic ops (`T op T`).
164    pub ops: &'a dyn Constraint<Prim>,
165}
166
167impl<Prim: PrimitiveType> Copy for OpConstraintSettings<'_, Prim> {}
168
169/// Primitive types for the numeric arithmetic: `Num`eric type and `Bool`ean.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum Num {
172    /// Numeric type (e.g., 1).
173    Num,
174    /// Boolean value (true or false).
175    Bool,
176}
177
178impl fmt::Display for Num {
179    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
180        formatter.write_str(match self {
181            Self::Num => "Num",
182            Self::Bool => "Bool",
183        })
184    }
185}
186
187impl FromStr for Num {
188    type Err = anyhow::Error;
189
190    fn from_str(s: &str) -> Result<Self, Self::Err> {
191        match s {
192            "Num" => Ok(Self::Num),
193            "Bool" => Ok(Self::Bool),
194            _ => Err(anyhow::anyhow!("Expected `Num` or `Bool`")),
195        }
196    }
197}
198
199impl PrimitiveType for Num {
200    fn well_known_constraints() -> ConstraintSet<Self> {
201        let mut constraints = ConstraintSet::default();
202        constraints.insert_object_safe(Linearity);
203        constraints.insert(Ops);
204        constraints
205    }
206}
207
208impl WithBoolean for Num {
209    const BOOL: Self = Self::Bool;
210}
211
212/// `Num`bers are linear, `Bool`ean values are not.
213impl LinearType for Num {
214    fn is_linear(&self) -> bool {
215        matches!(self, Self::Num)
216    }
217}
218
219/// Arithmetic on [`Num`]bers.
220///
221/// # Unary ops
222///
223/// - Unary minus is follows the equation `-T == T`, where `T` is any [linear](Linearity) type.
224/// - Unary negation is only defined for `Bool`s.
225///
226/// # Binary ops
227///
228/// Binary ops fall into 3 cases: `Num op T == T`, `T op Num == T`, or `T op T == T`,
229/// where `T` is any linear type (that is, `Num` or tuple of linear types).
230/// `T op T` is assumed by default, only falling into two other cases if one of operands
231/// is known to be a number and the other is not a number.
232///
233/// # Comparisons
234///
235/// Order comparisons (`>`, `<`, `>=`, `<=`) can be switched on or off. Use
236/// [`Self::with_comparisons()`] constructor to switch them on. If switched on, both arguments
237/// of the order comparison must be numbers.
238#[derive(Debug, Clone)]
239pub struct NumArithmetic {
240    comparisons_enabled: bool,
241}
242
243impl NumArithmetic {
244    /// Creates an instance of arithmetic that does not support order comparisons.
245    pub const fn without_comparisons() -> Self {
246        Self {
247            comparisons_enabled: false,
248        }
249    }
250
251    /// Creates an instance of arithmetic that supports order comparisons.
252    pub const fn with_comparisons() -> Self {
253        Self {
254            comparisons_enabled: true,
255        }
256    }
257
258    /// Applies [binary ops](#binary-ops) logic to unify the given LHS and RHS types.
259    /// Returns the result type of the binary operation.
260    ///
261    /// This logic can be reused by other [`TypeArithmetic`] implementations.
262    ///
263    /// # Arguments
264    ///
265    /// - `settings` are applied to arguments of arithmetic ops.
266    pub fn unify_binary_op<Prim: PrimitiveType>(
267        substitutions: &mut Substitutions<Prim>,
268        context: &BinaryOpContext<Prim>,
269        mut errors: OpErrors<'_, Prim>,
270        settings: OpConstraintSettings<'_, Prim>,
271    ) -> Type<Prim> {
272        let lhs_ty = &context.lhs;
273        let rhs_ty = &context.rhs;
274        let resolved_lhs_ty = substitutions.fast_resolve(lhs_ty);
275        let resolved_rhs_ty = substitutions.fast_resolve(rhs_ty);
276
277        match (
278            resolved_lhs_ty.is_primitive(),
279            resolved_rhs_ty.is_primitive(),
280        ) {
281            (Some(true), Some(false)) => {
282                let resolved_rhs_ty = resolved_rhs_ty.clone();
283                settings
284                    .lin
285                    .visitor(substitutions, errors.join_path(ErrorPathFragment::Lhs))
286                    .visit_type(lhs_ty);
287                settings
288                    .lin
289                    .visitor(substitutions, errors.join_path(ErrorPathFragment::Rhs))
290                    .visit_type(rhs_ty);
291                resolved_rhs_ty
292            }
293            (Some(false), Some(true)) => {
294                let resolved_lhs_ty = resolved_lhs_ty.clone();
295                settings
296                    .lin
297                    .visitor(substitutions, errors.join_path(ErrorPathFragment::Lhs))
298                    .visit_type(lhs_ty);
299                settings
300                    .lin
301                    .visitor(substitutions, errors.join_path(ErrorPathFragment::Rhs))
302                    .visit_type(rhs_ty);
303                resolved_lhs_ty
304            }
305            _ => {
306                let lhs_is_valid = errors.join_path(ErrorPathFragment::Lhs).check(|errors| {
307                    settings
308                        .ops
309                        .visitor(substitutions, errors)
310                        .visit_type(lhs_ty);
311                });
312                let rhs_is_valid = errors.join_path(ErrorPathFragment::Rhs).check(|errors| {
313                    settings
314                        .ops
315                        .visitor(substitutions, errors)
316                        .visit_type(rhs_ty);
317                });
318
319                if lhs_is_valid && rhs_is_valid {
320                    substitutions.unify(lhs_ty, rhs_ty, errors);
321                }
322                if lhs_is_valid {
323                    lhs_ty.clone()
324                } else {
325                    rhs_ty.clone()
326                }
327            }
328        }
329    }
330
331    /// Processes a unary operation according to [the numeric arithmetic rules](#unary-ops).
332    /// Returns the result type of the unary operation.
333    ///
334    /// This logic can be reused by other [`TypeArithmetic`] implementations.
335    pub fn process_unary_op<Prim: WithBoolean>(
336        substitutions: &mut Substitutions<Prim>,
337        context: &UnaryOpContext<Prim>,
338        mut errors: OpErrors<'_, Prim>,
339        constraints: &impl Constraint<Prim>,
340    ) -> Type<Prim> {
341        match context.op {
342            UnaryOp::Not => BoolArithmetic.process_unary_op(substitutions, context, errors),
343            UnaryOp::Neg => {
344                constraints
345                    .visitor(substitutions, errors)
346                    .visit_type(&context.arg);
347                context.arg.clone()
348            }
349            _ => {
350                errors.push(ErrorKind::unsupported(context.op));
351                substitutions.new_type_var()
352            }
353        }
354    }
355
356    /// Processes a binary operation according to [the numeric arithmetic rules](#binary-ops).
357    /// Returns the result type of the unary operation.
358    ///
359    /// This logic can be reused by other [`TypeArithmetic`] implementations.
360    ///
361    /// # Arguments
362    ///
363    /// - If `comparable_type` is set to `Some(_)`, it will be used to unify arguments of
364    ///   order comparisons. If `comparable_type` is `None`, order comparisons are not supported.
365    /// - `constraints` are applied to arguments of arithmetic ops.
366    pub fn process_binary_op<Prim: WithBoolean>(
367        substitutions: &mut Substitutions<Prim>,
368        context: &BinaryOpContext<Prim>,
369        mut errors: OpErrors<'_, Prim>,
370        comparable_type: Option<Prim>,
371        settings: OpConstraintSettings<'_, Prim>,
372    ) -> Type<Prim> {
373        match context.op {
374            BinaryOp::And | BinaryOp::Or | BinaryOp::Eq | BinaryOp::NotEq => {
375                BoolArithmetic.process_binary_op(substitutions, context, errors)
376            }
377
378            BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
379                Self::unify_binary_op(substitutions, context, errors, settings)
380            }
381
382            BinaryOp::Ge | BinaryOp::Le | BinaryOp::Lt | BinaryOp::Gt => {
383                if let Some(ty) = comparable_type {
384                    let ty = Type::Prim(ty);
385                    substitutions.unify(
386                        &ty,
387                        &context.lhs,
388                        errors.join_path(ErrorPathFragment::Lhs),
389                    );
390                    substitutions.unify(
391                        &ty,
392                        &context.rhs,
393                        errors.join_path(ErrorPathFragment::Rhs),
394                    );
395                } else {
396                    let err = ErrorKind::unsupported(context.op);
397                    errors.push(err);
398                }
399                Type::BOOL
400            }
401
402            _ => {
403                errors.push(ErrorKind::unsupported(context.op));
404                substitutions.new_type_var()
405            }
406        }
407    }
408}
409
410impl<Val> MapPrimitiveType<Val> for NumArithmetic
411where
412    Val: Clone + NumOps + PartialEq,
413{
414    type Prim = Num;
415
416    fn type_of_literal(&self, _: &Val) -> Self::Prim {
417        Num::Num
418    }
419}
420
421impl TypeArithmetic<Num> for NumArithmetic {
422    fn process_unary_op(
423        &self,
424        substitutions: &mut Substitutions<Num>,
425        context: &UnaryOpContext<Num>,
426        errors: OpErrors<'_, Num>,
427    ) -> Type<Num> {
428        Self::process_unary_op(substitutions, context, errors, &Linearity)
429    }
430
431    fn process_binary_op(
432        &self,
433        substitutions: &mut Substitutions<Num>,
434        context: &BinaryOpContext<Num>,
435        errors: OpErrors<'_, Num>,
436    ) -> Type<Num> {
437        const OP_SETTINGS: OpConstraintSettings<'static, Num> = OpConstraintSettings {
438            lin: &Linearity,
439            ops: &Ops,
440        };
441
442        let comparable_type = if self.comparisons_enabled {
443            Some(Num::Num)
444        } else {
445            None
446        };
447
448        Self::process_binary_op(substitutions, context, errors, comparable_type, OP_SETTINGS)
449    }
450}