1use core::fmt;
4
5use crate::{
6 alloc::{Arc, HashMap, HashSet, Vec},
7 arith::{CompleteConstraints, Constraint, ConstraintSet, Num},
8 types::ParamQuantifier,
9 LengthVar, PrimitiveType, Tuple, TupleLen, Type, TypeVar,
10};
11
12#[derive(Debug, Clone)]
13pub(crate) struct ParamConstraints<Prim: PrimitiveType> {
14 pub type_params: HashMap<usize, CompleteConstraints<Prim>>,
15 pub static_lengths: HashSet<usize>,
16}
17
18impl<Prim: PrimitiveType> Default for ParamConstraints<Prim> {
19 fn default() -> Self {
20 Self {
21 type_params: HashMap::new(),
22 static_lengths: HashSet::new(),
23 }
24 }
25}
26
27impl<Prim: PrimitiveType> fmt::Display for ParamConstraints<Prim> {
28 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
29 if !self.static_lengths.is_empty() {
30 formatter.write_str("len! ")?;
31 for (i, len) in self.static_lengths.iter().enumerate() {
32 write!(formatter, "{}", LengthVar::param_str(*len))?;
33 if i + 1 < self.static_lengths.len() {
34 formatter.write_str(", ")?;
35 }
36 }
37
38 if !self.type_params.is_empty() {
39 formatter.write_str("; ")?;
40 }
41 }
42
43 let type_param_count = self.type_params.len();
44 for (i, (idx, constraints)) in self.type_params().enumerate() {
45 write!(formatter, "'{}: {constraints}", TypeVar::param_str(idx))?;
46 if i + 1 < type_param_count {
47 formatter.write_str(", ")?;
48 }
49 }
50
51 Ok(())
52 }
53}
54
55impl<Prim: PrimitiveType> ParamConstraints<Prim> {
56 fn is_empty(&self) -> bool {
57 self.type_params.is_empty() && self.static_lengths.is_empty()
58 }
59
60 fn type_params(&self) -> impl Iterator<Item = (usize, &CompleteConstraints<Prim>)> + '_ {
61 let mut type_params: Vec<_> = self.type_params.iter().map(|(&idx, c)| (idx, c)).collect();
62 type_params.sort_unstable_by_key(|(idx, _)| *idx);
63 type_params.into_iter()
64 }
65}
66
67#[derive(Debug)]
68pub(crate) struct FnParams<Prim: PrimitiveType> {
69 pub type_params: Vec<(usize, CompleteConstraints<Prim>)>,
71 pub len_params: Vec<(usize, bool)>,
73 pub constraints: Option<ParamConstraints<Prim>>,
75}
76
77impl<Prim: PrimitiveType> Default for FnParams<Prim> {
78 fn default() -> Self {
79 Self {
80 type_params: Vec::new(),
81 len_params: Vec::new(),
82 constraints: None,
83 }
84 }
85}
86
87impl<Prim: PrimitiveType> PartialEq for FnParams<Prim> {
88 fn eq(&self, other: &Self) -> bool {
89 self.type_params == other.type_params && self.len_params == other.len_params
90 }
91}
92
93impl<Prim: PrimitiveType> FnParams<Prim> {
94 fn is_empty(&self) -> bool {
95 self.len_params.is_empty() && self.type_params.is_empty()
96 }
97}
98
99#[derive(Debug, Clone, PartialEq)]
165pub struct Function<Prim: PrimitiveType = Num> {
166 pub(crate) args: Tuple<Prim>,
168 pub(crate) return_type: Type<Prim>,
170 pub(crate) params: Option<Arc<FnParams<Prim>>>,
172}
173
174impl<Prim: PrimitiveType> fmt::Display for Function<Prim> {
175 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
176 let constraints = self
177 .params
178 .as_ref()
179 .and_then(|params| params.constraints.as_ref());
180 if let Some(constraints) = constraints {
181 if !constraints.is_empty() {
182 write!(formatter, "for<{constraints}> ")?;
183 }
184 }
185
186 self.args.format_as_tuple(formatter)?;
187 write!(formatter, " -> {}", self.return_type)?;
188 Ok(())
189 }
190}
191
192impl<Prim: PrimitiveType> Function<Prim> {
193 pub(crate) fn new(args: Tuple<Prim>, return_type: Type<Prim>) -> Self {
194 Self {
195 args,
196 return_type,
197 params: None,
198 }
199 }
200
201 pub fn builder() -> FunctionBuilder<Prim> {
203 FunctionBuilder::default()
204 }
205
206 pub fn args(&self) -> &Tuple<Prim> {
208 &self.args
209 }
210
211 pub fn return_type(&self) -> &Type<Prim> {
213 &self.return_type
214 }
215
216 pub(crate) fn set_params(&mut self, params: FnParams<Prim>) {
217 self.params = Some(Arc::new(params));
218 }
219
220 pub(crate) fn is_parametric(&self) -> bool {
221 self.params
222 .as_ref()
223 .is_some_and(|params| !params.is_empty())
224 }
225
226 pub fn is_concrete(&self) -> bool {
231 self.args.is_concrete() && self.return_type.is_concrete()
232 }
233
234 pub fn with_constraints<C: Constraint<Prim>>(
240 self,
241 indexes: &[usize],
242 constraint: C,
243 ) -> FnWithConstraints<Prim> {
244 assert!(
245 self.params.is_none(),
246 "Cannot attach constraints to a function with computed params: `{self}`"
247 );
248
249 let constraints = CompleteConstraints::from(ConstraintSet::just(constraint));
250 let type_params = indexes
251 .iter()
252 .map(|&idx| (idx, constraints.clone()))
253 .collect();
254
255 FnWithConstraints {
256 function: self,
257 constraints: ParamConstraints {
258 type_params,
259 static_lengths: HashSet::new(),
260 },
261 }
262 }
263
264 pub fn with_static_lengths(self, indexes: &[usize]) -> FnWithConstraints<Prim> {
270 assert!(
271 self.params.is_none(),
272 "Cannot attach constraints to a function with computed params: `{self}`"
273 );
274
275 FnWithConstraints {
276 function: self,
277 constraints: ParamConstraints {
278 type_params: HashMap::new(),
279 static_lengths: indexes.iter().copied().collect(),
280 },
281 }
282 }
283}
284
285#[derive(Debug)]
290pub struct FnWithConstraints<Prim: PrimitiveType> {
291 function: Function<Prim>,
292 constraints: ParamConstraints<Prim>,
293}
294
295impl<Prim: PrimitiveType> fmt::Display for FnWithConstraints<Prim> {
296 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
297 if self.constraints.is_empty() {
298 fmt::Display::fmt(&self.function, formatter)
299 } else {
300 write!(formatter, "for<{}> {}", self.constraints, self.function)
301 }
302 }
303}
304
305impl<Prim: PrimitiveType> FnWithConstraints<Prim> {
306 #[must_use]
309 pub fn with_constraint<C>(mut self, indexes: &[usize], constraint: &C) -> Self
310 where
311 C: Constraint<Prim> + Clone,
312 {
313 for &i in indexes {
314 let constraints = self.constraints.type_params.entry(i).or_default();
315 constraints.simple.insert(constraint.clone());
316 }
317 self
318 }
319
320 #[must_use]
322 pub fn with_static_lengths(mut self, indexes: &[usize]) -> Self {
323 let indexes = indexes.iter().copied();
324 self.constraints.static_lengths.extend(indexes);
325 self
326 }
327}
328
329impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Function<Prim> {
330 fn from(value: FnWithConstraints<Prim>) -> Self {
331 let mut function = value.function;
332 ParamQuantifier::fill_params(&mut function, value.constraints);
333 function
334 }
335}
336
337impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Type<Prim> {
338 fn from(value: FnWithConstraints<Prim>) -> Self {
339 Function::from(value).into()
340 }
341}
342
343#[derive(Debug, Clone)]
391#[must_use]
392pub struct FunctionBuilder<Prim: PrimitiveType = Num> {
393 args: Tuple<Prim>,
394}
395
396impl<Prim: PrimitiveType> Default for FunctionBuilder<Prim> {
397 fn default() -> Self {
398 Self {
399 args: Tuple::empty(),
400 }
401 }
402}
403
404impl<Prim: PrimitiveType> FunctionBuilder<Prim> {
405 pub fn with_arg(mut self, arg: impl Into<Type<Prim>>) -> Self {
407 self.args.push(arg.into());
408 self
409 }
410
411 pub fn with_varargs(
413 mut self,
414 element: impl Into<Type<Prim>>,
415 len: impl Into<TupleLen>,
416 ) -> Self {
417 self.args.set_middle(element.into(), len.into());
418 self
419 }
420
421 pub fn returning(self, return_type: impl Into<Type<Prim>>) -> Function<Prim> {
423 Function::new(self.args, return_type.into())
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use core::iter;
430
431 use super::*;
432 use crate::{alloc::ToString, arith::Linearity, UnknownLen};
433
434 #[test]
435 fn constraints_display() {
436 let type_constraints = ConstraintSet::<Num>::just(Linearity);
437 let type_constraints = CompleteConstraints::from(type_constraints);
438
439 let type_params = (0, type_constraints);
440 let constraints = ParamConstraints {
441 type_params: iter::once(type_params.clone()).collect(),
442 static_lengths: HashSet::new(),
443 };
444 assert_eq!(constraints.to_string(), "'T: Lin");
445
446 let constraints: ParamConstraints<Num> = ParamConstraints {
447 type_params: iter::once(type_params).collect(),
448 static_lengths: iter::once(0).collect(),
449 };
450 assert_eq!(constraints.to_string(), "len! N; 'T: Lin");
451 }
452
453 #[test]
454 fn fn_with_constraints_display() {
455 let sum_fn = <Function>::builder()
456 .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
457 .returning(Type::param(0))
458 .with_constraints(&[0], Linearity);
459 assert_eq!(sum_fn.to_string(), "for<'T: Lin> (['T; N]) -> 'T");
460 }
461
462 #[test]
463 fn fn_builder_with_quantified_arg() {
464 let sum_fn: Function = Function::builder()
465 .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
466 .returning(Type::NUM)
467 .with_constraints(&[], Linearity)
468 .into();
469 assert_eq!(sum_fn.to_string(), "([Num; N]) -> Num");
470
471 let complex_fn: Function = Function::builder()
472 .with_arg(Type::NUM)
473 .with_arg(sum_fn.clone())
474 .returning(Type::NUM)
475 .with_constraints(&[], Linearity)
476 .into();
477 assert_eq!(complex_fn.to_string(), "(Num, ([Num; N]) -> Num) -> Num");
478
479 let other_complex_fn: Function = Function::builder()
480 .with_varargs(Type::NUM, UnknownLen::param(0))
481 .with_arg(sum_fn)
482 .returning(Type::NUM)
483 .with_constraints(&[], Linearity)
484 .into();
485 assert_eq!(
486 other_complex_fn.to_string(),
487 "(...[Num; N], ([Num; N]) -> Num) -> Num"
488 );
489 }
490}