arithmetic_typing/arith/
constraints.rs

1//! `TypeConstraints` and implementations.
2
3use core::{fmt, marker::PhantomData};
4
5use crate::{
6    alloc::{Box, HashMap, String, ToString},
7    arith::Substitutions,
8    error::{ErrorKind, OpErrors},
9    visit::{self, Visit},
10    Function, Object, PrimitiveType, Slice, Tuple, Type, TypeVar,
11};
12
13/// Constraint that can be placed on [`Type`]s.
14///
15/// Constraints can be placed on [`Function`] type variables, and can be applied
16/// to types in [`TypeArithmetic`] impls. For example, [`NumArithmetic`] places
17/// the [`Linearity`] constraint on types involved in arithmetic ops.
18///
19/// The constraint mechanism is similar to trait constraints in Rust, but is much more limited:
20///
21/// - Constraints cannot be parametric (cf. parameters in traits, such `AsRef<_>`
22///   or `Iterator<Item = _>`).
23/// - Constraints are applied to types in separation; it is impossible to create a constraint
24///   involving several type variables.
25/// - Constraints cannot contradict each other.
26///
27/// # Implementation rules
28///
29/// - [`Display`](fmt::Display) must display constraint as an identifier (e.g., `Lin`).
30///   The string presentation of a constraint must be unique within a [`PrimitiveType`];
31///   it is used to identify constraints in a [`ConstraintSet`].
32///
33/// [`TypeArithmetic`]: crate::arith::TypeArithmetic
34/// [`NumArithmetic`]: crate::arith::NumArithmetic
35pub trait Constraint<Prim: PrimitiveType>: fmt::Display + Send + Sync + 'static {
36    /// Returns a [`Visit`]or that will be applied to constrained [`Type`]s. The visitor
37    /// may use `substitutions` to resolve types and `errors` to record constraint errors.
38    ///
39    /// # Tips
40    ///
41    /// - You can use [`StructConstraint`] for typical use cases, which involve recursively
42    ///   traversing `ty`.
43    fn visitor<'r>(
44        &self,
45        substitutions: &'r mut Substitutions<Prim>,
46        errors: OpErrors<'r, Prim>,
47    ) -> Box<dyn Visit<Prim> + 'r>;
48
49    /// Clones this constraint into a `Box`.
50    ///
51    /// This method should be implemented by implementing [`Clone`] and boxing its output.
52    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>>;
53}
54
55impl<Prim: PrimitiveType> fmt::Debug for dyn Constraint<Prim> {
56    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
57        formatter
58            .debug_tuple("dyn Constraint")
59            .field(&self.to_string())
60            .finish()
61    }
62}
63
64impl<Prim: PrimitiveType> Clone for Box<dyn Constraint<Prim>> {
65    fn clone(&self) -> Self {
66        self.clone_boxed()
67    }
68}
69
70/// Marker trait for object-safe constraints, i.e., constraints that can be included
71/// into a [`DynConstraints`](crate::DynConstraints).
72///
73/// Object safety is similar to this notion in Rust. For a constraint `C` to be object-safe,
74/// it should be the case that `dyn C` (the untagged union of all types implementing `C`)
75/// implements `C`. As an example, this is the case for [`Linearity`], but is not the case
76/// for [`Ops`]. Indeed, [`Ops`] requires the type to be addable to itself,
77/// which would be impossible for `dyn Ops`.
78pub trait ObjectSafeConstraint<Prim: PrimitiveType>: Constraint<Prim> {}
79
80/// Helper to define *structural* [`Constraint`]s, i.e., constraints recursively checking
81/// the provided type.
82///
83/// The following logic is used to check whether a type satisfies the constraint:
84///
85/// - Primitive types satisfy the constraint iff the predicate provided in [`Self::new()`]
86///   returns `true`.
87/// - [`Type::Any`] always satisfies the constraint.
88/// - [`Type::Dyn`] types satisfy the constraint iff the [`Constraint`] wrapped by this helper
89///   is present among [`DynConstraints`](crate::DynConstraints). Thus,
90///   if the wrapped constraint is not [object-safe](ObjectSafeConstraint), it will not be satisfied
91///   by any `Dyn` type.
92/// - Functional types never satisfy the constraint.
93/// - A compound type (i.e., a tuple) satisfies the constraint iff all its items satisfy
94///   the constraint.
95/// - If [`Self::deny_dyn_slices()`] is set, tuple types need to have static length.
96///
97/// # Examples
98///
99/// Defining a constraint type using `StructConstraint`:
100///
101/// ```
102/// # use arithmetic_typing::{
103/// #     arith::{Constraint, StructConstraint, Substitutions}, error::OpErrors, visit::Visit,
104/// #     PrimitiveType, Type,
105/// # };
106/// # use std::fmt;
107/// /// Constraint for hashable types.
108/// #[derive(Clone, Copy)]
109/// struct Hashed;
110///
111/// impl fmt::Display for Hashed {
112///     fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
113///         formatter.write_str("Hash")
114///     }
115/// }
116///
117/// impl<Prim: PrimitiveType> Constraint<Prim> for Hashed {
118///     fn visitor<'r>(
119///         &self,
120///         substitutions: &'r mut Substitutions<Prim>,
121///         errors: OpErrors<'r, Prim>,
122///     ) -> Box<dyn Visit<Prim> + 'r> {
123///         // We can hash everything except for functions (and thus,
124///         // types containing functions).
125///         StructConstraint::new(*self, |_| true)
126///             .visitor(substitutions, errors)
127///     }
128///
129///     fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
130///         Box::new(*self)
131///     }
132/// }
133/// ```
134#[derive(Debug)]
135pub struct StructConstraint<Prim, C, F> {
136    constraint: C,
137    predicate: F,
138    deny_dyn_slices: bool,
139    _prim: PhantomData<Prim>,
140}
141
142impl<Prim, C, F> StructConstraint<Prim, C, F>
143where
144    Prim: PrimitiveType,
145    C: Constraint<Prim> + Clone,
146    F: Fn(&Prim) -> bool + 'static,
147{
148    /// Creates a new helper. `predicate` determines whether a particular primitive type
149    /// should satisfy the `constraint`.
150    pub fn new(constraint: C, predicate: F) -> Self {
151        Self {
152            constraint,
153            predicate,
154            deny_dyn_slices: false,
155            _prim: PhantomData,
156        }
157    }
158
159    /// Marks that dynamically sized slices should fail the constraint check.
160    #[must_use]
161    pub fn deny_dyn_slices(mut self) -> Self {
162        self.deny_dyn_slices = true;
163        self
164    }
165
166    /// Returns a [`Visit`]or that can be used for [`Constraint::visitor()`] implementations.
167    pub fn visitor<'r>(
168        self,
169        substitutions: &'r mut Substitutions<Prim>,
170        errors: OpErrors<'r, Prim>,
171    ) -> Box<dyn Visit<Prim> + 'r> {
172        Box::new(StructConstraintVisitor {
173            inner: self,
174            substitutions,
175            errors,
176        })
177    }
178}
179
180#[derive(Debug)]
181struct StructConstraintVisitor<'r, Prim: PrimitiveType, C, F> {
182    inner: StructConstraint<Prim, C, F>,
183    substitutions: &'r mut Substitutions<Prim>,
184    errors: OpErrors<'r, Prim>,
185}
186
187impl<Prim, C, F> Visit<Prim> for StructConstraintVisitor<'_, Prim, C, F>
188where
189    Prim: PrimitiveType,
190    C: Constraint<Prim> + Clone,
191    F: Fn(&Prim) -> bool + 'static,
192{
193    fn visit_type(&mut self, ty: &Type<Prim>) {
194        match ty {
195            Type::Dyn(constraints) => {
196                if !constraints.inner.simple.contains(&self.inner.constraint) {
197                    self.errors.push(ErrorKind::failed_constraint(
198                        ty.clone(),
199                        self.inner.constraint.clone(),
200                    ));
201                }
202            }
203            _ => visit::visit_type(self, ty),
204        }
205    }
206
207    fn visit_var(&mut self, var: TypeVar) {
208        debug_assert!(var.is_free());
209        self.substitutions.insert_constraint(
210            var.index(),
211            &self.inner.constraint,
212            self.errors.by_ref(),
213        );
214
215        let resolved = self.substitutions.fast_resolve(&Type::Var(var)).clone();
216        if let Type::Var(_) = resolved {
217            // Avoid infinite recursion.
218        } else {
219            visit::visit_type(self, &resolved);
220        }
221    }
222
223    fn visit_primitive(&mut self, primitive: &Prim) {
224        if !(self.inner.predicate)(primitive) {
225            self.errors.push(ErrorKind::failed_constraint(
226                Type::Prim(primitive.clone()),
227                self.inner.constraint.clone(),
228            ));
229        }
230    }
231
232    fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
233        if self.inner.deny_dyn_slices {
234            let middle_len = tuple.parts().1.map(Slice::len);
235            if let Some(middle_len) = middle_len {
236                if let Err(err) = self.substitutions.apply_static_len(middle_len) {
237                    self.errors.push(err);
238                }
239            }
240        }
241
242        for (i, element) in tuple.element_types() {
243            self.errors.push_path_fragment(i);
244            self.visit_type(element);
245            self.errors.pop_path_fragment();
246        }
247    }
248
249    fn visit_object(&mut self, obj: &Object<Prim>) {
250        for (name, element) in obj.iter() {
251            self.errors.push_path_fragment(name);
252            self.visit_type(element);
253            self.errors.pop_path_fragment();
254        }
255    }
256
257    fn visit_function(&mut self, function: &Function<Prim>) {
258        self.errors.push(ErrorKind::failed_constraint(
259            function.clone().into(),
260            self.inner.constraint.clone(),
261        ));
262    }
263}
264
265/// [`Constraint`] for numeric types that can be subject to unary `-` and can participate
266/// in `T op Num` / `Num op T` operations.
267///
268/// Defined recursively as [linear](LinearType) primitive types and tuples / objects consisting
269/// of linear types.
270#[derive(Debug, Clone, Copy, PartialEq, Eq)]
271pub struct Linearity;
272
273impl fmt::Display for Linearity {
274    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
275        formatter.write_str("Lin")
276    }
277}
278
279impl<Prim: LinearType> Constraint<Prim> for Linearity {
280    fn visitor<'r>(
281        &self,
282        substitutions: &'r mut Substitutions<Prim>,
283        errors: OpErrors<'r, Prim>,
284    ) -> Box<dyn Visit<Prim> + 'r> {
285        StructConstraint::new(*self, LinearType::is_linear).visitor(substitutions, errors)
286    }
287
288    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
289        Box::new(*self)
290    }
291}
292
293impl<Prim: LinearType> ObjectSafeConstraint<Prim> for Linearity {}
294
295/// Primitive type which supports a notion of *linearity*. Linear types are types that
296/// can be used in arithmetic ops.
297pub trait LinearType: PrimitiveType {
298    /// Returns `true` iff this type is linear.
299    fn is_linear(&self) -> bool;
300}
301
302/// [`Constraint`] for numeric types that can participate in binary arithmetic ops (`T op T`).
303///
304/// Defined as a subset of `Lin` types without dynamically sized slices and
305/// any types containing dynamically sized slices.
306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub struct Ops;
308
309impl fmt::Display for Ops {
310    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
311        formatter.write_str("Ops")
312    }
313}
314
315impl<Prim: LinearType> Constraint<Prim> for Ops {
316    fn visitor<'r>(
317        &self,
318        substitutions: &'r mut Substitutions<Prim>,
319        errors: OpErrors<'r, Prim>,
320    ) -> Box<dyn Visit<Prim> + 'r> {
321        StructConstraint::new(*self, LinearType::is_linear)
322            .deny_dyn_slices()
323            .visitor(substitutions, errors)
324    }
325
326    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
327        Box::new(*self)
328    }
329}
330
331/// Set of [`Constraint`]s.
332///
333/// [`Display`](fmt::Display)ed as `Foo + Bar + Quux`, where `Foo`, `Bar` and `Quux` are
334/// constraints in the set.
335#[derive(Debug, Clone)]
336pub struct ConstraintSet<Prim: PrimitiveType> {
337    inner: HashMap<String, (Box<dyn Constraint<Prim>>, bool)>,
338}
339
340impl<Prim: PrimitiveType> Default for ConstraintSet<Prim> {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346impl<Prim: PrimitiveType> PartialEq for ConstraintSet<Prim> {
347    fn eq(&self, other: &Self) -> bool {
348        if self.inner.len() == other.inner.len() {
349            self.inner.keys().all(|key| other.inner.contains_key(key))
350        } else {
351            false
352        }
353    }
354}
355
356impl<Prim: PrimitiveType> fmt::Display for ConstraintSet<Prim> {
357    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358        let len = self.inner.len();
359        for (i, (constraint, _)) in self.inner.values().enumerate() {
360            fmt::Display::fmt(constraint, formatter)?;
361            if i + 1 < len {
362                formatter.write_str(" + ")?;
363            }
364        }
365        Ok(())
366    }
367}
368
369impl<Prim: PrimitiveType> ConstraintSet<Prim> {
370    /// Creates an empty set.
371    pub fn new() -> Self {
372        Self {
373            inner: HashMap::new(),
374        }
375    }
376
377    /// Creates a set with one constraint.
378    pub fn just(constraint: impl Constraint<Prim>) -> Self {
379        let mut this = Self::new();
380        this.insert(constraint);
381        this
382    }
383
384    /// Checks if this constraint set is empty.
385    pub fn is_empty(&self) -> bool {
386        self.inner.is_empty()
387    }
388
389    fn contains(&self, constraint: &impl Constraint<Prim>) -> bool {
390        self.inner.contains_key(&constraint.to_string())
391    }
392
393    /// Inserts a constraint into this set.
394    pub fn insert(&mut self, constraint: impl Constraint<Prim>) {
395        self.inner
396            .insert(constraint.to_string(), (Box::new(constraint), false));
397    }
398
399    /// Inserts an object-safe constraint into this set.
400    pub fn insert_object_safe(&mut self, constraint: impl ObjectSafeConstraint<Prim>) {
401        self.inner
402            .insert(constraint.to_string(), (Box::new(constraint), true));
403    }
404
405    /// Inserts a boxed constraint into this set.
406    pub(crate) fn insert_boxed(&mut self, constraint: Box<dyn Constraint<Prim>>) {
407        self.inner
408            .insert(constraint.to_string(), (constraint, false));
409    }
410
411    /// Returns the link to constraint and an indicator whether it is object-safe.
412    pub(crate) fn get_by_name(&self, name: &str) -> Option<(&dyn Constraint<Prim>, bool)> {
413        self.inner
414            .get(name)
415            .map(|(constraint, is_object_safe)| (constraint.as_ref(), *is_object_safe))
416    }
417
418    /// Applies all constraints from this set.
419    pub(crate) fn apply_all(
420        &self,
421        ty: &Type<Prim>,
422        substitutions: &mut Substitutions<Prim>,
423        mut errors: OpErrors<'_, Prim>,
424    ) {
425        for (constraint, _) in self.inner.values() {
426            constraint
427                .visitor(substitutions, errors.by_ref())
428                .visit_type(ty);
429        }
430    }
431
432    /// Applies all constraints from this set to an object.
433    pub(crate) fn apply_all_to_object(
434        &self,
435        object: &Object<Prim>,
436        substitutions: &mut Substitutions<Prim>,
437        mut errors: OpErrors<'_, Prim>,
438    ) {
439        for (constraint, _) in self.inner.values() {
440            constraint
441                .visitor(substitutions, errors.by_ref())
442                .visit_object(object);
443        }
444    }
445}
446
447/// Extended [`ConstraintSet`] that additionally supports object constraints.
448#[derive(Debug, Clone, PartialEq)]
449pub(crate) struct CompleteConstraints<Prim: PrimitiveType> {
450    pub simple: ConstraintSet<Prim>,
451    /// Object constraint. Stored as `Type` for convenience.
452    pub object: Option<Object<Prim>>,
453}
454
455impl<Prim: PrimitiveType> Default for CompleteConstraints<Prim> {
456    fn default() -> Self {
457        Self {
458            simple: ConstraintSet::new(),
459            object: None,
460        }
461    }
462}
463
464impl<Prim: PrimitiveType> From<ConstraintSet<Prim>> for CompleteConstraints<Prim> {
465    fn from(constraints: ConstraintSet<Prim>) -> Self {
466        Self {
467            simple: constraints,
468            object: None,
469        }
470    }
471}
472
473impl<Prim: PrimitiveType> From<Object<Prim>> for CompleteConstraints<Prim> {
474    fn from(object: Object<Prim>) -> Self {
475        Self {
476            simple: ConstraintSet::default(),
477            object: Some(object),
478        }
479    }
480}
481
482impl<Prim: PrimitiveType> fmt::Display for CompleteConstraints<Prim> {
483    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
484        match (&self.object, self.simple.is_empty()) {
485            (Some(object), false) => write!(formatter, "{object} + {}", self.simple),
486            (Some(object), true) => fmt::Display::fmt(object, formatter),
487            (None, _) => fmt::Display::fmt(&self.simple, formatter),
488        }
489    }
490}
491
492impl<Prim: PrimitiveType> CompleteConstraints<Prim> {
493    /// Checks if this constraint set is empty.
494    pub fn is_empty(&self) -> bool {
495        self.object.is_none() && self.simple.is_empty()
496    }
497
498    /// Inserts a constraint into this set.
499    pub fn insert(
500        &mut self,
501        constraint: impl Constraint<Prim>,
502        substitutions: &mut Substitutions<Prim>,
503        errors: OpErrors<'_, Prim>,
504    ) {
505        self.simple.insert(constraint);
506        self.check_object_consistency(substitutions, errors);
507    }
508
509    /// Applies all constraints from this set.
510    pub fn apply_all(
511        &self,
512        ty: &Type<Prim>,
513        substitutions: &mut Substitutions<Prim>,
514        mut errors: OpErrors<'_, Prim>,
515    ) {
516        self.simple.apply_all(ty, substitutions, errors.by_ref());
517        if let Some(lhs) = &self.object {
518            lhs.apply_as_constraint(ty, substitutions, errors);
519        }
520    }
521
522    /// Maps the object constraint if present.
523    pub fn map_object(self, map: impl FnOnce(&mut Object<Prim>)) -> Self {
524        Self {
525            simple: self.simple,
526            object: self.object.map(|mut object| {
527                map(&mut object);
528                object
529            }),
530        }
531    }
532
533    /// Inserts an object constraint into this set.
534    pub fn insert_obj_constraint(
535        &mut self,
536        object: Object<Prim>,
537        substitutions: &mut Substitutions<Prim>,
538        mut errors: OpErrors<'_, Prim>,
539    ) {
540        if let Some(existing_object) = &mut self.object {
541            existing_object.extend_from(object, substitutions, errors.by_ref());
542        } else {
543            self.object = Some(object);
544        }
545        self.check_object_consistency(substitutions, errors);
546    }
547
548    fn check_object_consistency(
549        &self,
550        substitutions: &mut Substitutions<Prim>,
551        errors: OpErrors<'_, Prim>,
552    ) {
553        if let Some(object) = &self.object {
554            self.simple
555                .apply_all_to_object(object, substitutions, errors);
556        }
557    }
558}