arithmetic_typing/
defs.rs

1//! Type definitions for the standard types from the [`arithmetic-eval`] crate.
2//!
3//! [`arithmetic-eval`]: https://docs.rs/arithmetic-eval/
4
5use core::iter;
6
7use crate::{arith::WithBoolean, Function, Object, PrimitiveType, Type, UnknownLen};
8
9/// Map containing type definitions for all variables from `Prelude` in the eval crate.
10///
11/// # Contents
12///
13/// - `true` and `false` Boolean constants
14/// - `if`, `while`, `map`, `filter`, `fold`, `push` and `merge` functions
15///
16/// The `merge` function has somewhat imprecise typing; its return value is
17/// a dynamically-sized slice.
18///
19/// The `array` function is available separately via [`Self::array()`].
20///
21/// # Examples
22///
23/// Function counting number of zeros in a slice:
24///
25/// ```
26/// use arithmetic_parser::grammars::{F32Grammar, Parse};
27/// use arithmetic_typing::{defs::Prelude, Annotated, TypeEnvironment, Type};
28///
29/// # fn main() -> anyhow::Result<()> {
30/// let code = "|xs| xs.fold(0, |acc, x| if(x == 0, acc + 1, acc))";
31/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
32///
33/// let mut env: TypeEnvironment = Prelude::iter().collect();
34/// let count_zeros_fn = env.process_statements(&ast)?;
35/// assert_eq!(count_zeros_fn.to_string(), "([Num; N]) -> Num");
36/// # Ok(())
37/// # }
38/// ```
39///
40/// Limitations of `merge`:
41///
42/// ```
43/// # use arithmetic_parser::grammars::{F32Grammar, Parse};
44/// # use arithmetic_typing::{defs::Prelude, error::ErrorKind, Annotated, TypeEnvironment, Type};
45/// # use assert_matches::assert_matches;
46/// # fn main() -> anyhow::Result<()> {
47/// let code = "
48///     len = |xs| xs.fold(0, |acc, _| acc + 1);
49///     slice = (1, 2).merge((3, 4));
50///     slice.len(); // methods working on slices are applicable
51///     (_, _, _, z) = slice; // but destructuring is not
52/// ";
53/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
54///
55/// let mut env: TypeEnvironment = Prelude::iter().collect();
56/// let errors = env.process_statements(&ast).unwrap_err();
57/// assert_eq!(errors.len(), 1);
58/// let err = errors.iter().next().unwrap();
59/// assert_eq!(err.main_location().span(code), "(_, _, _, z)");
60/// # assert_matches!(err.kind(), ErrorKind::TupleLenMismatch { .. });
61/// # Ok(())
62/// # }
63/// ```
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65#[non_exhaustive]
66pub enum Prelude {
67    /// `false` type (Boolean).
68    False,
69    /// `true` type (Boolean).
70    True,
71    /// Type of the `if` function.
72    If,
73    /// Type of the `while` function.
74    While,
75    /// Type of the `defer` function.
76    Defer,
77    /// Type of the `map` function.
78    Map,
79    /// Type of the `filter` function.
80    Filter,
81    /// Type of the `fold` function.
82    Fold,
83    /// Type of the `push` function.
84    Push,
85    /// Type of the `merge` function.
86    Merge,
87    /// Type of the `all` function.
88    All,
89    /// Type of the `any` function.
90    Any,
91}
92
93impl<Prim: WithBoolean> From<Prelude> for Type<Prim> {
94    fn from(value: Prelude) -> Self {
95        match value {
96            Prelude::True | Prelude::False => Type::BOOL,
97
98            Prelude::If => Function::builder()
99                .with_arg(Type::BOOL)
100                .with_arg(Type::param(0))
101                .with_arg(Type::param(0))
102                .returning(Type::param(0))
103                .into(),
104
105            Prelude::While => {
106                let condition_fn = Function::builder()
107                    .with_arg(Type::param(0))
108                    .returning(Type::BOOL);
109                let iter_fn = Function::builder()
110                    .with_arg(Type::param(0))
111                    .returning(Type::param(0));
112
113                Function::builder()
114                    .with_arg(Type::param(0)) // state
115                    .with_arg(condition_fn)
116                    .with_arg(iter_fn)
117                    .returning(Type::param(0))
118                    .into()
119            }
120
121            Prelude::Defer => {
122                let fn_arg = Function::builder()
123                    .with_arg(Type::param(0))
124                    .returning(Type::param(0));
125                Function::builder()
126                    .with_arg(fn_arg)
127                    .returning(Type::param(0))
128                    .into()
129            }
130
131            Prelude::Map => {
132                let map_arg = Function::builder()
133                    .with_arg(Type::param(0))
134                    .returning(Type::param(1));
135
136                Function::builder()
137                    .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
138                    .with_arg(map_arg)
139                    .returning(Type::param(1).repeat(UnknownLen::param(0)))
140                    .into()
141            }
142
143            Prelude::Filter => {
144                let predicate_arg = Function::builder()
145                    .with_arg(Type::param(0))
146                    .returning(Type::BOOL);
147
148                Function::builder()
149                    .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
150                    .with_arg(predicate_arg)
151                    .returning(Type::param(0).repeat(UnknownLen::Dynamic))
152                    .into()
153            }
154
155            Prelude::Fold => {
156                // 0th type param is slice element, 1st is accumulator
157                let fold_arg = Function::builder()
158                    .with_arg(Type::param(1))
159                    .with_arg(Type::param(0))
160                    .returning(Type::param(1));
161
162                Function::builder()
163                    .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
164                    .with_arg(Type::param(1))
165                    .with_arg(fold_arg)
166                    .returning(Type::param(1))
167                    .into()
168            }
169
170            Prelude::Push => Function::builder()
171                .with_arg(Type::param(0).repeat(UnknownLen::param(0)))
172                .with_arg(Type::param(0))
173                .returning(Type::param(0).repeat(UnknownLen::param(0) + 1))
174                .into(),
175
176            Prelude::Merge => Function::builder()
177                .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
178                .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
179                .returning(Type::param(0).repeat(UnknownLen::Dynamic))
180                .into(),
181
182            Prelude::All | Prelude::Any => {
183                let predicate_arg = Function::builder()
184                    .with_arg(Type::param(0))
185                    .returning(Type::BOOL);
186
187                Function::builder()
188                    .with_arg(Type::param(0).repeat(UnknownLen::Dynamic))
189                    .with_arg(predicate_arg)
190                    .returning(Type::BOOL)
191                    .into()
192            }
193        }
194    }
195}
196
197impl Prelude {
198    const VALUES: &'static [Self] = &[Self::True, Self::False, Self::If, Self::While, Self::Defer];
199
200    const ARRAY_FUNCTIONS: &'static [Self] = &[
201        Self::Map,
202        Self::Filter,
203        Self::Fold,
204        Self::Push,
205        Self::Merge,
206        Self::All,
207        Self::Any,
208    ];
209
210    fn as_str(self) -> &'static str {
211        match self {
212            Self::True => "true",
213            Self::False => "false",
214            Self::If => "if",
215            Self::While => "while",
216            Self::Defer => "defer",
217            Self::Map => "map",
218            Self::Filter => "filter",
219            Self::Fold => "fold",
220            Self::Push => "push",
221            Self::Merge => "merge",
222            Self::All => "all",
223            Self::Any => "any",
224        }
225    }
226
227    /// Returns the type of the `array` generation function from the eval crate.
228    ///
229    /// The `array` function is not included into [`Self::iter()`] because in the general case
230    /// we don't know the type of indexes.
231    pub fn array<T: PrimitiveType>(index_type: T) -> Function<T> {
232        Function::builder()
233            .with_arg(Type::Prim(index_type.clone()))
234            .with_arg(
235                Function::builder()
236                    .with_arg(Type::Prim(index_type))
237                    .returning(Type::param(0)),
238            )
239            .returning(Type::param(0).repeat(UnknownLen::Dynamic))
240    }
241
242    fn array_namespace<Prim: WithBoolean>() -> Object<Prim> {
243        Self::ARRAY_FUNCTIONS
244            .iter()
245            .map(|&value| (value.as_str(), Type::from(value)))
246            .collect()
247    }
248
249    /// Returns an iterator over all type definitions in the `Prelude`.
250    pub fn iter<Prim: WithBoolean>() -> impl Iterator<Item = (&'static str, Type<Prim>)> {
251        Self::VALUES
252            .iter()
253            .chain(Self::ARRAY_FUNCTIONS)
254            .map(|&value| (value.as_str(), value.into()))
255            .chain(iter::once(("Array", Self::array_namespace().into())))
256    }
257}
258
259/// Definitions for `assert` and `assert_eq` functions.
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
261#[non_exhaustive]
262pub enum Assertions {
263    /// Type of the `assert` function.
264    Assert,
265    /// Type of the `assert_eq` function.
266    AssertEq,
267    /// Type of the `assert_fails` function.
268    AssertFails,
269}
270
271impl<Prim: WithBoolean> From<Assertions> for Type<Prim> {
272    fn from(value: Assertions) -> Self {
273        match value {
274            Assertions::Assert => Function::builder()
275                .with_arg(Type::BOOL)
276                .returning(Type::void())
277                .into(),
278            Assertions::AssertEq => Function::builder()
279                .with_arg(Type::param(0))
280                .with_arg(Type::param(0))
281                .returning(Type::void())
282                .into(),
283            Assertions::AssertFails => {
284                let checked_fn = Function::builder().returning(Type::param(0));
285                Function::builder()
286                    .with_arg(checked_fn)
287                    .returning(Type::void())
288                    .into()
289            }
290        }
291    }
292}
293
294impl Assertions {
295    const VALUES: &'static [Self] = &[Self::Assert, Self::AssertEq, Self::AssertFails];
296
297    fn as_str(self) -> &'static str {
298        match self {
299            Self::Assert => "assert",
300            Self::AssertEq => "assert_eq",
301            Self::AssertFails => "assert_fails",
302        }
303    }
304
305    /// Returns an iterator over all type definitions in `Assertions`.
306    pub fn iter<Prim: WithBoolean>() -> impl Iterator<Item = (&'static str, Type<Prim>)> {
307        Self::VALUES.iter().map(|&val| (val.as_str(), val.into()))
308    }
309
310    /// Returns the type of the `assert_close` function from the eval crate.
311    ///
312    /// This function is not included into [`Self::iter()`] because in the general case
313    /// we don't know the type of arguments it accepts.
314    pub fn assert_close<T: PrimitiveType>(value: T) -> Function<T> {
315        Function::builder()
316            .with_arg(Type::Prim(value.clone()))
317            .with_arg(Type::Prim(value))
318            .returning(Type::void())
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::{
326        alloc::{HashMap, HashSet, ToString},
327        arith::Num,
328    };
329
330    const EXPECTED_PRELUDE_TYPES: &[(&str, &str)] = &[
331        ("false", "Bool"),
332        ("true", "Bool"),
333        ("if", "(Bool, 'T, 'T) -> 'T"),
334        ("while", "('T, ('T) -> Bool, ('T) -> 'T) -> 'T"),
335        ("defer", "(('T) -> 'T) -> 'T"),
336        ("map", "(['T; N], ('T) -> 'U) -> ['U; N]"),
337        ("filter", "(['T], ('T) -> Bool) -> ['T]"),
338        ("fold", "(['T], 'U, ('U, 'T) -> 'U) -> 'U"),
339        ("push", "(['T; N], 'T) -> ['T; N + 1]"),
340        ("merge", "(['T], ['T]) -> ['T]"),
341        ("all", "(['T], ('T) -> Bool) -> Bool"),
342        ("any", "(['T], ('T) -> Bool) -> Bool"),
343    ];
344
345    #[test]
346    fn string_presentations_of_prelude_types() {
347        let expected_types: HashMap<_, _> = EXPECTED_PRELUDE_TYPES.iter().copied().collect();
348
349        for (name, ty) in Prelude::iter::<Num>() {
350            if name != "Array" {
351                assert_eq!(ty.to_string(), expected_types[name]);
352            }
353        }
354
355        assert_eq!(
356            Prelude::iter::<Num>()
357                .filter_map(|(name, _)| (name != "Array").then_some(name))
358                .collect::<HashSet<_>>(),
359            expected_types.keys().copied().collect::<HashSet<_>>()
360        );
361    }
362
363    #[test]
364    fn string_presentation_of_array_type() {
365        let array_fn = Prelude::array(Num::Num);
366        assert_eq!(array_fn.to_string(), "(Num, (Num) -> 'T) -> ['T]");
367    }
368
369    const EXPECTED_ASSERT_TYPES: &[(&str, &str)] = &[
370        ("assert", "(Bool) -> ()"),
371        ("assert_eq", "('T, 'T) -> ()"),
372        ("assert_fails", "(() -> 'T) -> ()"),
373    ];
374
375    #[test]
376    fn string_representation_of_assert_types() {
377        let expected_types: HashMap<_, _> = EXPECTED_ASSERT_TYPES.iter().copied().collect();
378
379        for (name, ty) in Assertions::iter::<Num>() {
380            assert_eq!(ty.to_string(), expected_types[name]);
381        }
382
383        assert_eq!(
384            Assertions::iter::<Num>()
385                .map(|(name, _)| name)
386                .collect::<HashSet<_>>(),
387            expected_types.keys().copied().collect::<HashSet<_>>()
388        );
389    }
390
391    #[test]
392    fn string_representation_of_assert_close() {
393        let assert_close_fn = Assertions::assert_close(Num::Num);
394        assert_eq!(assert_close_fn.to_string(), "(Num, Num) -> ()");
395    }
396}