arithmetic_typing/arith/substitutions/
fns.rs1use crate::{
4 alloc::{hash_map::Entry, Arc, HashMap},
5 arith::{CompleteConstraints, Substitutions},
6 types::{FnParams, ParamConstraints, ParamQuantifier},
7 visit::{self, VisitMut},
8 Function, Object, PrimitiveType, TupleLen, Type, UnknownLen,
9};
10
11impl<Prim: PrimitiveType> Function<Prim> {
12 pub(crate) fn finalize(&mut self, substitutions: &Substitutions<Prim>) {
15 let mut transformer = PolyTypeTransformer::new(substitutions);
17 transformer.visit_function_mut(self);
18 let mapping = transformer.mapping;
19 let mut resolved_objects = transformer.resolved_objects;
20
21 let type_params = mapping
23 .types
24 .into_iter()
25 .filter_map(|(var_idx, param_idx)| {
26 let constraints = substitutions.constraints.get(&var_idx);
27 constraints
28 .filter(|constraints| !constraints.is_empty())
29 .cloned()
30 .map(|constraints| {
31 let resolved = constraints.map_object(|object| {
32 if let Some(resolved) = resolved_objects.remove(&var_idx) {
33 *object = resolved;
34 }
35 });
36 (param_idx, resolved)
37 })
38 })
39 .collect();
40
41 let static_lengths = mapping
42 .lengths
43 .into_iter()
44 .filter_map(|(var_idx, param_idx)| {
45 if substitutions.static_lengths.contains(&var_idx) {
46 Some(param_idx)
47 } else {
48 None
49 }
50 })
51 .collect();
52
53 ParamQuantifier::fill_params(
55 self,
56 ParamConstraints {
57 type_params,
58 static_lengths,
59 },
60 );
61 }
62}
63
64#[derive(Debug, Default)]
65pub(super) struct ParamMapping {
66 pub types: HashMap<usize, usize>,
67 pub lengths: HashMap<usize, usize>,
68}
69
70#[derive(Debug)]
72struct PolyTypeTransformer<'a, Prim: PrimitiveType> {
73 mapping: ParamMapping,
74 resolved_objects: HashMap<usize, Object<Prim>>,
75 substitutions: &'a Substitutions<Prim>,
76}
77
78impl<'a, Prim: PrimitiveType> PolyTypeTransformer<'a, Prim> {
79 fn new(substitutions: &'a Substitutions<Prim>) -> Self {
80 Self {
81 mapping: ParamMapping::default(),
82 resolved_objects: HashMap::new(),
83 substitutions,
84 }
85 }
86
87 fn object_constraint(&self, var_idx: usize) -> Option<&'a Object<Prim>> {
88 let constraints = self.substitutions.constraints.get(&var_idx)?;
89 constraints.object.as_ref()
90 }
91}
92
93impl<Prim: PrimitiveType> VisitMut<Prim> for PolyTypeTransformer<'_, Prim> {
94 fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
95 match ty {
96 Type::Var(var) if var.is_free() => {
97 let type_count = self.mapping.types.len();
98 let var_idx = var.index();
99 let entry = self.mapping.types.entry(var_idx);
100 let is_new_var = matches!(entry, Entry::Vacant(_));
101 let param_idx = *entry.or_insert(type_count);
102 *ty = Type::param(param_idx);
103
104 if is_new_var {
105 if let Some(object) = self.object_constraint(var_idx) {
108 let mut resolved_object = object.clone();
109 self.substitutions
110 .resolver()
111 .visit_object_mut(&mut resolved_object);
112 self.visit_object_mut(&mut resolved_object);
113 self.resolved_objects.insert(var_idx, resolved_object);
114 }
115 }
116 }
117 _ => visit::visit_type_mut(self, ty),
118 }
119 }
120
121 fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
122 let (Some(target_len), _) = len.components_mut() else {
123 return;
124 };
125 if let UnknownLen::Var(var) = target_len {
126 debug_assert!(var.is_free());
127 let len_count = self.mapping.lengths.len();
128 let param_idx = *self.mapping.lengths.entry(var.index()).or_insert(len_count);
129 *target_len = UnknownLen::param(param_idx);
130 }
131 }
132}
133
134#[derive(Debug)]
136pub(super) struct MonoTypeTransformer<'a> {
137 mapping: &'a ParamMapping,
138}
139
140impl<'a> MonoTypeTransformer<'a> {
141 pub fn transform<Prim: PrimitiveType>(
142 mapping: &'a ParamMapping,
143 function: &mut Function<Prim>,
144 ) {
145 function.params = None;
146 Self { mapping }.visit_function_mut(function);
147 }
148
149 pub fn transform_constraints<Prim: PrimitiveType>(
150 mapping: &'a ParamMapping,
151 constraints: &CompleteConstraints<Prim>,
152 ) -> CompleteConstraints<Prim> {
153 constraints.clone().map_object(|object| {
154 Self { mapping }.visit_object_mut(object);
155 })
156 }
157}
158
159impl<Prim: PrimitiveType> VisitMut<Prim> for MonoTypeTransformer<'_> {
160 fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
161 match ty {
162 Type::Var(var) if !var.is_free() => {
163 if let Some(mapped_idx) = self.mapping.types.get(&var.index()) {
164 *ty = Type::free_var(*mapped_idx);
165 }
166 }
167 _ => visit::visit_type_mut(self, ty),
168 }
169 }
170
171 fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
172 let (Some(target_len), _) = len.components_mut() else {
173 return;
174 };
175 if let UnknownLen::Var(var) = target_len {
176 if !var.is_free() {
177 if let Some(mapped_len) = self.mapping.lengths.get(&var.index()) {
178 *target_len = UnknownLen::free_var(*mapped_len);
179 }
180 }
181 }
182 }
183
184 fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
185 visit::visit_function_mut(self, function);
186
187 if let Some(params) = function.params.as_deref() {
188 let needs_modifying = params
190 .type_params
191 .iter()
192 .any(|(_, type_params)| type_params.object.is_some());
193
194 if needs_modifying {
196 let mapped_params = params.type_params.iter().map(|(i, constraints)| {
197 let mapped_constraints = constraints
198 .clone()
199 .map_object(|object| self.visit_object_mut(object));
200 (*i, mapped_constraints)
201 });
202 function.params = Some(Arc::new(FnParams {
203 type_params: mapped_params.collect(),
204 len_params: params.len_params.clone(),
205 constraints: None,
206 }));
207 }
208 }
209 }
210}