arithmetic_typing/ast/
mod.rs

1//! ASTs for type annotations and their parsing logic.
2//!
3//! # Overview
4//!
5//! This module contains types representing AST for parsed type annotations; for example,
6//! [`TypeAst`] and [`FunctionAst`]. These two types expose `parse` method which
7//! allows to integrate them into `nom` parsing.
8
9use arithmetic_parser::{with_span, ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned};
10use nom::{
11    branch::alt,
12    bytes::complete::{tag, take, take_until, take_while, take_while1, take_while_m_n},
13    character::complete::char as tag_char,
14    combinator::{cut, map, map_res, not, opt, peek, recognize},
15    multi::{many0, separated_list0, separated_list1},
16    sequence::{delimited, preceded, separated_pair, terminated},
17    Parser as _,
18};
19
20pub use self::conversion::AstConversionError;
21pub(crate) use self::conversion::AstConversionState;
22use crate::alloc::{Box, Vec};
23
24mod conversion;
25#[cfg(test)]
26mod tests;
27
28/// Type annotation after parsing.
29///
30/// Compared to [`Type`], this enum corresponds to AST, not to the logical presentation
31/// of a type.
32///
33/// [`Type`]: crate::Type
34///
35/// # Examples
36///
37/// ```
38/// use arithmetic_parser::InputSpan;
39/// # use arithmetic_typing::ast::TypeAst;
40/// # use assert_matches::assert_matches;
41///
42/// # fn main() -> anyhow::Result<()> {
43/// let input = InputSpan::new("(Num, ('T) -> ('T, 'T))");
44/// let (_, ty) = TypeAst::parse(input)?;
45/// let TypeAst::Tuple(elements) = ty.extra else {
46///     unreachable!();
47/// };
48/// assert_eq!(elements.start[0].extra, TypeAst::Ident);
49/// assert_matches!(
50///     &elements.start[1].extra,
51///     TypeAst::Function { .. }
52/// );
53/// # Ok(())
54/// # }
55/// ```
56#[derive(Debug, Clone, PartialEq)]
57#[non_exhaustive]
58pub enum TypeAst<'a> {
59    /// Type placeholder (`_`). Corresponds to a certain type that is not specified, like `_`
60    /// in type annotations in Rust.
61    Some,
62    /// Any type (`any`).
63    Any,
64    /// Dynamically applied constraints (`dyn _`).
65    Dyn(TypeConstraintsAst<'a>),
66    /// Non-ticked identifier, e.g., `Bool`.
67    Ident,
68    /// Ticked identifier, e.g., `'T`.
69    Param,
70    /// Functional type.
71    Function(Box<FunctionAst<'a>>),
72    /// Functional type with constraints.
73    FunctionWithConstraints {
74        /// Constraints on function params.
75        constraints: Spanned<'a, ConstraintsAst<'a>>,
76        /// Function body.
77        function: Box<Spanned<'a, FunctionAst<'a>>>,
78    },
79    /// Tuple type; for example, `(Num, Bool)`.
80    Tuple(TupleAst<'a>),
81    /// Slice type; for example, `[Num]` or `[(Num, T); N]`.
82    Slice(SliceAst<'a>),
83    /// Object type; for example, `{ len: Num }`. Not to be confused with object constraints.
84    Object(ObjectAst<'a>),
85}
86
87impl<'a> TypeAst<'a> {
88    /// Parses `input` as a type. This parser can be composed using `nom` infrastructure.
89    pub fn parse(input: InputSpan<'a>) -> NomResult<'a, Spanned<'a, Self>> {
90        with_span(type_definition).parse(input)
91    }
92}
93
94/// Spanned [`TypeAst`].
95pub type SpannedTypeAst<'a> = Spanned<'a, TypeAst<'a>>;
96
97/// Parsed tuple type, such as `(Num, Bool)` or `(fn() -> Num, ...[Num; _])`.
98#[derive(Debug, Clone, PartialEq)]
99pub struct TupleAst<'a> {
100    /// Elements at the beginning of the tuple, e.g., `Num` and `Bool`
101    /// in `(Num, Bool, ...[T; _])`.
102    pub start: Vec<SpannedTypeAst<'a>>,
103    /// Middle of the tuple, e.g., `[T; _]` in `(Num, Bool, ...[T; _])`.
104    pub middle: Option<Spanned<'a, SliceAst<'a>>>,
105    /// Elements at the end of the tuple, e.g., `Bool` in `(...[Num; _], Bool)`.
106    /// Guaranteed to be empty if `middle` is not present.
107    pub end: Vec<SpannedTypeAst<'a>>,
108}
109
110/// Parsed slice type, such as `[Num; N]`.
111#[derive(Debug, Clone, PartialEq)]
112pub struct SliceAst<'a> {
113    /// Element of this slice; for example, `Num` in `[Num; N]`.
114    pub element: Box<SpannedTypeAst<'a>>,
115    /// Length of this slice; for example, `N` in `[Num; N]`.
116    pub length: Spanned<'a, TupleLenAst>,
117}
118
119/// Parsed functional type.
120///
121/// In contrast to [`Function`], this struct corresponds to AST, not to the logical representation
122/// of functional types.
123///
124/// [`Function`]: crate::Function
125///
126/// # Examples
127///
128/// ```
129/// use arithmetic_parser::InputSpan;
130/// # use assert_matches::assert_matches;
131/// # use arithmetic_typing::ast::{FunctionAst, TypeAst};
132///
133/// # fn main() -> anyhow::Result<()> {
134/// let input = InputSpan::new("([Num; N]) -> Num");
135/// let (rest, ty) = FunctionAst::parse(input)?;
136/// assert!(rest.fragment().is_empty());
137/// assert_matches!(ty.args.extra.start[0].extra, TypeAst::Slice(_));
138/// assert_eq!(ty.return_type.extra, TypeAst::Ident);
139/// # Ok(())
140/// # }
141/// ```
142#[derive(Debug, Clone, PartialEq)]
143#[non_exhaustive]
144pub struct FunctionAst<'a> {
145    /// Function arguments.
146    pub args: Spanned<'a, TupleAst<'a>>,
147    /// Return type of the function.
148    pub return_type: SpannedTypeAst<'a>,
149}
150
151impl<'a> FunctionAst<'a> {
152    /// Parses `input` as a functional type. This parser can be composed using `nom` infrastructure.
153    pub fn parse(input: InputSpan<'a>) -> NomResult<'a, Self> {
154        fn_definition(input)
155    }
156}
157
158/// Parsed tuple length.
159#[derive(Debug, Clone, PartialEq, Eq)]
160#[non_exhaustive]
161pub enum TupleLenAst {
162    /// Length placeholder (`_`). Corresponds to any single length.
163    Some,
164    /// Dynamic tuple length. This length is *implicit*, as in `[Num]`. As such, it has
165    /// an empty span.
166    Dynamic,
167    /// Reference to a length; for example, `N` in `[Num; N]`.
168    Ident,
169}
170
171/// Parameter constraints, e.g. `for<len! N; T: Lin>`.
172#[derive(Debug, Clone, PartialEq)]
173#[non_exhaustive]
174pub struct ConstraintsAst<'a> {
175    /// Static lengths, e.g., `N` in `for<len! N>`.
176    pub static_lengths: Vec<Spanned<'a>>,
177    /// Type constraints.
178    pub type_params: Vec<(Spanned<'a>, TypeConstraintsAst<'a>)>,
179}
180
181/// Bounds that can be placed on a type variable.
182#[derive(Debug, Default, Clone, PartialEq)]
183#[non_exhaustive]
184pub struct TypeConstraintsAst<'a> {
185    /// Object constraint, such as `{ x: 'T }`.
186    pub object: Option<ObjectAst<'a>>,
187    /// Spans corresponding to constraints, e.g. `Foo` and `Bar` in `Foo + Bar`.
188    pub terms: Vec<Spanned<'a>>,
189}
190
191/// Object type or constraint, such as `{ x: Num, y: [(Num, Bool)] }`.
192#[derive(Debug, Clone, PartialEq)]
193#[non_exhaustive]
194pub struct ObjectAst<'a> {
195    /// Fields of the object.
196    pub fields: Vec<(Spanned<'a>, SpannedTypeAst<'a>)>,
197}
198
199/// Whitespace and comments.
200fn ws(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
201    fn narrow_ws(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
202        take_while1(|c: char| c.is_ascii_whitespace())(input)
203    }
204
205    fn long_comment_body(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
206        cut(take_until("*/")).parse(input)
207    }
208
209    let comment = preceded(tag("//"), take_while(|c: char| c != '\n'));
210    let long_comment = delimited(tag("/*"), long_comment_body, tag("*/"));
211    let ws_line = alt((narrow_ws, comment, long_comment));
212    recognize(many0(ws_line)).parse(input)
213}
214
215/// Comma separator.
216fn comma_sep(input: InputSpan<'_>) -> NomResult<'_, char> {
217    delimited(ws, tag_char(','), ws).parse(input)
218}
219
220fn ident(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
221    preceded(
222        peek(take_while_m_n(1, 1, |c: char| {
223            c.is_ascii_alphabetic() || c == '_'
224        })),
225        map(
226            take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_'),
227            Spanned::from,
228        ),
229    )
230    .parse(input)
231}
232
233fn not_keyword(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
234    map_res(ident, |ident| {
235        if *ident.fragment() == "as" {
236            Err(ParseErrorKind::Type(anyhow::anyhow!(
237                "`as` is a reserved keyword"
238            )))
239        } else {
240            Ok(ident)
241        }
242    })
243    .parse(input)
244}
245
246fn type_param_ident(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
247    preceded(tag_char('\''), ident).parse(input)
248}
249
250fn comma_separated_types(input: InputSpan<'_>) -> NomResult<'_, Vec<SpannedTypeAst<'_>>> {
251    separated_list0(delimited(ws, tag_char(','), ws), with_span(type_definition)).parse(input)
252}
253
254fn tuple_middle(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_, SliceAst<'_>>> {
255    preceded(terminated(tag("..."), ws), with_span(slice_definition)).parse(input)
256}
257
258type TupleTailAst<'a> = (Spanned<'a, SliceAst<'a>>, Vec<SpannedTypeAst<'a>>);
259
260fn tuple_tail(input: InputSpan<'_>) -> NomResult<'_, TupleTailAst<'_>> {
261    (
262        tuple_middle,
263        map(
264            opt(preceded(comma_sep, comma_separated_types)),
265            Option::unwrap_or_default,
266        ),
267    )
268        .parse(input)
269}
270
271fn tuple_definition(input: InputSpan<'_>) -> NomResult<'_, TupleAst<'_>> {
272    let maybe_comma = opt(comma_sep);
273
274    let main_parser = alt((
275        map(tuple_tail, |(middle, end)| TupleAst {
276            start: Vec::new(),
277            middle: Some(middle),
278            end,
279        }),
280        map(
281            (comma_separated_types, opt(preceded(comma_sep, tuple_tail))),
282            |(start, maybe_tail)| {
283                if let Some((middle, end)) = maybe_tail {
284                    TupleAst {
285                        start,
286                        middle: Some(middle),
287                        end,
288                    }
289                } else {
290                    TupleAst {
291                        start,
292                        middle: None,
293                        end: Vec::new(),
294                    }
295                }
296            },
297        ),
298    ));
299
300    preceded(
301        terminated(tag_char('('), ws),
302        // Once we've encountered the opening `(`, the input *must* correspond to the parser.
303        cut(terminated(main_parser, (maybe_comma, ws, tag_char(')')))),
304    )
305    .parse(input)
306}
307
308fn tuple_len(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_, TupleLenAst>> {
309    let semicolon = (ws, tag_char(';'), ws);
310    let empty = map(take(0_usize), Spanned::from);
311    map(alt((preceded(semicolon, not_keyword), empty)), |id| {
312        id.map_extra(|()| match *id.fragment() {
313            "_" => TupleLenAst::Some,
314            "" => TupleLenAst::Dynamic,
315            _ => TupleLenAst::Ident,
316        })
317    })
318    .parse(input)
319}
320
321fn slice_definition(input: InputSpan<'_>) -> NomResult<'_, SliceAst<'_>> {
322    preceded(
323        terminated(tag_char('['), ws),
324        // Once we've encountered the opening `[`, the input *must* correspond to the parser.
325        cut(terminated(
326            map(
327                (with_span(type_definition), tuple_len),
328                |(element, length)| SliceAst {
329                    element: Box::new(element),
330                    length,
331                },
332            ),
333            (ws, tag_char(']')),
334        )),
335    )
336    .parse(input)
337}
338
339fn object(input: InputSpan<'_>) -> NomResult<'_, ObjectAst<'_>> {
340    let colon = (ws, tag_char(':'), ws);
341    let object_field = separated_pair(ident, colon, with_span(type_definition));
342    let object_body = terminated(separated_list1(comma_sep, object_field), opt(comma_sep));
343    let object = preceded(
344        terminated(tag_char('{'), ws),
345        cut(terminated(object_body, (ws, tag_char('}')))),
346    );
347    map(object, |fields| ObjectAst { fields }).parse(input)
348}
349
350fn constraint_sep(input: InputSpan<'_>) -> NomResult<'_, ()> {
351    map((ws, tag_char('+'), ws), drop).parse(input)
352}
353
354fn simple_type_bounds(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
355    map(separated_list1(constraint_sep, not_keyword), |terms| {
356        TypeConstraintsAst {
357            object: None,
358            terms,
359        }
360    })
361    .parse(input)
362}
363
364fn type_bounds(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
365    alt((
366        map(
367            (
368                object,
369                opt(preceded(
370                    constraint_sep,
371                    separated_list1(constraint_sep, not_keyword),
372                )),
373            ),
374            |(object, terms)| TypeConstraintsAst {
375                object: Some(object),
376                terms: terms.unwrap_or_default(),
377            },
378        ),
379        simple_type_bounds,
380    ))
381    .parse(input)
382}
383
384fn type_params(input: InputSpan<'_>) -> NomResult<'_, Vec<(Spanned<'_>, TypeConstraintsAst<'_>)>> {
385    let type_bounds = preceded((ws, tag_char(':'), ws), type_bounds);
386    let type_param = (type_param_ident, type_bounds);
387    separated_list1(comma_sep, type_param).parse(input)
388}
389
390/// Function params, including the `for` keyword and `<>` brackets.
391fn constraints(input: InputSpan<'_>) -> NomResult<'_, ConstraintsAst<'_>> {
392    let semicolon = (ws, tag_char(';'), ws);
393
394    let len_params = preceded(
395        terminated(tag("len!"), ws),
396        separated_list1(comma_sep, not_keyword),
397    );
398
399    let params_parser = alt((
400        map(
401            (len_params, opt(preceded(semicolon, type_params))),
402            |(static_lengths, type_params)| (static_lengths, type_params.unwrap_or_default()),
403        ),
404        map(type_params, |type_params| (Vec::new(), type_params)),
405    ));
406
407    let constraints_parser = (
408        terminated(tag("for"), ws),
409        terminated(tag_char('<'), ws),
410        cut(terminated(params_parser, (ws, tag_char('>')))),
411    );
412
413    map(
414        constraints_parser,
415        |(_, _, (static_lengths, type_params))| ConstraintsAst {
416            static_lengths,
417            type_params,
418        },
419    )
420    .parse(input)
421}
422
423fn return_type(input: InputSpan<'_>) -> NomResult<'_, SpannedTypeAst<'_>> {
424    preceded((ws, tag("->"), ws), cut(with_span(type_definition))).parse(input)
425}
426
427fn fn_or_tuple(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
428    map(
429        (with_span(tuple_definition), opt(return_type)),
430        |(args, return_type)| {
431            if let Some(return_type) = return_type {
432                TypeAst::Function(Box::new(FunctionAst { args, return_type }))
433            } else {
434                TypeAst::Tuple(args.extra)
435            }
436        },
437    )
438    .parse(input)
439}
440
441fn fn_definition(input: InputSpan<'_>) -> NomResult<'_, FunctionAst<'_>> {
442    map(
443        (with_span(tuple_definition), return_type),
444        |(args, return_type)| FunctionAst { args, return_type },
445    )
446    .parse(input)
447}
448
449fn fn_definition_with_constraints(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
450    map(
451        (with_span(constraints), ws, cut(with_span(fn_definition))),
452        |(constraints, _, function)| TypeAst::FunctionWithConstraints {
453            constraints,
454            function: Box::new(function),
455        },
456    )
457    .parse(input)
458}
459
460fn not_ident_char(input: InputSpan<'_>) -> NomResult<'_, ()> {
461    peek(not(take_while_m_n(1, 1, |c: char| {
462        c.is_ascii_alphanumeric() || c == '_'
463    })))
464    .parse(input)
465}
466
467fn any_type(input: InputSpan<'_>) -> NomResult<'_, ()> {
468    terminated(map(tag("any"), drop), not_ident_char).parse(input)
469}
470
471fn dyn_type(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
472    map(
473        preceded(
474            terminated(tag("dyn"), not_ident_char),
475            opt(preceded(ws, type_bounds)),
476        ),
477        Option::unwrap_or_default,
478    )
479    .parse(input)
480}
481
482fn free_ident(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
483    map(not_keyword, |id| match *id.fragment() {
484        "_" => TypeAst::Some,
485        _ => TypeAst::Ident,
486    })
487    .parse(input)
488}
489
490fn type_definition(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
491    alt((
492        fn_or_tuple,
493        fn_definition_with_constraints,
494        map(type_param_ident, |_| TypeAst::Param),
495        map(slice_definition, TypeAst::Slice),
496        map(object, TypeAst::Object),
497        map(dyn_type, TypeAst::Dyn),
498        map(any_type, |()| TypeAst::Any),
499        free_ident,
500    ))
501    .parse(input)
502}