arithmetic_typing/visit.rs
1//! Visitor traits allowing to traverse [`Type`] and related types.
2
3use crate::{DynConstraints, Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar};
4
5/// Recursive traversal across the shared reference to a [`Type`].
6///
7/// Inspired by the [`Visit` trait from `syn`](https://docs.rs/syn/^1/syn/visit/trait.Visit.html).
8///
9/// # Examples
10///
11/// ```
12/// use arithmetic_typing::{
13/// ast::TypeAst, visit::{self, Visit},
14/// PrimitiveType, Slice, Tuple, UnknownLen, Type, TypeVar,
15/// };
16/// # use std::{collections::HashMap, convert::TryFrom};
17///
18/// /// Counts the number of mentions of type / length params in a type.
19/// #[derive(Default)]
20/// pub struct Mentions {
21/// types: HashMap<usize, usize>,
22/// lengths: HashMap<usize, usize>,
23/// }
24///
25/// impl<Prim: PrimitiveType> Visit<Prim> for Mentions {
26/// fn visit_var(&mut self, var: TypeVar) {
27/// *self.types.entry(var.index()).or_default() += 1;
28/// }
29///
30/// fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
31/// let (_, middle, _) = tuple.parts();
32/// let len = middle.and_then(|middle| middle.len().components().0);
33/// if let Some(UnknownLen::Var(var)) = len {
34/// *self.lengths.entry(var.index()).or_default() += 1;
35/// }
36/// visit::visit_tuple(self, tuple);
37/// }
38/// }
39///
40/// # fn main() -> anyhow::Result<()> {
41/// let ty = TypeAst::try_from("(...['T; N], ('T) -> 'U) -> [('T, 'U); N]")?;
42/// let ty: Type = Type::try_from(&ty)?;
43///
44/// let mut mentions = Mentions::default();
45/// mentions.visit_type(&ty);
46/// assert_eq!(mentions.lengths[&0], 2); // `N` is mentioned twice
47/// assert_eq!(mentions.types[&0], 3); // `T` is mentioned 3 times
48/// assert_eq!(mentions.types[&1], 2); // `U` is mentioned twice
49/// # Ok(())
50/// # }
51/// ```
52#[allow(unused_variables)]
53pub trait Visit<Prim: PrimitiveType> {
54 /// Visits a generic type.
55 ///
56 /// The default implementation calls one of more specific methods corresponding to the `ty`
57 /// variant.
58 fn visit_type(&mut self, ty: &Type<Prim>) {
59 visit_type(self, ty);
60 }
61
62 /// Visits a type variable.
63 ///
64 /// The default implementation does nothing.
65 fn visit_var(&mut self, var: TypeVar) {
66 // Does nothing.
67 }
68
69 /// Visits a primitive type.
70 ///
71 /// The default implementation does nothing.
72 fn visit_primitive(&mut self, primitive: &Prim) {
73 // Does nothing.
74 }
75
76 /// Visits a tuple type.
77 ///
78 /// The default implementation calls [`Self::visit_type()`] for each tuple element,
79 /// including the middle element if any.
80 fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
81 visit_tuple(self, tuple);
82 }
83
84 /// Visits an object type.
85 fn visit_object(&mut self, object: &Object<Prim>) {
86 visit_object(self, object);
87 }
88
89 /// Visits a [`Type::Dyn`] variant.
90 ///
91 /// The default implementation visits the object constraint if it is present using
92 /// [`Self::visit_object()`].
93 fn visit_dyn_constraints(&mut self, constraints: &DynConstraints<Prim>) {
94 if let Some(object) = &constraints.inner.object {
95 self.visit_object(object);
96 }
97 }
98
99 /// Visits a functional type.
100 ///
101 /// The default implementation calls [`Self::visit_tuple()`] on arguments and then
102 /// [`Self::visit_type()`] on the return value.
103 fn visit_function(&mut self, function: &Function<Prim>) {
104 visit_function(self, function);
105 }
106}
107
108/// Default implementation of [`Visit::visit_type()`].
109pub fn visit_type<Prim, V>(visitor: &mut V, ty: &Type<Prim>)
110where
111 Prim: PrimitiveType,
112 V: Visit<Prim> + ?Sized,
113{
114 match ty {
115 Type::Any => { /* Do nothing. */ }
116 Type::Dyn(constraints) => visitor.visit_dyn_constraints(constraints),
117 Type::Var(var) => visitor.visit_var(*var),
118 Type::Prim(primitive) => visitor.visit_primitive(primitive),
119 Type::Tuple(tuple) => visitor.visit_tuple(tuple),
120 Type::Object(obj) => visitor.visit_object(obj),
121 Type::Function(function) => visitor.visit_function(function.as_ref()),
122 }
123}
124
125/// Default implementation of [`Visit::visit_tuple()`].
126pub fn visit_tuple<Prim, V>(visitor: &mut V, tuple: &Tuple<Prim>)
127where
128 Prim: PrimitiveType,
129 V: Visit<Prim> + ?Sized,
130{
131 for (_, ty) in tuple.element_types() {
132 visitor.visit_type(ty);
133 }
134}
135
136/// Default implementation of [`Visit::visit_object()`].
137pub fn visit_object<Prim, V>(visitor: &mut V, object: &Object<Prim>)
138where
139 Prim: PrimitiveType,
140 V: Visit<Prim> + ?Sized,
141{
142 for (_, ty) in object.iter() {
143 visitor.visit_type(ty);
144 }
145}
146
147/// Default implementation of [`Visit::visit_function()`].
148pub fn visit_function<Prim, V>(visitor: &mut V, function: &Function<Prim>)
149where
150 Prim: PrimitiveType,
151 V: Visit<Prim> + ?Sized,
152{
153 visitor.visit_tuple(&function.args);
154 visitor.visit_type(&function.return_type);
155}
156
157/// Recursive traversal across the exclusive reference to a [`Type`].
158///
159/// Inspired by the [`VisitMut` trait from `syn`].
160///
161/// [`VisitMut` trait from `syn`]: https://docs.rs/syn/^1/syn/visit_mut/trait.VisitMut.html
162///
163/// # Examples
164///
165/// ```
166/// use arithmetic_typing::{ast::TypeAst, arith::Num, Type};
167/// use arithmetic_typing::visit::{self, VisitMut};
168/// # use std::convert::TryFrom;
169///
170/// /// Replaces all primitive types with `Num`.
171/// struct Replacer;
172///
173/// impl VisitMut<Num> for Replacer {
174/// fn visit_type_mut(&mut self, ty: &mut Type) {
175/// match ty {
176/// Type::Prim(_) => *ty = Type::NUM,
177/// _ => visit::visit_type_mut(self, ty),
178/// }
179/// }
180/// }
181///
182/// # fn main() -> anyhow::Result<()> {
183/// let ty = TypeAst::try_from("(Num, Bool, (Num) -> (Bool, Num))")?;
184/// let mut ty = Type::try_from(&ty)?;
185/// Replacer.visit_type_mut(&mut ty);
186/// assert_eq!(ty.to_string(), "(Num, Num, (Num) -> (Num, Num))");
187/// # Ok(())
188/// # }
189/// ```
190#[allow(unused_variables)]
191pub trait VisitMut<Prim: PrimitiveType> {
192 /// Visits a generic type.
193 ///
194 /// The default implementation calls one of more specific methods corresponding to the `ty`
195 /// variant. For "simple" types (variables, params, primitive types) does nothing.
196 fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
197 visit_type_mut(self, ty);
198 }
199
200 /// Visits a tuple type.
201 ///
202 /// The default implementation calls [`Self::visit_middle_len_mut()`] for the middle length
203 /// if the tuple has a middle. Then, [`Self::visit_type_mut()`] is called
204 /// for each tuple element, including the middle element if any.
205 fn visit_tuple_mut(&mut self, tuple: &mut Tuple<Prim>) {
206 visit_tuple_mut(self, tuple);
207 }
208
209 /// Visits an object type.
210 fn visit_object_mut(&mut self, object: &mut Object<Prim>) {
211 visit_object_mut(self, object);
212 }
213
214 /// Visits a [`Type::Dyn`] variant.
215 ///
216 /// The default implementation visits the object constraint if it is present using
217 /// [`Self::visit_object_mut()`].
218 fn visit_dyn_constraints_mut(&mut self, constraints: &mut DynConstraints<Prim>) {
219 if let Some(object) = &mut constraints.inner.object {
220 self.visit_object_mut(object);
221 }
222 }
223
224 /// Visits a middle length of a tuple.
225 ///
226 /// The default implementation does nothing.
227 fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
228 // Does nothing.
229 }
230
231 /// Visits a functional type.
232 ///
233 /// The default implementation calls [`Self::visit_tuple_mut()`] on arguments and then
234 /// [`Self::visit_type_mut()`] on the return value.
235 fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
236 visit_function_mut(self, function);
237 }
238}
239
240/// Default implementation of [`VisitMut::visit_type_mut()`].
241pub fn visit_type_mut<Prim, V>(visitor: &mut V, ty: &mut Type<Prim>)
242where
243 Prim: PrimitiveType,
244 V: VisitMut<Prim> + ?Sized,
245{
246 match ty {
247 Type::Any | Type::Var(_) | Type::Prim(_) => {}
248 Type::Dyn(constraints) => visitor.visit_dyn_constraints_mut(constraints),
249 Type::Tuple(tuple) => visitor.visit_tuple_mut(tuple),
250 Type::Object(obj) => visitor.visit_object_mut(obj),
251 Type::Function(function) => visitor.visit_function_mut(function.as_mut()),
252 }
253}
254
255/// Default implementation of [`VisitMut::visit_tuple_mut()`].
256pub fn visit_tuple_mut<Prim, V>(visitor: &mut V, tuple: &mut Tuple<Prim>)
257where
258 Prim: PrimitiveType,
259 V: VisitMut<Prim> + ?Sized,
260{
261 if let Some(middle) = tuple.parts_mut().1 {
262 visitor.visit_middle_len_mut(middle.len_mut());
263 }
264 for ty in tuple.element_types_mut() {
265 visitor.visit_type_mut(ty);
266 }
267}
268
269/// Default implementation of [`VisitMut::visit_object_mut()`].
270pub fn visit_object_mut<Prim, V>(visitor: &mut V, object: &mut Object<Prim>)
271where
272 Prim: PrimitiveType,
273 V: VisitMut<Prim> + ?Sized,
274{
275 for (_, ty) in object.iter_mut() {
276 visitor.visit_type_mut(ty);
277 }
278}
279
280/// Default implementation of [`VisitMut::visit_function_mut()`].
281pub fn visit_function_mut<Prim, V>(visitor: &mut V, function: &mut Function<Prim>)
282where
283 Prim: PrimitiveType,
284 V: VisitMut<Prim> + ?Sized,
285{
286 visitor.visit_tuple_mut(&mut function.args);
287 visitor.visit_type_mut(&mut function.return_type);
288}