arithmetic_typing/types/
fn_type.rs

1//! Functional type (`Function`) and closely related types.
2
3use core::fmt;
4
5use crate::{
6    alloc::{Arc, HashMap, HashSet, Vec},
7    arith::{CompleteConstraints, Constraint, ConstraintSet, Num},
8    types::ParamQuantifier,
9    LengthVar, PrimitiveType, Tuple, TupleLen, Type, TypeVar,
10};
11
12#[derive(Debug, Clone)]
13pub(crate) struct ParamConstraints<Prim: PrimitiveType> {
14    pub type_params: HashMap<usize, CompleteConstraints<Prim>>,
15    pub static_lengths: HashSet<usize>,
16}
17
18impl<Prim: PrimitiveType> Default for ParamConstraints<Prim> {
19    fn default() -> Self {
20        Self {
21            type_params: HashMap::new(),
22            static_lengths: HashSet::new(),
23        }
24    }
25}
26
27impl<Prim: PrimitiveType> fmt::Display for ParamConstraints<Prim> {
28    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
29        if !self.static_lengths.is_empty() {
30            formatter.write_str("len! ")?;
31            for (i, len) in self.static_lengths.iter().enumerate() {
32                write!(formatter, "{}", LengthVar::param_str(*len))?;
33                if i + 1 < self.static_lengths.len() {
34                    formatter.write_str(", ")?;
35                }
36            }
37
38            if !self.type_params.is_empty() {
39                formatter.write_str("; ")?;
40            }
41        }
42
43        let type_param_count = self.type_params.len();
44        for (i, (idx, constraints)) in self.type_params().enumerate() {
45            write!(formatter, "'{}: {constraints}", TypeVar::param_str(idx))?;
46            if i + 1 < type_param_count {
47                formatter.write_str(", ")?;
48            }
49        }
50
51        Ok(())
52    }
53}
54
55impl<Prim: PrimitiveType> ParamConstraints<Prim> {
56    fn is_empty(&self) -> bool {
57        self.type_params.is_empty() && self.static_lengths.is_empty()
58    }
59
60    fn type_params(&self) -> impl Iterator<Item = (usize, &CompleteConstraints<Prim>)> + '_ {
61        let mut type_params: Vec<_> = self.type_params.iter().map(|(&idx, c)| (idx, c)).collect();
62        type_params.sort_unstable_by_key(|(idx, _)| *idx);
63        type_params.into_iter()
64    }
65}
66
67#[derive(Debug)]
68pub(crate) struct FnParams<Prim: PrimitiveType> {
69    /// Type params associated with this function. Filled in by `FnQuantifier`.
70    pub type_params: Vec<(usize, CompleteConstraints<Prim>)>,
71    /// Length params associated with this function. Filled in by `FnQuantifier`.
72    pub len_params: Vec<(usize, bool)>,
73    /// Constraints for params of this function and child functions.
74    pub constraints: Option<ParamConstraints<Prim>>,
75}
76
77impl<Prim: PrimitiveType> Default for FnParams<Prim> {
78    fn default() -> Self {
79        Self {
80            type_params: Vec::new(),
81            len_params: Vec::new(),
82            constraints: None,
83        }
84    }
85}
86
87impl<Prim: PrimitiveType> PartialEq for FnParams<Prim> {
88    fn eq(&self, other: &Self) -> bool {
89        self.type_params == other.type_params && self.len_params == other.len_params
90    }
91}
92
93impl<Prim: PrimitiveType> FnParams<Prim> {
94    fn is_empty(&self) -> bool {
95        self.len_params.is_empty() && self.type_params.is_empty()
96    }
97}
98
99/// Functional type.
100///
101/// # Notation
102///
103/// Functional types are denoted as follows:
104///
105/// ```text
106/// for<len! M; 'T: Lin> (['T; N], 'T) -> ['T; M]
107/// ```
108///
109/// Here:
110///
111/// - `len! M` and `'T: Lin` are constraints on [length params] and [type params], respectively.
112///   Length and/or type params constraints may be empty. Unconstrained type / length params
113///   (such as length `N` in the example) do not need to be mentioned.
114/// - `len! M` means that `M` is a [static length](TupleLen#static-lengths).
115/// - `Lin` is a [constraint] on the type param.
116/// - `N`, `M` and `'T` are parameter names. The args and the return type may reference these
117///   parameters.
118/// - `['T; N]` and `'T` are types of the function arguments.
119/// - `['T; M]` is the return type.
120///
121/// The `for` constraints can only be present on top-level functions, but not in functions
122/// mentioned in args / return types of other functions.
123///
124/// The `-> _` part is mandatory, even if the function returns [`Type::void()`].
125///
126/// A function may accept variable number of arguments of the same type along
127/// with other args. (This construction is known as *varargs*.) This is denoted similarly
128/// to middles in [`Tuple`]s. For example, `(...[Num; N]) -> Num` denotes a function
129/// that accepts any number of `Num` args and returns a `Num` value.
130///
131/// [length params]: crate::LengthVar
132/// [type params]: crate::TypeVar
133/// [constraint]: crate::arith::Constraint
134/// [dynamic length]: crate::TupleLen#static-lengths
135///
136/// # Construction
137///
138/// Functional types can be constructed via [`Self::builder()`] or parsed from a string.
139///
140/// With [`Self::builder()`], type / length params are *implicit*; they are computed automatically
141/// when a function or [`FnWithConstraints`] is supplied to a [`TypeEnvironment`]. Computations
142/// include both the function itself, and any child functions.
143///
144/// [`TypeEnvironment`]: crate::TypeEnvironment
145///
146/// # Examples
147///
148/// ```
149/// # use arithmetic_typing::{ast::FunctionAst, Function, Slice, Type};
150/// # use std::convert::TryFrom;
151/// # use assert_matches::assert_matches;
152/// # fn main() -> anyhow::Result<()> {
153/// let fn_type: Function = FunctionAst::try_from("([Num; N]) -> Num")?
154///     .try_convert()?;
155/// assert_eq!(*fn_type.return_type(), Type::NUM);
156/// assert_matches!(
157///     fn_type.args().parts(),
158///     ([Type::Tuple(t)], None, [])
159///         if t.as_slice().map(Slice::element) == Some(&Type::NUM)
160/// );
161/// # Ok(())
162/// # }
163/// ```
164#[derive(Debug, Clone, PartialEq)]
165pub struct Function<Prim: PrimitiveType = Num> {
166    /// Type of function arguments.
167    pub(crate) args: Tuple<Prim>,
168    /// Type of the value returned by the function.
169    pub(crate) return_type: Type<Prim>,
170    /// Cache for function params.
171    pub(crate) params: Option<Arc<FnParams<Prim>>>,
172}
173
174impl<Prim: PrimitiveType> fmt::Display for Function<Prim> {
175    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
176        let constraints = self
177            .params
178            .as_ref()
179            .and_then(|params| params.constraints.as_ref());
180        if let Some(constraints) = constraints {
181            if !constraints.is_empty() {
182                write!(formatter, "for<{constraints}> ")?;
183            }
184        }
185
186        self.args.format_as_tuple(formatter)?;
187        write!(formatter, " -> {}", self.return_type)?;
188        Ok(())
189    }
190}
191
192impl<Prim: PrimitiveType> Function<Prim> {
193    pub(crate) fn new(args: Tuple<Prim>, return_type: Type<Prim>) -> Self {
194        Self {
195            args,
196            return_type,
197            params: None,
198        }
199    }
200
201    /// Returns a builder for `Function`s.
202    pub fn builder() -> FunctionBuilder<Prim> {
203        FunctionBuilder::default()
204    }
205
206    /// Gets the argument types of this function.
207    pub fn args(&self) -> &Tuple<Prim> {
208        &self.args
209    }
210
211    /// Gets the return type of this function.
212    pub fn return_type(&self) -> &Type<Prim> {
213        &self.return_type
214    }
215
216    pub(crate) fn set_params(&mut self, params: FnParams<Prim>) {
217        self.params = Some(Arc::new(params));
218    }
219
220    pub(crate) fn is_parametric(&self) -> bool {
221        self.params
222            .as_ref()
223            .is_some_and(|params| !params.is_empty())
224    }
225
226    /// Returns `true` iff this type does not contain type / length variables.
227    ///
228    /// See [`TypeEnvironment`](crate::TypeEnvironment) for caveats of dealing with
229    /// non-concrete types.
230    pub fn is_concrete(&self) -> bool {
231        self.args.is_concrete() && self.return_type.is_concrete()
232    }
233
234    /// Marks type params with the specified `indexes` to have `constraints`.
235    ///
236    /// # Panics
237    ///
238    /// - Panics if parameters were already computed for the function.
239    pub fn with_constraints<C: Constraint<Prim>>(
240        self,
241        indexes: &[usize],
242        constraint: C,
243    ) -> FnWithConstraints<Prim> {
244        assert!(
245            self.params.is_none(),
246            "Cannot attach constraints to a function with computed params: `{self}`"
247        );
248
249        let constraints = CompleteConstraints::from(ConstraintSet::just(constraint));
250        let type_params = indexes
251            .iter()
252            .map(|&idx| (idx, constraints.clone()))
253            .collect();
254
255        FnWithConstraints {
256            function: self,
257            constraints: ParamConstraints {
258                type_params,
259                static_lengths: HashSet::new(),
260            },
261        }
262    }
263
264    /// Marks lengths with the specified `indexes` as static.
265    ///
266    /// # Panics
267    ///
268    /// - Panics if parameters were already computed for the function.
269    pub fn with_static_lengths(self, indexes: &[usize]) -> FnWithConstraints<Prim> {
270        assert!(
271            self.params.is_none(),
272            "Cannot attach constraints to a function with computed params: `{self}`"
273        );
274
275        FnWithConstraints {
276            function: self,
277            constraints: ParamConstraints {
278                type_params: HashMap::new(),
279                static_lengths: indexes.iter().copied().collect(),
280            },
281        }
282    }
283}
284
285/// Function together with constraints on type variables contained either in the function itself
286/// or any of the child functions.
287///
288/// Constructed via [`Function::with_constraints()`].
289#[derive(Debug)]
290pub struct FnWithConstraints<Prim: PrimitiveType> {
291    function: Function<Prim>,
292    constraints: ParamConstraints<Prim>,
293}
294
295impl<Prim: PrimitiveType> fmt::Display for FnWithConstraints<Prim> {
296    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
297        if self.constraints.is_empty() {
298            fmt::Display::fmt(&self.function, formatter)
299        } else {
300            write!(formatter, "for<{}> {}", self.constraints, self.function)
301        }
302    }
303}
304
305impl<Prim: PrimitiveType> FnWithConstraints<Prim> {
306    /// Marks type params with the specified `indexes` to have `constraints`. If some constraints
307    /// are already present for some of the types, they are overwritten.
308    #[must_use]
309    pub fn with_constraint<C>(mut self, indexes: &[usize], constraint: &C) -> Self
310    where
311        C: Constraint<Prim> + Clone,
312    {
313        for &i in indexes {
314            let constraints = self.constraints.type_params.entry(i).or_default();
315            constraints.simple.insert(constraint.clone());
316        }
317        self
318    }
319
320    /// Marks lengths with the specified `indexes` as static.
321    #[must_use]
322    pub fn with_static_lengths(mut self, indexes: &[usize]) -> Self {
323        let indexes = indexes.iter().copied();
324        self.constraints.static_lengths.extend(indexes);
325        self
326    }
327}
328
329impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Function<Prim> {
330    fn from(value: FnWithConstraints<Prim>) -> Self {
331        let mut function = value.function;
332        ParamQuantifier::fill_params(&mut function, value.constraints);
333        function
334    }
335}
336
337impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Type<Prim> {
338    fn from(value: FnWithConstraints<Prim>) -> Self {
339        Function::from(value).into()
340    }
341}
342
343/// Builder for functional types.
344///
345/// **Tip.** You may also use [`FromStr`](core::str::FromStr) implementation to parse
346/// functional types.
347///
348/// # Examples
349///
350/// Signature for a function summing a slice of numbers:
351///
352/// ```
353/// # use arithmetic_typing::{Function, UnknownLen, Type, TypeEnvironment};
354/// let sum_fn_type = Function::builder()
355///     .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
356///     .returning(Type::NUM);
357/// assert_eq!(sum_fn_type.to_string(), "([Num; N]) -> Num");
358/// ```
359///
360/// Signature for a slice mapping function:
361///
362/// ```
363/// # use arithmetic_typing::{arith::Linearity, Function, UnknownLen, Type};
364/// // Definition of the mapping arg.
365/// let map_fn_arg = <Function>::builder()
366///     .with_arg(Type::param(0))
367///     .returning(Type::param(1));
368///
369/// let map_fn_type = <Function>::builder()
370///     .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
371///     .with_arg(map_fn_arg)
372///     .returning(Type::param(1).repeat(UnknownLen::Dynamic))
373///     .with_constraints(&[1], Linearity);
374/// assert_eq!(
375///     map_fn_type.to_string(),
376///     "for<'U: Lin> (['T; N], ('T) -> 'U) -> ['U]"
377/// );
378/// ```
379///
380/// Signature of a function with varargs:
381///
382/// ```
383/// # use arithmetic_typing::{Function, UnknownLen, Type};
384/// let fn_type = <Function>::builder()
385///     .with_varargs(Type::param(0), UnknownLen::param(0))
386///     .with_arg(Type::BOOL)
387///     .returning(Type::param(0));
388/// assert_eq!(fn_type.to_string(), "(...['T; N], Bool) -> 'T");
389/// ```
390#[derive(Debug, Clone)]
391#[must_use]
392pub struct FunctionBuilder<Prim: PrimitiveType = Num> {
393    args: Tuple<Prim>,
394}
395
396impl<Prim: PrimitiveType> Default for FunctionBuilder<Prim> {
397    fn default() -> Self {
398        Self {
399            args: Tuple::empty(),
400        }
401    }
402}
403
404impl<Prim: PrimitiveType> FunctionBuilder<Prim> {
405    /// Adds a new argument to the function definition.
406    pub fn with_arg(mut self, arg: impl Into<Type<Prim>>) -> Self {
407        self.args.push(arg.into());
408        self
409    }
410
411    /// Adds or sets varargs in the function definition.
412    pub fn with_varargs(
413        mut self,
414        element: impl Into<Type<Prim>>,
415        len: impl Into<TupleLen>,
416    ) -> Self {
417        self.args.set_middle(element.into(), len.into());
418        self
419    }
420
421    /// Declares the return type of the function and builds it.
422    pub fn returning(self, return_type: impl Into<Type<Prim>>) -> Function<Prim> {
423        Function::new(self.args, return_type.into())
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use core::iter;
430
431    use super::*;
432    use crate::{alloc::ToString, arith::Linearity, UnknownLen};
433
434    #[test]
435    fn constraints_display() {
436        let type_constraints = ConstraintSet::<Num>::just(Linearity);
437        let type_constraints = CompleteConstraints::from(type_constraints);
438
439        let type_params = (0, type_constraints);
440        let constraints = ParamConstraints {
441            type_params: iter::once(type_params.clone()).collect(),
442            static_lengths: HashSet::new(),
443        };
444        assert_eq!(constraints.to_string(), "'T: Lin");
445
446        let constraints: ParamConstraints<Num> = ParamConstraints {
447            type_params: iter::once(type_params).collect(),
448            static_lengths: iter::once(0).collect(),
449        };
450        assert_eq!(constraints.to_string(), "len! N; 'T: Lin");
451    }
452
453    #[test]
454    fn fn_with_constraints_display() {
455        let sum_fn = <Function>::builder()
456            .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
457            .returning(Type::param(0))
458            .with_constraints(&[0], Linearity);
459        assert_eq!(sum_fn.to_string(), "for<'T: Lin> (['T; N]) -> 'T");
460    }
461
462    #[test]
463    fn fn_builder_with_quantified_arg() {
464        let sum_fn: Function = Function::builder()
465            .with_arg(Type::NUM.repeat(UnknownLen::param(0)))
466            .returning(Type::NUM)
467            .with_constraints(&[], Linearity)
468            .into();
469        assert_eq!(sum_fn.to_string(), "([Num; N]) -> Num");
470
471        let complex_fn: Function = Function::builder()
472            .with_arg(Type::NUM)
473            .with_arg(sum_fn.clone())
474            .returning(Type::NUM)
475            .with_constraints(&[], Linearity)
476            .into();
477        assert_eq!(complex_fn.to_string(), "(Num, ([Num; N]) -> Num) -> Num");
478
479        let other_complex_fn: Function = Function::builder()
480            .with_varargs(Type::NUM, UnknownLen::param(0))
481            .with_arg(sum_fn)
482            .returning(Type::NUM)
483            .with_constraints(&[], Linearity)
484            .into();
485        assert_eq!(
486            other_complex_fn.to_string(),
487            "(...[Num; N], ([Num; N]) -> Num) -> Num"
488        );
489    }
490}