arithmetic_typing/arith/substitutions/
fns.rs

1//! Functional type substitutions.
2
3use crate::{
4    alloc::{hash_map::Entry, Arc, HashMap},
5    arith::{CompleteConstraints, Substitutions},
6    types::{FnParams, ParamConstraints, ParamQuantifier},
7    visit::{self, VisitMut},
8    Function, Object, PrimitiveType, TupleLen, Type, UnknownLen,
9};
10
11impl<Prim: PrimitiveType> Function<Prim> {
12    /// Performs final transformations on this function, bounding all of its type vars
13    /// to the function or its child functions.
14    pub(crate) fn finalize(&mut self, substitutions: &Substitutions<Prim>) {
15        // 1. Replace `Var`s with `Param`s.
16        let mut transformer = PolyTypeTransformer::new(substitutions);
17        transformer.visit_function_mut(self);
18        let mapping = transformer.mapping;
19        let mut resolved_objects = transformer.resolved_objects;
20
21        // 2. Extract constraints on type params and lengths.
22        let type_params = mapping
23            .types
24            .into_iter()
25            .filter_map(|(var_idx, param_idx)| {
26                let constraints = substitutions.constraints.get(&var_idx);
27                constraints
28                    .filter(|constraints| !constraints.is_empty())
29                    .cloned()
30                    .map(|constraints| {
31                        let resolved = constraints.map_object(|object| {
32                            if let Some(resolved) = resolved_objects.remove(&var_idx) {
33                                *object = resolved;
34                            }
35                        });
36                        (param_idx, resolved)
37                    })
38            })
39            .collect();
40
41        let static_lengths = mapping
42            .lengths
43            .into_iter()
44            .filter_map(|(var_idx, param_idx)| {
45                if substitutions.static_lengths.contains(&var_idx) {
46                    Some(param_idx)
47                } else {
48                    None
49                }
50            })
51            .collect();
52
53        // 3. Set constraints for the function.
54        ParamQuantifier::fill_params(
55            self,
56            ParamConstraints {
57                type_params,
58                static_lengths,
59            },
60        );
61    }
62}
63
64#[derive(Debug, Default)]
65pub(super) struct ParamMapping {
66    pub types: HashMap<usize, usize>,
67    pub lengths: HashMap<usize, usize>,
68}
69
70/// Replaces `Var`s with `Param`s and creates the corresponding `mapping`.
71#[derive(Debug)]
72struct PolyTypeTransformer<'a, Prim: PrimitiveType> {
73    mapping: ParamMapping,
74    resolved_objects: HashMap<usize, Object<Prim>>,
75    substitutions: &'a Substitutions<Prim>,
76}
77
78impl<'a, Prim: PrimitiveType> PolyTypeTransformer<'a, Prim> {
79    fn new(substitutions: &'a Substitutions<Prim>) -> Self {
80        Self {
81            mapping: ParamMapping::default(),
82            resolved_objects: HashMap::new(),
83            substitutions,
84        }
85    }
86
87    fn object_constraint(&self, var_idx: usize) -> Option<&'a Object<Prim>> {
88        let constraints = self.substitutions.constraints.get(&var_idx)?;
89        constraints.object.as_ref()
90    }
91}
92
93impl<Prim: PrimitiveType> VisitMut<Prim> for PolyTypeTransformer<'_, Prim> {
94    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
95        match ty {
96            Type::Var(var) if var.is_free() => {
97                let type_count = self.mapping.types.len();
98                let var_idx = var.index();
99                let entry = self.mapping.types.entry(var_idx);
100                let is_new_var = matches!(entry, Entry::Vacant(_));
101                let param_idx = *entry.or_insert(type_count);
102                *ty = Type::param(param_idx);
103
104                if is_new_var {
105                    // Resolve object constraints only when we're visiting the variable the
106                    // first time.
107                    if let Some(object) = self.object_constraint(var_idx) {
108                        let mut resolved_object = object.clone();
109                        self.substitutions
110                            .resolver()
111                            .visit_object_mut(&mut resolved_object);
112                        self.visit_object_mut(&mut resolved_object);
113                        self.resolved_objects.insert(var_idx, resolved_object);
114                    }
115                }
116            }
117            _ => visit::visit_type_mut(self, ty),
118        }
119    }
120
121    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
122        let (Some(target_len), _) = len.components_mut() else {
123            return;
124        };
125        if let UnknownLen::Var(var) = target_len {
126            debug_assert!(var.is_free());
127            let len_count = self.mapping.lengths.len();
128            let param_idx = *self.mapping.lengths.entry(var.index()).or_insert(len_count);
129            *target_len = UnknownLen::param(param_idx);
130        }
131    }
132}
133
134/// Makes functional types monomorphic by replacing type / length params with vars.
135#[derive(Debug)]
136pub(super) struct MonoTypeTransformer<'a> {
137    mapping: &'a ParamMapping,
138}
139
140impl<'a> MonoTypeTransformer<'a> {
141    pub fn transform<Prim: PrimitiveType>(
142        mapping: &'a ParamMapping,
143        function: &mut Function<Prim>,
144    ) {
145        function.params = None;
146        Self { mapping }.visit_function_mut(function);
147    }
148
149    pub fn transform_constraints<Prim: PrimitiveType>(
150        mapping: &'a ParamMapping,
151        constraints: &CompleteConstraints<Prim>,
152    ) -> CompleteConstraints<Prim> {
153        constraints.clone().map_object(|object| {
154            Self { mapping }.visit_object_mut(object);
155        })
156    }
157}
158
159impl<Prim: PrimitiveType> VisitMut<Prim> for MonoTypeTransformer<'_> {
160    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
161        match ty {
162            Type::Var(var) if !var.is_free() => {
163                if let Some(mapped_idx) = self.mapping.types.get(&var.index()) {
164                    *ty = Type::free_var(*mapped_idx);
165                }
166            }
167            _ => visit::visit_type_mut(self, ty),
168        }
169    }
170
171    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
172        let (Some(target_len), _) = len.components_mut() else {
173            return;
174        };
175        if let UnknownLen::Var(var) = target_len {
176            if !var.is_free() {
177                if let Some(mapped_len) = self.mapping.lengths.get(&var.index()) {
178                    *target_len = UnknownLen::free_var(*mapped_len);
179                }
180            }
181        }
182    }
183
184    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
185        visit::visit_function_mut(self, function);
186
187        if let Some(params) = function.params.as_deref() {
188            // TODO: make this check more precise?
189            let needs_modifying = params
190                .type_params
191                .iter()
192                .any(|(_, type_params)| type_params.object.is_some());
193
194            // We need to monomorphize types in the object constraint as well.
195            if needs_modifying {
196                let mapped_params = params.type_params.iter().map(|(i, constraints)| {
197                    let mapped_constraints = constraints
198                        .clone()
199                        .map_object(|object| self.visit_object_mut(object));
200                    (*i, mapped_constraints)
201                });
202                function.params = Some(Arc::new(FnParams {
203                    type_params: mapped_params.collect(),
204                    len_params: params.len_params.clone(),
205                    constraints: None,
206                }));
207            }
208        }
209    }
210}