arithmetic_parser/grammars/
mod.rs

1//! Grammar functionality and a collection of standard grammars.
2//!
3//! # Defining grammars
4//!
5//! To define a [`Grammar`], you'll need a [`ParseLiteral`] implementation, which defines
6//! how literals are parsed (numbers, strings, chars, hex- / base64-encoded byte sequences, etc.).
7//! There are standard impls for floating-point number parsing and the complex numbers
8//! (if the relevant feature is on).
9//!
10//! You may define how to parse type annotations by implementing `Grammar` explicitly.
11//! Alternatively, if you don't need type annotations, a `Grammar` can be obtained from
12//! a [`ParseLiteral`] impl by wrapping it into [`Untyped`].
13//!
14//! Once you have a `Grammar`, you can supply it as a `Base` for [`Parse`]. `Parse` methods
15//! allow to parse complete or streaming [`Block`](crate::Block)s of statements.
16//! Note that `Untyped` and [`Typed`] wrappers allow to avoid an explicit `Parse` impl.
17//!
18//! See [`ParseLiteral`], [`Grammar`] and [`Parse`] docs for the examples of various grammar
19//! definitions.
20
21use core::{fmt, marker::PhantomData};
22
23use nom::{
24    bytes::complete::take_while_m_n,
25    character::complete::{char as tag_char, digit1},
26    combinator::{map_res, not, opt, peek, recognize},
27    number::complete::{double, float},
28    sequence::terminated,
29    Input, Parser as _,
30};
31
32pub use self::traits::{
33    Features, Grammar, IntoInputSpan, MockTypes, Parse, ParseLiteral, Typed, Untyped,
34    WithMockedTypes,
35};
36use crate::{spans::NomResult, ErrorKind, InputSpan};
37
38mod traits;
39
40/// Single-type numeric grammar parameterized by the literal type.
41#[derive(Debug)]
42pub struct NumGrammar<T>(PhantomData<T>);
43
44/// Type alias for a grammar on `f32` literals.
45pub type F32Grammar = NumGrammar<f32>;
46/// Type alias for a grammar on `f64` literals.
47pub type F64Grammar = NumGrammar<f64>;
48
49impl<T: NumLiteral> ParseLiteral for NumGrammar<T> {
50    type Lit = T;
51
52    fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
53        T::parse(input)
54    }
55}
56
57/// Numeric literal used in `NumGrammar`s.
58pub trait NumLiteral: 'static + Clone + fmt::Debug {
59    /// Tries to parse a literal.
60    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self>;
61}
62
63/// Ensures that the child parser does not consume a part of a larger expression by rejecting
64/// if the part following the input is an alphanumeric char or `_`.
65///
66/// For example, `float` parses `-Inf`, which can lead to parser failure if it's a part of
67/// a larger expression (e.g., `-Infer(2, 3)`).
68pub fn ensure_no_overlap<'a, F>(
69    mut parser: F,
70) -> impl nom::Parser<InputSpan<'a>, Output = F::Output, Error = F::Error>
71where
72    F: nom::Parser<InputSpan<'a>>,
73{
74    let truncating_parser = move |input| {
75        parser
76            .parse(input)
77            .map(|(rest, number)| (maybe_truncate_consumed_input(input, rest), number))
78    };
79
80    terminated(
81        truncating_parser,
82        peek(not(take_while_m_n(1, 1, |c: char| {
83            c.is_ascii_alphabetic() || c == '_'
84        }))),
85    )
86}
87
88fn can_start_a_var_name(byte: u8) -> bool {
89    byte == b'_' || byte.is_ascii_alphabetic()
90}
91
92fn maybe_truncate_consumed_input<'a>(input: InputSpan<'a>, rest: InputSpan<'a>) -> InputSpan<'a> {
93    let relative_offset = rest.location_offset() - input.location_offset();
94    debug_assert!(relative_offset > 0, "num parser succeeded for empty string");
95    let last_consumed_byte_index = relative_offset - 1;
96
97    let input_fragment = *input.fragment();
98    let input_as_bytes = input_fragment.as_bytes();
99    if relative_offset < input_fragment.len()
100        && input_fragment.is_char_boundary(last_consumed_byte_index)
101        && input_as_bytes[last_consumed_byte_index] == b'.'
102        && can_start_a_var_name(input_as_bytes[relative_offset])
103    {
104        // The last char consumed by the parser is '.' and the next part looks like
105        // a method call. Shift the `rest` boundary to include '.'.
106        input.take_from(last_consumed_byte_index)
107    } else {
108        rest
109    }
110}
111
112macro_rules! impl_num_literal_for_uint {
113    ($($num:ident),+) => {
114        $(
115        impl NumLiteral for $num {
116            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
117                let parser = |s: InputSpan<'_>| {
118                    s.fragment()
119                        .parse::<$num>()
120                        .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
121                };
122                map_res(digit1, parser).parse(input)
123            }
124        }
125        )+
126    };
127}
128
129impl_num_literal_for_uint!(u8, u16, u32, u64, u128);
130
131macro_rules! impl_num_literal_for_int {
132    ($($num:ident),+) => {
133        $(
134        impl NumLiteral for $num {
135            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
136                let parser = |s: InputSpan<'_>| {
137                    s.fragment()
138                        .parse::<$num>()
139                        .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
140                };
141                map_res(recognize((opt(tag_char('-')), digit1)), parser).parse(input)
142            }
143        }
144        )+
145    };
146}
147
148impl_num_literal_for_int!(i8, i16, i32, i64, i128);
149
150impl NumLiteral for f32 {
151    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
152        ensure_no_overlap(float).parse(input)
153    }
154}
155
156impl NumLiteral for f64 {
157    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
158        ensure_no_overlap(double).parse(input)
159    }
160}
161
162#[cfg(feature = "num-complex")]
163mod complex {
164    use nom::{
165        branch::alt,
166        character::complete::one_of,
167        combinator::{map, opt},
168        number::complete::{double, float},
169        Parser as _,
170    };
171    use num_complex::Complex;
172    use num_traits::Num;
173
174    use super::{ensure_no_overlap, NumLiteral};
175    use crate::{InputSpan, NomResult};
176
177    fn complex_parser<'a, T: Num, F>(
178        num_parser: F,
179    ) -> impl nom::Parser<InputSpan<'a>, Output = Complex<T>, Error = F::Error>
180    where
181        F: nom::Parser<InputSpan<'a>, Output = T>,
182    {
183        let i_parser = map(one_of("ij"), |_| Complex::new(T::zero(), T::one()));
184
185        let parser = (num_parser, opt(one_of("ij")));
186        let parser = map(parser, |(value, maybe_imag)| {
187            if maybe_imag.is_some() {
188                Complex::new(T::zero(), value)
189            } else {
190                Complex::new(value, T::zero())
191            }
192        });
193
194        alt((i_parser, parser))
195    }
196
197    impl NumLiteral for num_complex::Complex32 {
198        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
199            ensure_no_overlap(complex_parser(float)).parse(input)
200        }
201    }
202
203    impl NumLiteral for num_complex::Complex64 {
204        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
205            ensure_no_overlap(complex_parser(double)).parse(input)
206        }
207    }
208}
209
210#[cfg(feature = "num-bigint")]
211mod bigint {
212    use nom::{
213        character::complete::{char as tag_char, digit1},
214        combinator::{map_res, opt, recognize},
215        Parser as _,
216    };
217    use num_bigint::{BigInt, BigUint};
218    use num_traits::Num;
219
220    use super::NumLiteral;
221    use crate::{ErrorKind, InputSpan, NomResult};
222
223    impl NumLiteral for BigInt {
224        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
225            let parser = |s: InputSpan<'_>| {
226                BigInt::from_str_radix(s.fragment(), 10)
227                    .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
228            };
229            map_res(recognize((opt(tag_char('-')), digit1)), parser).parse(input)
230        }
231    }
232
233    impl NumLiteral for BigUint {
234        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
235            let parser = |s: InputSpan<'_>| {
236                BigUint::from_str_radix(s.fragment(), 10)
237                    .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
238            };
239            map_res(digit1, parser).parse(input)
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use assert_matches::assert_matches;
247    use nom::Err as NomErr;
248
249    use super::*;
250
251    #[test]
252    fn parsing_numbers_with_dot() {
253        #[derive(Debug, Clone, Copy)]
254        struct Sample {
255            input: &'static str,
256            consumed: usize,
257            value: f32,
258        }
259
260        #[rustfmt::skip]
261        const SAMPLES: &[Sample] = &[
262            Sample { input: "1.25+3", consumed: 4, value: 1.25 },
263
264            // Cases in which '.' should be consumed.
265            Sample { input: "1.", consumed: 2, value: 1.0 },
266            Sample { input: "-1.", consumed: 3, value: -1.0 },
267            Sample { input: "1. + 2.", consumed: 2, value: 1.0 },
268            Sample { input: "1.+2.", consumed: 2, value: 1.0 },
269            Sample { input: "1. .sin()", consumed: 2, value: 1.0 },
270
271            // Cases in which '.' should not be consumed.
272            Sample { input: "1.sin()", consumed: 1, value: 1.0 },
273            Sample { input: "-3.sin()", consumed: 2, value: -3.0 },
274            Sample { input: "-3.5.sin()", consumed: 4, value: -3.5 },
275        ];
276
277        for &sample in SAMPLES {
278            let (rest, number) = <f32 as NumLiteral>::parse(InputSpan::new(sample.input)).unwrap();
279            assert!(
280                (number - sample.value).abs() < f32::EPSILON,
281                "Failed sample: {sample:?}"
282            );
283            assert_eq!(
284                rest.location_offset(),
285                sample.consumed,
286                "Failed sample: {sample:?}"
287            );
288        }
289    }
290
291    #[cfg(feature = "num-complex")]
292    #[test]
293    fn parsing_i() {
294        use num_complex::Complex32;
295
296        use crate::{Expr, UnaryOp};
297
298        type C32Grammar = Untyped<NumGrammar<Complex32>>;
299
300        let parsed = C32Grammar::parse_statements("i").unwrap();
301        let ret = parsed.return_value.unwrap().extra;
302        assert_matches!(ret, Expr::Literal(lit) if lit == Complex32::i());
303
304        let parsed = C32Grammar::parse_statements("i + 5").unwrap();
305        let ret = parsed.return_value.unwrap().extra;
306        let i_as_lhs = &ret.binary_lhs().unwrap().extra;
307        assert_matches!(*i_as_lhs, Expr::Literal(lit) if lit == Complex32::i());
308
309        let parsed = C32Grammar::parse_statements("5 - i").unwrap();
310        let ret = parsed.return_value.unwrap().extra;
311        let i_as_rhs = &ret.binary_rhs().unwrap().extra;
312        assert_matches!(*i_as_rhs, Expr::Literal(lit) if lit == Complex32::i());
313
314        // `i` should not be parsed as a literal if it's a part of larger expression.
315        let parsed = C32Grammar::parse_statements("ix + 5").unwrap();
316        let ret = parsed.return_value.unwrap().extra;
317        let variable = &ret.binary_lhs().unwrap().extra;
318        assert_matches!(*variable, Expr::Variable);
319
320        let parsed = C32Grammar::parse_statements("-i + 5").unwrap();
321        let ret = parsed.return_value.unwrap().extra;
322        let negation_expr = &ret.binary_lhs().unwrap().extra;
323        let inner_lhs = match negation_expr {
324            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
325            _ => panic!("Unexpected LHS: {negation_expr:?}"),
326        };
327        assert_matches!(inner_lhs, Expr::Literal(lit) if *lit == Complex32::i());
328
329        let parsed = C32Grammar::parse_statements("-ix + 5").unwrap();
330        let ret = parsed.return_value.unwrap().extra;
331        let var_negation = &ret.binary_lhs().unwrap().extra;
332        let negated_var = match var_negation {
333            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
334            _ => panic!("Unexpected LHS: {var_negation:?}"),
335        };
336        assert_matches!(negated_var, Expr::Variable);
337    }
338
339    #[test]
340    fn uint_parsers() {
341        let (_, u8_val) = <u8 as NumLiteral>::parse(InputSpan::new("3")).unwrap();
342        assert_eq!(u8_val, 3);
343        let (_, u16_val) = <u16 as NumLiteral>::parse(InputSpan::new("33333")).unwrap();
344        assert_eq!(u16_val, 33_333);
345        let (_, u32_val) = <u32 as NumLiteral>::parse(InputSpan::new("1111111111")).unwrap();
346        assert_eq!(u32_val, 1_111_111_111);
347        let (_, u64_val) =
348            <u64 as NumLiteral>::parse(InputSpan::new(&u64::MAX.to_string())).unwrap();
349        assert_eq!(u64_val, u64::MAX);
350        let (_, u128_val) =
351            <u128 as NumLiteral>::parse(InputSpan::new(&u128::MAX.to_string())).unwrap();
352        assert_eq!(u128_val, u128::MAX);
353    }
354
355    #[test]
356    fn int_parsers() {
357        let (_, min_val) = <i8 as NumLiteral>::parse(InputSpan::new("-128")).unwrap();
358        assert_eq!(min_val, -128);
359        let (_, max_val) = <i8 as NumLiteral>::parse(InputSpan::new("127")).unwrap();
360        assert_eq!(max_val, 127);
361
362        let err = <i8 as NumLiteral>::parse(InputSpan::new("128")).unwrap_err();
363        let NomErr::Error(err) = &err else {
364            panic!("Unexpected error type: {err:?}");
365        };
366        assert_matches!(err.kind(), ErrorKind::Literal(_));
367    }
368
369    #[cfg(feature = "num-bigint")]
370    #[test]
371    fn bigint_parsers() {
372        use num_bigint::{BigInt, BigUint};
373
374        for len in 1..500 {
375            let input = "1".repeat(len);
376            let (_, value) = <BigUint as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
377            assert_eq!(value, BigUint::parse_bytes(input.as_bytes(), 10).unwrap());
378
379            let (_, value) = <BigInt as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
380            let expected_value = BigInt::parse_bytes(input.as_bytes(), 10).unwrap();
381            assert_eq!(value, expected_value);
382            let (_, value) =
383                <BigInt as NumLiteral>::parse(InputSpan::new(&format!("-{input}"))).unwrap();
384            assert_eq!(value, -expected_value);
385        }
386    }
387}