arithmetic_typing/arith/substitutions/
mod.rs

1//! Substitutions type and dependencies.
2
3use core::{cmp::Ordering, iter, ops, ptr};
4
5use self::fns::{MonoTypeTransformer, ParamMapping};
6use crate::{
7    alloc::{vec, Box, HashMap, HashSet, String, Vec},
8    arith::{CompleteConstraints, Constraint},
9    error::{ErrorKind, ErrorPathFragment, OpErrors, TupleContext},
10    visit::{self, Visit, VisitMut},
11    Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar, UnknownLen,
12};
13
14mod fns;
15#[cfg(test)]
16mod tests;
17
18#[derive(Debug, Clone, Copy)]
19enum LenErrorKind {
20    UnresolvedParam,
21    Mismatch,
22    Dynamic(TupleLen),
23}
24
25/// Set of equations and constraints on type variables.
26#[derive(Debug, Clone)]
27pub struct Substitutions<Prim: PrimitiveType> {
28    /// Number of type variables.
29    type_var_count: usize,
30    /// Type variable equations, encoded as `type_var[key] = value`.
31    eqs: HashMap<usize, Type<Prim>>,
32    /// Constraints on type variables.
33    constraints: HashMap<usize, CompleteConstraints<Prim>>,
34    /// Number of length variables.
35    len_var_count: usize,
36    /// Length variable equations.
37    length_eqs: HashMap<usize, TupleLen>,
38    /// Lengths that have static restriction.
39    static_lengths: HashSet<usize>,
40}
41
42impl<Prim: PrimitiveType> Default for Substitutions<Prim> {
43    fn default() -> Self {
44        Self {
45            type_var_count: 0,
46            eqs: HashMap::new(),
47            constraints: HashMap::new(),
48            len_var_count: 0,
49            length_eqs: HashMap::new(),
50            static_lengths: HashSet::new(),
51        }
52    }
53}
54
55impl<Prim: PrimitiveType> Substitutions<Prim> {
56    /// Inserts `constraints` for a type var with the specified index and all vars
57    /// it is equivalent to.
58    pub fn insert_constraint<C>(
59        &mut self,
60        var_idx: usize,
61        constraint: &C,
62        mut errors: OpErrors<'_, Prim>,
63    ) where
64        C: Constraint<Prim> + Clone,
65    {
66        for idx in self.equivalent_vars(var_idx) {
67            let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
68            current_constraints.insert(constraint.clone(), self, errors.by_ref());
69            self.constraints.insert(idx, current_constraints);
70        }
71    }
72
73    /// Returns an object constraint associated with the specified type var. The returned type
74    /// is resolved.
75    pub(crate) fn object_constraint(&self, var: TypeVar) -> Option<Object<Prim>> {
76        if var.is_free() {
77            let mut ty = self.constraints.get(&var.index())?.object.clone()?;
78            self.resolver().visit_object_mut(&mut ty);
79            Some(ty)
80        } else {
81            None
82        }
83    }
84
85    /// Inserts an object constraint for a type var with the specified index.
86    pub(crate) fn insert_obj_constraint(
87        &mut self,
88        var_idx: usize,
89        constraint: &Object<Prim>,
90        mut errors: OpErrors<'_, Prim>,
91    ) {
92        // Check whether the constraint is recursive.
93        let mut checker = OccurrenceChecker::new(self, self.equivalent_vars(var_idx));
94        checker.visit_object(constraint);
95        if let Some(var) = checker.recursive_var {
96            self.handle_recursive_type(Type::Object(constraint.clone()), var, &mut errors);
97            return;
98        }
99
100        for idx in self.equivalent_vars(var_idx) {
101            let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
102            current_constraints.insert_obj_constraint(constraint.clone(), self, errors.by_ref());
103            self.constraints.insert(idx, current_constraints);
104        }
105    }
106
107    // TODO: If recursion is manifested via constraints, the returned type is not informative.
108    fn handle_recursive_type(
109        &self,
110        ty: Type<Prim>,
111        recursive_var: usize,
112        errors: &mut OpErrors<'_, Prim>,
113    ) {
114        let mut resolved_ty = ty;
115        self.resolver().visit_type_mut(&mut resolved_ty);
116        TypeSanitizer::new(recursive_var).visit_type_mut(&mut resolved_ty);
117        errors.push(ErrorKind::RecursiveType(Box::new(resolved_ty)));
118    }
119
120    /// Returns type var indexes that are equivalent to the provided `var_idx`,
121    /// including `var_idx` itself.
122    fn equivalent_vars(&self, var_idx: usize) -> Vec<usize> {
123        let ty = Type::free_var(var_idx);
124        let mut ty = &ty;
125        let mut equivalent_vars = vec![];
126
127        while let Type::Var(var) = ty {
128            debug_assert!(var.is_free());
129            equivalent_vars.push(var.index());
130            if let Some(resolved) = self.eqs.get(&var.index()) {
131                ty = resolved;
132            } else {
133                break;
134            }
135        }
136        equivalent_vars
137    }
138
139    /// Marks `len` as static, i.e., not containing [`UnknownLen::Dynamic`] components.
140    #[allow(clippy::missing_panics_doc)]
141    pub fn apply_static_len(&mut self, len: TupleLen) -> Result<(), ErrorKind<Prim>> {
142        let resolved = self.resolve_len(len);
143        self.apply_static_len_inner(resolved)
144            .map_err(|err| match err {
145                LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
146                LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
147                LenErrorKind::Mismatch => unreachable!(),
148            })
149    }
150
151    // Assumes that `len` is resolved.
152    fn apply_static_len_inner(&mut self, len: TupleLen) -> Result<(), LenErrorKind> {
153        match len.components().0 {
154            None => Ok(()),
155            Some(UnknownLen::Dynamic) => Err(LenErrorKind::Dynamic(len)),
156            Some(UnknownLen::Var(var)) => {
157                if var.is_free() {
158                    self.static_lengths.insert(var.index());
159                    Ok(())
160                } else {
161                    Err(LenErrorKind::UnresolvedParam)
162                }
163            }
164        }
165    }
166
167    /// Resolves the type by following established equality links between type variables.
168    pub fn fast_resolve<'a>(&'a self, mut ty: &'a Type<Prim>) -> &'a Type<Prim> {
169        while let Type::Var(var) = ty {
170            if !var.is_free() {
171                // Bound variables cannot be resolved further.
172                break;
173            }
174
175            if let Some(resolved) = self.eqs.get(&var.index()) {
176                ty = resolved;
177            } else {
178                break;
179            }
180        }
181        ty
182    }
183
184    /// Returns a visitor that resolves the type using equality relations in these `Substitutions`.
185    pub fn resolver(&self) -> impl VisitMut<Prim> + '_ {
186        TypeResolver {
187            substitutions: self,
188        }
189    }
190
191    /// Resolves the provided `len` given length equations in this instance.
192    pub(crate) fn resolve_len(&self, len: TupleLen) -> TupleLen {
193        let mut resolved = len;
194        while let (Some(UnknownLen::Var(var)), exact) = resolved.components() {
195            if !var.is_free() {
196                break;
197            }
198
199            if let Some(eq_rhs) = self.length_eqs.get(&var.index()) {
200                resolved = *eq_rhs + exact;
201            } else {
202                break;
203            }
204        }
205        resolved
206    }
207
208    /// Creates and returns a new type variable.
209    pub fn new_type_var(&mut self) -> Type<Prim> {
210        let new_type = Type::free_var(self.type_var_count);
211        self.type_var_count += 1;
212        new_type
213    }
214
215    /// Creates and returns a new length variable.
216    pub(crate) fn new_len_var(&mut self) -> UnknownLen {
217        let new_length = UnknownLen::free_var(self.len_var_count);
218        self.len_var_count += 1;
219        new_length
220    }
221
222    /// Unifies types in `lhs` and `rhs`.
223    ///
224    /// - LHS corresponds to the lvalue in assignments and to called function signature in fn calls.
225    /// - RHS corresponds to the rvalue in assignments and to the type of the called function.
226    ///
227    /// If unification is impossible, the corresponding error(s) will be put into `errors`.
228    pub fn unify(&mut self, lhs: &Type<Prim>, rhs: &Type<Prim>, mut errors: OpErrors<'_, Prim>) {
229        let resolved_lhs = self.fast_resolve(lhs).clone();
230        let resolved_rhs = self.fast_resolve(rhs).clone();
231
232        // **NB.** LHS and RHS should never switch sides; the side is important for
233        // accuracy of error reporting, and for some cases of type inference (e.g.,
234        // instantiation of parametric functions).
235        match (&resolved_lhs, &resolved_rhs) {
236            // Variables should be assigned *before* the equality check and dealing with `Any`
237            // to account for `Var <- Any` assignment.
238            (Type::Var(var), ty) => {
239                if var.is_free() {
240                    self.unify_var(var.index(), ty, true, errors);
241                } else {
242                    errors.push(ErrorKind::UnresolvedParam);
243                }
244            }
245
246            // This takes care of `Any` types because they are equal to anything.
247            (ty, other_ty) if ty == other_ty => {
248                // We already know that types are equal.
249            }
250
251            (Type::Dyn(constraints), ty) => {
252                constraints.inner.apply_all(ty, self, errors);
253            }
254
255            (ty, Type::Var(var)) => {
256                if var.is_free() {
257                    self.unify_var(var.index(), ty, false, errors);
258                } else {
259                    errors.push(ErrorKind::UnresolvedParam);
260                }
261            }
262
263            (Type::Tuple(lhs_tuple), Type::Tuple(rhs_tuple)) => {
264                self.unify_tuples(lhs_tuple, rhs_tuple, TupleContext::Generic, errors);
265            }
266            (Type::Object(lhs_obj), Type::Object(rhs_obj)) => {
267                self.unify_objects(lhs_obj, rhs_obj, errors);
268            }
269
270            (Type::Function(lhs_fn), Type::Function(rhs_fn)) => {
271                self.unify_fn_types(lhs_fn, rhs_fn, errors);
272            }
273
274            (ty, other_ty) => {
275                let mut resolver = self.resolver();
276                let mut ty = ty.clone();
277                resolver.visit_type_mut(&mut ty);
278                let mut other_ty = other_ty.clone();
279                resolver.visit_type_mut(&mut other_ty);
280                errors.push(ErrorKind::TypeMismatch(Box::new(ty), Box::new(other_ty)));
281            }
282        }
283    }
284
285    fn unify_tuples(
286        &mut self,
287        lhs: &Tuple<Prim>,
288        rhs: &Tuple<Prim>,
289        context: TupleContext,
290        mut errors: OpErrors<'_, Prim>,
291    ) {
292        let resolved_len = self.unify_lengths(lhs.len(), rhs.len(), context);
293        let resolved_len = match resolved_len {
294            Ok(len) => len,
295            Err(err) => {
296                self.unify_tuples_after_error(lhs, rhs, &err, context, errors.by_ref());
297                errors.push(err);
298                return;
299            }
300        };
301
302        if let (None, exact) = resolved_len.components() {
303            self.unify_tuple_elements(lhs.iter(exact), rhs.iter(exact), context, errors);
304        } else {
305            // TODO: is this always applicable?
306            for (lhs_elem, rhs_elem) in lhs.equal_elements_dyn(rhs) {
307                let elem_errors = errors.join_path(match context {
308                    TupleContext::Generic => ErrorPathFragment::TupleElement(None),
309                    TupleContext::FnArgs => ErrorPathFragment::FnArg(None),
310                });
311                self.unify(lhs_elem, rhs_elem, elem_errors);
312            }
313        }
314    }
315
316    #[inline]
317    fn unify_tuple_elements<'it>(
318        &mut self,
319        lhs_elements: impl Iterator<Item = &'it Type<Prim>>,
320        rhs_elements: impl Iterator<Item = &'it Type<Prim>>,
321        context: TupleContext,
322        mut errors: OpErrors<'_, Prim>,
323    ) {
324        for (i, (lhs_elem, rhs_elem)) in lhs_elements.zip(rhs_elements).enumerate() {
325            let location = context.element(i);
326            self.unify(lhs_elem, rhs_elem, errors.join_path(location));
327        }
328    }
329
330    /// Tries to unify tuple elements after an error has occurred when unifying their lengths.
331    fn unify_tuples_after_error(
332        &mut self,
333        lhs: &Tuple<Prim>,
334        rhs: &Tuple<Prim>,
335        err: &ErrorKind<Prim>,
336        context: TupleContext,
337        errors: OpErrors<'_, Prim>,
338    ) {
339        let (lhs_len, rhs_len) = match err {
340            ErrorKind::TupleLenMismatch {
341                lhs: lhs_len,
342                rhs: rhs_len,
343                ..
344            } => (*lhs_len, *rhs_len),
345            _ => return,
346        };
347        let (lhs_var, lhs_exact) = lhs_len.components();
348        let (rhs_var, rhs_exact) = rhs_len.components();
349
350        match (lhs_var, rhs_var) {
351            (None, None) => {
352                // We've attempted to unify tuples with different known lengths.
353                // Iterate over common elements and unify them.
354                debug_assert_ne!(lhs_exact, rhs_exact);
355                self.unify_tuple_elements(
356                    lhs.iter(lhs_exact),
357                    rhs.iter(rhs_exact),
358                    context,
359                    errors,
360                );
361            }
362
363            (None, Some(UnknownLen::Dynamic)) => {
364                // We've attempted to unify static LHS with a dynamic RHS
365                // e.g., `(x, y) = filter(...)`.
366                self.unify_tuple_elements(
367                    lhs.iter(lhs_exact),
368                    rhs.iter(rhs_exact),
369                    context,
370                    errors,
371                );
372            }
373
374            _ => { /* Do nothing. */ }
375        }
376    }
377
378    /// Returns the resolved length that `lhs` and `rhs` are equal to.
379    fn unify_lengths(
380        &mut self,
381        lhs: TupleLen,
382        rhs: TupleLen,
383        context: TupleContext,
384    ) -> Result<TupleLen, ErrorKind<Prim>> {
385        let resolved_lhs = self.resolve_len(lhs);
386        let resolved_rhs = self.resolve_len(rhs);
387
388        self.unify_lengths_inner(resolved_lhs, resolved_rhs)
389            .map_err(|err| match err {
390                LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
391                LenErrorKind::Mismatch => ErrorKind::TupleLenMismatch {
392                    lhs: resolved_lhs,
393                    rhs: resolved_rhs,
394                    context,
395                },
396                LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
397            })
398    }
399
400    fn unify_lengths_inner(
401        &mut self,
402        resolved_lhs: TupleLen,
403        resolved_rhs: TupleLen,
404    ) -> Result<TupleLen, LenErrorKind> {
405        let (lhs_var, lhs_exact) = resolved_lhs.components();
406        let (rhs_var, rhs_exact) = resolved_rhs.components();
407
408        // First, consider a case when at least one of resolved lengths is exact.
409        let (lhs_var, rhs_var) = match (lhs_var, rhs_var) {
410            (Some(lhs_var), Some(rhs_var)) => (lhs_var, rhs_var),
411
412            (Some(lhs_var), None) if rhs_exact >= lhs_exact => {
413                return self
414                    .unify_simple_length(lhs_var, TupleLen::from(rhs_exact - lhs_exact), true)
415                    .map(|len| len + lhs_exact);
416            }
417            (None, Some(rhs_var)) if lhs_exact >= rhs_exact => {
418                return self
419                    .unify_simple_length(rhs_var, TupleLen::from(lhs_exact - rhs_exact), false)
420                    .map(|len| len + rhs_exact);
421            }
422
423            (None, None) if lhs_exact == rhs_exact => return Ok(TupleLen::from(lhs_exact)),
424
425            _ => return Err(LenErrorKind::Mismatch),
426        };
427
428        match lhs_exact.cmp(&rhs_exact) {
429            Ordering::Equal => self.unify_simple_length(lhs_var, TupleLen::from(rhs_var), true),
430            Ordering::Greater => {
431                let reduced = lhs_var + (lhs_exact - rhs_exact);
432                self.unify_simple_length(rhs_var, reduced, false)
433                    .map(|len| len + rhs_exact)
434            }
435            Ordering::Less => {
436                let reduced = rhs_var + (rhs_exact - lhs_exact);
437                self.unify_simple_length(lhs_var, reduced, true)
438                    .map(|len| len + lhs_exact)
439            }
440        }
441    }
442
443    fn unify_simple_length(
444        &mut self,
445        simple_len: UnknownLen,
446        source: TupleLen,
447        is_lhs: bool,
448    ) -> Result<TupleLen, LenErrorKind> {
449        match simple_len {
450            UnknownLen::Var(var) if var.is_free() => self.unify_var_length(var.index(), source),
451            UnknownLen::Dynamic => self.unify_dyn_length(source, is_lhs),
452            _ => Err(LenErrorKind::UnresolvedParam),
453        }
454    }
455
456    #[inline]
457    fn unify_var_length(
458        &mut self,
459        var_idx: usize,
460        source: TupleLen,
461    ) -> Result<TupleLen, LenErrorKind> {
462        // Check that the source is valid.
463        match source.components() {
464            (Some(UnknownLen::Var(var)), _) if !var.is_free() => Err(LenErrorKind::UnresolvedParam),
465
466            // Special case is uniting a var with self.
467            (Some(UnknownLen::Var(var)), offset) if var.index() == var_idx => {
468                if offset == 0 {
469                    Ok(source)
470                } else {
471                    Err(LenErrorKind::Mismatch)
472                }
473            }
474
475            _ => {
476                if self.static_lengths.contains(&var_idx) {
477                    self.apply_static_len_inner(source)?;
478                }
479                self.length_eqs.insert(var_idx, source);
480                Ok(source)
481            }
482        }
483    }
484
485    #[inline]
486    fn unify_dyn_length(
487        &mut self,
488        source: TupleLen,
489        is_lhs: bool,
490    ) -> Result<TupleLen, LenErrorKind> {
491        if is_lhs {
492            Ok(source) // assignment to dyn length always succeeds
493        } else {
494            let source_var_idx = match source.components() {
495                (Some(UnknownLen::Var(var)), 0) if var.is_free() => var.index(),
496                (Some(UnknownLen::Dynamic), 0) => return Ok(source),
497                _ => return Err(LenErrorKind::Mismatch),
498            };
499            self.unify_var_length(source_var_idx, UnknownLen::Dynamic.into())
500        }
501    }
502
503    fn unify_objects(
504        &mut self,
505        lhs: &Object<Prim>,
506        rhs: &Object<Prim>,
507        mut errors: OpErrors<'_, Prim>,
508    ) {
509        let lhs_fields: HashSet<_> = lhs.field_names().collect();
510        let rhs_fields: HashSet<_> = rhs.field_names().collect();
511
512        if lhs_fields == rhs_fields {
513            for (field_name, ty) in lhs.iter() {
514                self.unify(ty, &rhs[field_name], errors.join_path(field_name));
515            }
516        } else {
517            errors.push(ErrorKind::FieldsMismatch {
518                lhs_fields: lhs_fields.into_iter().map(String::from).collect(),
519                rhs_fields: rhs_fields.into_iter().map(String::from).collect(),
520            });
521        }
522    }
523
524    fn unify_fn_types(
525        &mut self,
526        lhs: &Function<Prim>,
527        rhs: &Function<Prim>,
528        mut errors: OpErrors<'_, Prim>,
529    ) {
530        if lhs.is_parametric() {
531            errors.push(ErrorKind::UnsupportedParam);
532            return;
533        }
534
535        let instantiated_lhs = self.instantiate_function(lhs);
536        let instantiated_rhs = self.instantiate_function(rhs);
537
538        // Swapping args is intentional. To see why, consider a function
539        // `fn(T, U) -> V` called as `fn(A, B) -> C` (`T`, ... `C` are types).
540        // In this case, the first arg of actual type `A` will be assigned to type `T`
541        // (i.e., `T` is LHS and `A` is RHS); same with `U` and `B`. In contrast,
542        // after function execution the return value of type `V` will be assigned
543        // to type `C`. (I.e., unification of return values is not swapped.)
544        self.unify_tuples(
545            &instantiated_rhs.args,
546            &instantiated_lhs.args,
547            TupleContext::FnArgs,
548            errors.by_ref(),
549        );
550
551        self.unify(
552            &instantiated_lhs.return_type,
553            &instantiated_rhs.return_type,
554            errors.join_path(ErrorPathFragment::FnReturnType),
555        );
556    }
557
558    /// Instantiates a functional type by replacing all type arguments with new type vars.
559    fn instantiate_function(&mut self, fn_type: &Function<Prim>) -> Function<Prim> {
560        if !fn_type.is_parametric() {
561            // Fast path: just clone the function type.
562            return fn_type.clone();
563        }
564        let fn_params = fn_type.params.as_ref().expect("fn with params");
565
566        // Map type vars in the function into newly created type vars.
567        let mapping = ParamMapping {
568            types: fn_params
569                .type_params
570                .iter()
571                .enumerate()
572                .map(|(i, (var_idx, _))| (*var_idx, self.type_var_count + i))
573                .collect(),
574            lengths: fn_params
575                .len_params
576                .iter()
577                .enumerate()
578                .map(|(i, (var_idx, _))| (*var_idx, self.len_var_count + i))
579                .collect(),
580        };
581        self.type_var_count += fn_params.type_params.len();
582        self.len_var_count += fn_params.len_params.len();
583
584        let mut instantiated_fn_type = fn_type.clone();
585        MonoTypeTransformer::transform(&mapping, &mut instantiated_fn_type);
586
587        // Copy constraints on the newly generated length and type vars
588        // from the function definition.
589        for (original_idx, is_static) in &fn_params.len_params {
590            if *is_static {
591                let new_idx = mapping.lengths[original_idx];
592                self.static_lengths.insert(new_idx);
593            }
594        }
595        for (original_idx, constraints) in &fn_params.type_params {
596            let new_idx = mapping.types[original_idx];
597            let mono_constraints =
598                MonoTypeTransformer::transform_constraints(&mapping, constraints);
599            self.constraints.insert(new_idx, mono_constraints);
600        }
601
602        instantiated_fn_type
603    }
604
605    /// Unifies a type variable with the specified index and the specified type.
606    fn unify_var(
607        &mut self,
608        var_idx: usize,
609        ty: &Type<Prim>,
610        is_lhs: bool,
611        mut errors: OpErrors<'_, Prim>,
612    ) {
613        // Variables should be resolved in `unify`.
614        debug_assert!(is_lhs || !matches!(ty, Type::Any | Type::Dyn(_)));
615        debug_assert!(!self.eqs.contains_key(&var_idx));
616        debug_assert!(if let Type::Var(var) = ty {
617            !self.eqs.contains_key(&var.index())
618        } else {
619            true
620        });
621
622        if let Type::Var(var) = ty {
623            if !var.is_free() {
624                errors.push(ErrorKind::UnresolvedParam);
625                return;
626            } else if var.index() == var_idx {
627                return;
628            }
629        }
630
631        let mut checker = OccurrenceChecker::new(self, iter::once(var_idx));
632        checker.visit_type(ty);
633
634        if let Some(var) = checker.recursive_var {
635            self.handle_recursive_type(ty.clone(), var, &mut errors);
636        } else {
637            let mut ty = ty.clone();
638            if !is_lhs {
639                // We need to swap `any` types / lengths with new vars so that this type
640                // can be specified further.
641                TypeSpecifier::new(self).visit_type_mut(&mut ty);
642            }
643            self.eqs.insert(var_idx, ty.clone());
644
645            // Constraints need to be applied *after* adding a type equation in order to
646            // account for recursive constraints (e.g., object ones) - otherwise,
647            // constraints on some type vars may be lost.
648            // TODO: is it possible (or necessary?) to detect recursion in order to avoid cloning?
649            if let Some(constraints) = self.constraints.get(&var_idx).cloned() {
650                constraints.apply_all(&ty, self, errors);
651            }
652        }
653    }
654}
655
656/// Checks if a type variable with the specified index is present in `ty`. This method
657/// is used to check that types are not recursive.
658#[derive(Debug)]
659struct OccurrenceChecker<'a, Prim: PrimitiveType> {
660    substitutions: &'a Substitutions<Prim>,
661    var_indexes: HashSet<usize>,
662    recursive_var: Option<usize>,
663}
664
665impl<'a, Prim: PrimitiveType> OccurrenceChecker<'a, Prim> {
666    fn new(
667        substitutions: &'a Substitutions<Prim>,
668        var_indexes: impl IntoIterator<Item = usize>,
669    ) -> Self {
670        Self {
671            substitutions,
672            var_indexes: var_indexes.into_iter().collect(),
673            recursive_var: None,
674        }
675    }
676}
677
678impl<Prim: PrimitiveType> Visit<Prim> for OccurrenceChecker<'_, Prim> {
679    fn visit_type(&mut self, ty: &Type<Prim>) {
680        if self.recursive_var.is_some() {
681            // Skip recursion; we already have our answer at this point.
682        } else {
683            visit::visit_type(self, ty);
684        }
685    }
686
687    fn visit_var(&mut self, var: TypeVar) {
688        if !var.is_free() {
689            // Can happen with assigned generic functions, e.g., `reduce = fold; ...`.
690            return;
691        }
692
693        let var_idx = var.index();
694        if self.var_indexes.contains(&var_idx) {
695            self.recursive_var = Some(var_idx);
696        } else if let Some(ty) = self.substitutions.eqs.get(&var_idx) {
697            self.visit_type(ty);
698        }
699        // TODO: we don't check object constraints since they are fine (probably).
700    }
701}
702
703/// Removes excessive information about type vars. This method is used when types are
704/// provided to `Error`.
705#[derive(Debug)]
706struct TypeSanitizer {
707    fixed_idx: usize,
708}
709
710impl TypeSanitizer {
711    fn new(fixed_idx: usize) -> Self {
712        Self { fixed_idx }
713    }
714}
715
716impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSanitizer {
717    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
718        match ty {
719            Type::Var(var) if var.index() == self.fixed_idx => {
720                *ty = Type::param(0);
721            }
722            _ => visit::visit_type_mut(self, ty),
723        }
724    }
725}
726
727/// Visitor that performs type resolution based on `Substitutions`.
728#[derive(Debug, Clone, Copy)]
729struct TypeResolver<'a, Prim: PrimitiveType> {
730    substitutions: &'a Substitutions<Prim>,
731}
732
733impl<Prim: PrimitiveType> VisitMut<Prim> for TypeResolver<'_, Prim> {
734    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
735        let fast_resolved = self.substitutions.fast_resolve(ty);
736        if !ptr::eq(ty, fast_resolved) {
737            *ty = fast_resolved.clone();
738        }
739        visit::visit_type_mut(self, ty);
740    }
741
742    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
743        *len = self.substitutions.resolve_len(*len);
744    }
745}
746
747#[derive(Debug, Clone, Copy, PartialEq)]
748enum Variance {
749    Co,
750    Contra,
751}
752
753impl ops::Not for Variance {
754    type Output = Self;
755
756    fn not(self) -> Self {
757        match self {
758            Self::Co => Self::Contra,
759            Self::Contra => Self::Co,
760        }
761    }
762}
763
764/// Visitor that swaps `any` types / lengths with new vars, but only if they are in a covariant
765/// position (return types, args of function args, etc.).
766///
767/// This is used when assigning to a type containing `any`.
768#[derive(Debug)]
769struct TypeSpecifier<'a, Prim: PrimitiveType> {
770    substitutions: &'a mut Substitutions<Prim>,
771    variance: Variance,
772}
773
774impl<'a, Prim: PrimitiveType> TypeSpecifier<'a, Prim> {
775    fn new(substitutions: &'a mut Substitutions<Prim>) -> Self {
776        Self {
777            substitutions,
778            variance: Variance::Co,
779        }
780    }
781}
782
783impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSpecifier<'_, Prim> {
784    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
785        match ty {
786            Type::Any if self.variance == Variance::Co => {
787                *ty = self.substitutions.new_type_var();
788            }
789
790            Type::Dyn(constraints) if self.variance == Variance::Co => {
791                let var_idx = self.substitutions.type_var_count;
792                self.substitutions
793                    .constraints
794                    .insert(var_idx, constraints.inner.clone());
795                *ty = Type::free_var(var_idx);
796                self.substitutions.type_var_count += 1;
797            }
798
799            _ => visit::visit_type_mut(self, ty),
800        }
801    }
802
803    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
804        if self.variance != Variance::Co {
805            return;
806        }
807        if let (Some(var_len @ UnknownLen::Dynamic), _) = len.components_mut() {
808            *var_len = self.substitutions.new_len_var();
809        }
810    }
811
812    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
813        // Since the visiting order doesn't matter, we visit the return type (which preserves
814        // variance) first.
815        self.visit_type_mut(&mut function.return_type);
816
817        let old_variance = self.variance;
818        self.variance = !self.variance;
819        self.visit_tuple_mut(&mut function.args);
820        self.variance = old_variance;
821    }
822}