1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
//! Visitor traits allowing to traverse [`Type`] and related types.

use crate::{DynConstraints, Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar};

/// Recursive traversal across the shared reference to a [`Type`].
///
/// Inspired by the [`Visit` trait from `syn`](https://docs.rs/syn/^1/syn/visit/trait.Visit.html).
///
/// # Examples
///
/// ```
/// use arithmetic_typing::{
///     ast::TypeAst, visit::{self, Visit},
///     PrimitiveType, Slice, Tuple, UnknownLen, Type, TypeVar,
/// };
/// # use std::{collections::HashMap, convert::TryFrom};
///
/// /// Counts the number of mentions of type / length params in a type.
/// #[derive(Default)]
/// pub struct Mentions {
///     types: HashMap<usize, usize>,
///     lengths: HashMap<usize, usize>,
/// }
///
/// impl<Prim: PrimitiveType> Visit<Prim> for Mentions {
///     fn visit_var(&mut self, var: TypeVar) {
///         *self.types.entry(var.index()).or_default() += 1;
///     }
///
///     fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
///         let (_, middle, _) = tuple.parts();
///         let len = middle.and_then(|middle| middle.len().components().0);
///         if let Some(UnknownLen::Var(var)) = len {
///             *self.lengths.entry(var.index()).or_default() += 1;
///         }
///         visit::visit_tuple(self, tuple);
///     }
/// }
///
/// # fn main() -> anyhow::Result<()> {
/// let ty = TypeAst::try_from("(...['T; N], ('T) -> 'U) -> [('T, 'U); N]")?;
/// let ty: Type = Type::try_from(&ty)?;
///
/// let mut mentions = Mentions::default();
/// mentions.visit_type(&ty);
/// assert_eq!(mentions.lengths[&0], 2); // `N` is mentioned twice
/// assert_eq!(mentions.types[&0], 3); // `T` is mentioned 3 times
/// assert_eq!(mentions.types[&1], 2); // `U` is mentioned twice
/// # Ok(())
/// # }
/// ```
#[allow(unused_variables)]
pub trait Visit<Prim: PrimitiveType> {
    /// Visits a generic type.
    ///
    /// The default implementation calls one of more specific methods corresponding to the `ty`
    /// variant.
    fn visit_type(&mut self, ty: &Type<Prim>) {
        visit_type(self, ty);
    }

    /// Visits a type variable.
    ///
    /// The default implementation does nothing.
    fn visit_var(&mut self, var: TypeVar) {
        // Does nothing.
    }

    /// Visits a primitive type.
    ///
    /// The default implementation does nothing.
    fn visit_primitive(&mut self, primitive: &Prim) {
        // Does nothing.
    }

    /// Visits a tuple type.
    ///
    /// The default implementation calls [`Self::visit_type()`] for each tuple element,
    /// including the middle element if any.
    fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
        visit_tuple(self, tuple);
    }

    /// Visits an object type.
    fn visit_object(&mut self, object: &Object<Prim>) {
        visit_object(self, object);
    }

    /// Visits a [`Type::Dyn`] variant.
    ///
    /// The default implementation visits the object constraint if it is present using
    /// [`Self::visit_object()`].
    fn visit_dyn_constraints(&mut self, constraints: &DynConstraints<Prim>) {
        if let Some(object) = &constraints.inner.object {
            self.visit_object(object);
        }
    }

    /// Visits a functional type.
    ///
    /// The default implementation calls [`Self::visit_tuple()`] on arguments and then
    /// [`Self::visit_type()`] on the return value.
    fn visit_function(&mut self, function: &Function<Prim>) {
        visit_function(self, function);
    }
}

/// Default implementation of [`Visit::visit_type()`].
pub fn visit_type<Prim, V>(visitor: &mut V, ty: &Type<Prim>)
where
    Prim: PrimitiveType,
    V: Visit<Prim> + ?Sized,
{
    match ty {
        Type::Any => { /* Do nothing. */ }
        Type::Dyn(constraints) => visitor.visit_dyn_constraints(constraints),
        Type::Var(var) => visitor.visit_var(*var),
        Type::Prim(primitive) => visitor.visit_primitive(primitive),
        Type::Tuple(tuple) => visitor.visit_tuple(tuple),
        Type::Object(obj) => visitor.visit_object(obj),
        Type::Function(function) => visitor.visit_function(function.as_ref()),
    }
}

/// Default implementation of [`Visit::visit_tuple()`].
pub fn visit_tuple<Prim, V>(visitor: &mut V, tuple: &Tuple<Prim>)
where
    Prim: PrimitiveType,
    V: Visit<Prim> + ?Sized,
{
    for (_, ty) in tuple.element_types() {
        visitor.visit_type(ty);
    }
}

/// Default implementation of [`Visit::visit_object()`].
pub fn visit_object<Prim, V>(visitor: &mut V, object: &Object<Prim>)
where
    Prim: PrimitiveType,
    V: Visit<Prim> + ?Sized,
{
    for (_, ty) in object.iter() {
        visitor.visit_type(ty);
    }
}

/// Default implementation of [`Visit::visit_function()`].
pub fn visit_function<Prim, V>(visitor: &mut V, function: &Function<Prim>)
where
    Prim: PrimitiveType,
    V: Visit<Prim> + ?Sized,
{
    visitor.visit_tuple(&function.args);
    visitor.visit_type(&function.return_type);
}

/// Recursive traversal across the exclusive reference to a [`Type`].
///
/// Inspired by the [`VisitMut` trait from `syn`].
///
/// [`VisitMut` trait from `syn`]: https://docs.rs/syn/^1/syn/visit_mut/trait.VisitMut.html
///
/// # Examples
///
/// ```
/// use arithmetic_typing::{ast::TypeAst, arith::Num, Type};
/// use arithmetic_typing::visit::{self, VisitMut};
/// # use std::convert::TryFrom;
///
/// /// Replaces all primitive types with `Num`.
/// struct Replacer;
///
/// impl VisitMut<Num> for Replacer {
///     fn visit_type_mut(&mut self, ty: &mut Type) {
///         match ty {
///             Type::Prim(_) => *ty = Type::NUM,
///             _ => visit::visit_type_mut(self, ty),
///         }
///     }
/// }
///
/// # fn main() -> anyhow::Result<()> {
/// let ty = TypeAst::try_from("(Num, Bool, (Num) -> (Bool, Num))")?;
/// let mut ty = Type::try_from(&ty)?;
/// Replacer.visit_type_mut(&mut ty);
/// assert_eq!(ty.to_string(), "(Num, Num, (Num) -> (Num, Num))");
/// # Ok(())
/// # }
/// ```
#[allow(unused_variables)]
pub trait VisitMut<Prim: PrimitiveType> {
    /// Visits a generic type.
    ///
    /// The default implementation calls one of more specific methods corresponding to the `ty`
    /// variant. For "simple" types (variables, params, primitive types) does nothing.
    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
        visit_type_mut(self, ty);
    }

    /// Visits a tuple type.
    ///
    /// The default implementation calls [`Self::visit_middle_len_mut()`] for the middle length
    /// if the tuple has a middle. Then, [`Self::visit_type_mut()`] is called
    /// for each tuple element, including the middle element if any.
    fn visit_tuple_mut(&mut self, tuple: &mut Tuple<Prim>) {
        visit_tuple_mut(self, tuple);
    }

    /// Visits an object type.
    fn visit_object_mut(&mut self, object: &mut Object<Prim>) {
        visit_object_mut(self, object);
    }

    /// Visits a [`Type::Dyn`] variant.
    ///
    /// The default implementation visits the object constraint if it is present using
    /// [`Self::visit_object_mut()`].
    fn visit_dyn_constraints_mut(&mut self, constraints: &mut DynConstraints<Prim>) {
        if let Some(object) = &mut constraints.inner.object {
            self.visit_object_mut(object);
        }
    }

    /// Visits a middle length of a tuple.
    ///
    /// The default implementation does nothing.
    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
        // Does nothing.
    }

    /// Visits a functional type.
    ///
    /// The default implementation calls [`Self::visit_tuple_mut()`] on arguments and then
    /// [`Self::visit_type_mut()`] on the return value.
    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
        visit_function_mut(self, function);
    }
}

/// Default implementation of [`VisitMut::visit_type_mut()`].
pub fn visit_type_mut<Prim, V>(visitor: &mut V, ty: &mut Type<Prim>)
where
    Prim: PrimitiveType,
    V: VisitMut<Prim> + ?Sized,
{
    match ty {
        Type::Any | Type::Var(_) | Type::Prim(_) => {}
        Type::Dyn(constraints) => visitor.visit_dyn_constraints_mut(constraints),
        Type::Tuple(tuple) => visitor.visit_tuple_mut(tuple),
        Type::Object(obj) => visitor.visit_object_mut(obj),
        Type::Function(function) => visitor.visit_function_mut(function.as_mut()),
    }
}

/// Default implementation of [`VisitMut::visit_tuple_mut()`].
pub fn visit_tuple_mut<Prim, V>(visitor: &mut V, tuple: &mut Tuple<Prim>)
where
    Prim: PrimitiveType,
    V: VisitMut<Prim> + ?Sized,
{
    if let Some(middle) = tuple.parts_mut().1 {
        visitor.visit_middle_len_mut(middle.len_mut());
    }
    for ty in tuple.element_types_mut() {
        visitor.visit_type_mut(ty);
    }
}

/// Default implementation of [`VisitMut::visit_object_mut()`].
pub fn visit_object_mut<Prim, V>(visitor: &mut V, object: &mut Object<Prim>)
where
    Prim: PrimitiveType,
    V: VisitMut<Prim> + ?Sized,
{
    for (_, ty) in object.iter_mut() {
        visitor.visit_type_mut(ty);
    }
}

/// Default implementation of [`VisitMut::visit_function_mut()`].
pub fn visit_function_mut<Prim, V>(visitor: &mut V, function: &mut Function<Prim>)
where
    Prim: PrimitiveType,
    V: VisitMut<Prim> + ?Sized,
{
    visitor.visit_tuple_mut(&mut function.args);
    visitor.visit_type_mut(&mut function.return_type);
}