arithmetic_typing/
visit.rs

1//! Visitor traits allowing to traverse [`Type`] and related types.
2
3use crate::{DynConstraints, Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar};
4
5/// Recursive traversal across the shared reference to a [`Type`].
6///
7/// Inspired by the [`Visit` trait from `syn`](https://docs.rs/syn/^1/syn/visit/trait.Visit.html).
8///
9/// # Examples
10///
11/// ```
12/// use arithmetic_typing::{
13///     ast::TypeAst, visit::{self, Visit},
14///     PrimitiveType, Slice, Tuple, UnknownLen, Type, TypeVar,
15/// };
16/// # use std::{collections::HashMap, convert::TryFrom};
17///
18/// /// Counts the number of mentions of type / length params in a type.
19/// #[derive(Default)]
20/// pub struct Mentions {
21///     types: HashMap<usize, usize>,
22///     lengths: HashMap<usize, usize>,
23/// }
24///
25/// impl<Prim: PrimitiveType> Visit<Prim> for Mentions {
26///     fn visit_var(&mut self, var: TypeVar) {
27///         *self.types.entry(var.index()).or_default() += 1;
28///     }
29///
30///     fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
31///         let (_, middle, _) = tuple.parts();
32///         let len = middle.and_then(|middle| middle.len().components().0);
33///         if let Some(UnknownLen::Var(var)) = len {
34///             *self.lengths.entry(var.index()).or_default() += 1;
35///         }
36///         visit::visit_tuple(self, tuple);
37///     }
38/// }
39///
40/// # fn main() -> anyhow::Result<()> {
41/// let ty = TypeAst::try_from("(...['T; N], ('T) -> 'U) -> [('T, 'U); N]")?;
42/// let ty: Type = Type::try_from(&ty)?;
43///
44/// let mut mentions = Mentions::default();
45/// mentions.visit_type(&ty);
46/// assert_eq!(mentions.lengths[&0], 2); // `N` is mentioned twice
47/// assert_eq!(mentions.types[&0], 3); // `T` is mentioned 3 times
48/// assert_eq!(mentions.types[&1], 2); // `U` is mentioned twice
49/// # Ok(())
50/// # }
51/// ```
52#[allow(unused_variables)]
53pub trait Visit<Prim: PrimitiveType> {
54    /// Visits a generic type.
55    ///
56    /// The default implementation calls one of more specific methods corresponding to the `ty`
57    /// variant.
58    fn visit_type(&mut self, ty: &Type<Prim>) {
59        visit_type(self, ty);
60    }
61
62    /// Visits a type variable.
63    ///
64    /// The default implementation does nothing.
65    fn visit_var(&mut self, var: TypeVar) {
66        // Does nothing.
67    }
68
69    /// Visits a primitive type.
70    ///
71    /// The default implementation does nothing.
72    fn visit_primitive(&mut self, primitive: &Prim) {
73        // Does nothing.
74    }
75
76    /// Visits a tuple type.
77    ///
78    /// The default implementation calls [`Self::visit_type()`] for each tuple element,
79    /// including the middle element if any.
80    fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
81        visit_tuple(self, tuple);
82    }
83
84    /// Visits an object type.
85    fn visit_object(&mut self, object: &Object<Prim>) {
86        visit_object(self, object);
87    }
88
89    /// Visits a [`Type::Dyn`] variant.
90    ///
91    /// The default implementation visits the object constraint if it is present using
92    /// [`Self::visit_object()`].
93    fn visit_dyn_constraints(&mut self, constraints: &DynConstraints<Prim>) {
94        if let Some(object) = &constraints.inner.object {
95            self.visit_object(object);
96        }
97    }
98
99    /// Visits a functional type.
100    ///
101    /// The default implementation calls [`Self::visit_tuple()`] on arguments and then
102    /// [`Self::visit_type()`] on the return value.
103    fn visit_function(&mut self, function: &Function<Prim>) {
104        visit_function(self, function);
105    }
106}
107
108/// Default implementation of [`Visit::visit_type()`].
109pub fn visit_type<Prim, V>(visitor: &mut V, ty: &Type<Prim>)
110where
111    Prim: PrimitiveType,
112    V: Visit<Prim> + ?Sized,
113{
114    match ty {
115        Type::Any => { /* Do nothing. */ }
116        Type::Dyn(constraints) => visitor.visit_dyn_constraints(constraints),
117        Type::Var(var) => visitor.visit_var(*var),
118        Type::Prim(primitive) => visitor.visit_primitive(primitive),
119        Type::Tuple(tuple) => visitor.visit_tuple(tuple),
120        Type::Object(obj) => visitor.visit_object(obj),
121        Type::Function(function) => visitor.visit_function(function.as_ref()),
122    }
123}
124
125/// Default implementation of [`Visit::visit_tuple()`].
126pub fn visit_tuple<Prim, V>(visitor: &mut V, tuple: &Tuple<Prim>)
127where
128    Prim: PrimitiveType,
129    V: Visit<Prim> + ?Sized,
130{
131    for (_, ty) in tuple.element_types() {
132        visitor.visit_type(ty);
133    }
134}
135
136/// Default implementation of [`Visit::visit_object()`].
137pub fn visit_object<Prim, V>(visitor: &mut V, object: &Object<Prim>)
138where
139    Prim: PrimitiveType,
140    V: Visit<Prim> + ?Sized,
141{
142    for (_, ty) in object.iter() {
143        visitor.visit_type(ty);
144    }
145}
146
147/// Default implementation of [`Visit::visit_function()`].
148pub fn visit_function<Prim, V>(visitor: &mut V, function: &Function<Prim>)
149where
150    Prim: PrimitiveType,
151    V: Visit<Prim> + ?Sized,
152{
153    visitor.visit_tuple(&function.args);
154    visitor.visit_type(&function.return_type);
155}
156
157/// Recursive traversal across the exclusive reference to a [`Type`].
158///
159/// Inspired by the [`VisitMut` trait from `syn`].
160///
161/// [`VisitMut` trait from `syn`]: https://docs.rs/syn/^1/syn/visit_mut/trait.VisitMut.html
162///
163/// # Examples
164///
165/// ```
166/// use arithmetic_typing::{ast::TypeAst, arith::Num, Type};
167/// use arithmetic_typing::visit::{self, VisitMut};
168/// # use std::convert::TryFrom;
169///
170/// /// Replaces all primitive types with `Num`.
171/// struct Replacer;
172///
173/// impl VisitMut<Num> for Replacer {
174///     fn visit_type_mut(&mut self, ty: &mut Type) {
175///         match ty {
176///             Type::Prim(_) => *ty = Type::NUM,
177///             _ => visit::visit_type_mut(self, ty),
178///         }
179///     }
180/// }
181///
182/// # fn main() -> anyhow::Result<()> {
183/// let ty = TypeAst::try_from("(Num, Bool, (Num) -> (Bool, Num))")?;
184/// let mut ty = Type::try_from(&ty)?;
185/// Replacer.visit_type_mut(&mut ty);
186/// assert_eq!(ty.to_string(), "(Num, Num, (Num) -> (Num, Num))");
187/// # Ok(())
188/// # }
189/// ```
190#[allow(unused_variables)]
191pub trait VisitMut<Prim: PrimitiveType> {
192    /// Visits a generic type.
193    ///
194    /// The default implementation calls one of more specific methods corresponding to the `ty`
195    /// variant. For "simple" types (variables, params, primitive types) does nothing.
196    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
197        visit_type_mut(self, ty);
198    }
199
200    /// Visits a tuple type.
201    ///
202    /// The default implementation calls [`Self::visit_middle_len_mut()`] for the middle length
203    /// if the tuple has a middle. Then, [`Self::visit_type_mut()`] is called
204    /// for each tuple element, including the middle element if any.
205    fn visit_tuple_mut(&mut self, tuple: &mut Tuple<Prim>) {
206        visit_tuple_mut(self, tuple);
207    }
208
209    /// Visits an object type.
210    fn visit_object_mut(&mut self, object: &mut Object<Prim>) {
211        visit_object_mut(self, object);
212    }
213
214    /// Visits a [`Type::Dyn`] variant.
215    ///
216    /// The default implementation visits the object constraint if it is present using
217    /// [`Self::visit_object_mut()`].
218    fn visit_dyn_constraints_mut(&mut self, constraints: &mut DynConstraints<Prim>) {
219        if let Some(object) = &mut constraints.inner.object {
220            self.visit_object_mut(object);
221        }
222    }
223
224    /// Visits a middle length of a tuple.
225    ///
226    /// The default implementation does nothing.
227    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
228        // Does nothing.
229    }
230
231    /// Visits a functional type.
232    ///
233    /// The default implementation calls [`Self::visit_tuple_mut()`] on arguments and then
234    /// [`Self::visit_type_mut()`] on the return value.
235    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
236        visit_function_mut(self, function);
237    }
238}
239
240/// Default implementation of [`VisitMut::visit_type_mut()`].
241pub fn visit_type_mut<Prim, V>(visitor: &mut V, ty: &mut Type<Prim>)
242where
243    Prim: PrimitiveType,
244    V: VisitMut<Prim> + ?Sized,
245{
246    match ty {
247        Type::Any | Type::Var(_) | Type::Prim(_) => {}
248        Type::Dyn(constraints) => visitor.visit_dyn_constraints_mut(constraints),
249        Type::Tuple(tuple) => visitor.visit_tuple_mut(tuple),
250        Type::Object(obj) => visitor.visit_object_mut(obj),
251        Type::Function(function) => visitor.visit_function_mut(function.as_mut()),
252    }
253}
254
255/// Default implementation of [`VisitMut::visit_tuple_mut()`].
256pub fn visit_tuple_mut<Prim, V>(visitor: &mut V, tuple: &mut Tuple<Prim>)
257where
258    Prim: PrimitiveType,
259    V: VisitMut<Prim> + ?Sized,
260{
261    if let Some(middle) = tuple.parts_mut().1 {
262        visitor.visit_middle_len_mut(middle.len_mut());
263    }
264    for ty in tuple.element_types_mut() {
265        visitor.visit_type_mut(ty);
266    }
267}
268
269/// Default implementation of [`VisitMut::visit_object_mut()`].
270pub fn visit_object_mut<Prim, V>(visitor: &mut V, object: &mut Object<Prim>)
271where
272    Prim: PrimitiveType,
273    V: VisitMut<Prim> + ?Sized,
274{
275    for (_, ty) in object.iter_mut() {
276        visitor.visit_type_mut(ty);
277    }
278}
279
280/// Default implementation of [`VisitMut::visit_function_mut()`].
281pub fn visit_function_mut<Prim, V>(visitor: &mut V, function: &mut Function<Prim>)
282where
283    Prim: PrimitiveType,
284    V: VisitMut<Prim> + ?Sized,
285{
286    visitor.visit_tuple_mut(&mut function.args);
287    visitor.visit_type_mut(&mut function.return_type);
288}