1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
//! Grammar functionality and a collection of standard grammars.
//!
//! # Defining grammars
//!
//! To define a [`Grammar`], you'll need a [`ParseLiteral`] implementation, which defines
//! how literals are parsed (numbers, strings, chars, hex- / base64-encoded byte sequences, etc.).
//! There are standard impls for floating-point number parsing and the complex numbers
//! (if the relevant feature is on).
//!
//! You may define how to parse type annotations by implementing `Grammar` explicitly.
//! Alternatively, if you don't need type annotations, a `Grammar` can be obtained from
//! a [`ParseLiteral`] impl by wrapping it into [`Untyped`].
//!
//! Once you have a `Grammar`, you can supply it as a `Base` for [`Parse`]. `Parse` methods
//! allow to parse complete or streaming [`Block`](crate::Block)s of statements.
//! Note that `Untyped` and [`Typed`] wrappers allow to avoid an explicit `Parse` impl.
//!
//! See [`ParseLiteral`], [`Grammar`] and [`Parse`] docs for the examples of various grammar
//! definitions.

use nom::{
    bytes::complete::take_while_m_n,
    character::complete::{char as tag_char, digit1},
    combinator::{map_res, not, opt, peek, recognize},
    number::complete::{double, float},
    sequence::{terminated, tuple},
    Slice,
};

use core::{fmt, marker::PhantomData};

mod traits;

pub use self::traits::{
    Features, Grammar, IntoInputSpan, MockTypes, Parse, ParseLiteral, Typed, Untyped,
    WithMockedTypes,
};

use crate::{spans::NomResult, ErrorKind, InputSpan};

/// Single-type numeric grammar parameterized by the literal type.
#[derive(Debug)]
pub struct NumGrammar<T>(PhantomData<T>);

/// Type alias for a grammar on `f32` literals.
pub type F32Grammar = NumGrammar<f32>;
/// Type alias for a grammar on `f64` literals.
pub type F64Grammar = NumGrammar<f64>;

impl<T: NumLiteral> ParseLiteral for NumGrammar<T> {
    type Lit = T;

    fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
        T::parse(input)
    }
}

/// Numeric literal used in `NumGrammar`s.
pub trait NumLiteral: 'static + Clone + fmt::Debug {
    /// Tries to parse a literal.
    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self>;
}

/// Ensures that the child parser does not consume a part of a larger expression by rejecting
/// if the part following the input is an alphanumeric char or `_`.
///
/// For example, `float` parses `-Inf`, which can lead to parser failure if it's a part of
/// a larger expression (e.g., `-Infer(2, 3)`).
pub fn ensure_no_overlap<'a, T>(
    mut parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, T> {
    let truncating_parser = move |input| {
        parser(input).map(|(rest, number)| (maybe_truncate_consumed_input(input, rest), number))
    };

    terminated(
        truncating_parser,
        peek(not(take_while_m_n(1, 1, |c: char| {
            c.is_ascii_alphabetic() || c == '_'
        }))),
    )
}

fn can_start_a_var_name(byte: u8) -> bool {
    byte == b'_' || byte.is_ascii_alphabetic()
}

fn maybe_truncate_consumed_input<'a>(input: InputSpan<'a>, rest: InputSpan<'a>) -> InputSpan<'a> {
    let relative_offset = rest.location_offset() - input.location_offset();
    debug_assert!(relative_offset > 0, "num parser succeeded for empty string");
    let last_consumed_byte_index = relative_offset - 1;

    let input_fragment = *input.fragment();
    let input_as_bytes = input_fragment.as_bytes();
    if relative_offset < input_fragment.len()
        && input_fragment.is_char_boundary(last_consumed_byte_index)
        && input_as_bytes[last_consumed_byte_index] == b'.'
        && can_start_a_var_name(input_as_bytes[relative_offset])
    {
        // The last char consumed by the parser is '.' and the next part looks like
        // a method call. Shift the `rest` boundary to include '.'.
        input.slice(last_consumed_byte_index..)
    } else {
        rest
    }
}

macro_rules! impl_num_literal_for_uint {
    ($($num:ident),+) => {
        $(
        impl NumLiteral for $num {
            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
                let parser = |s: InputSpan<'_>| {
                    s.fragment()
                        .parse::<$num>()
                        .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
                };
                map_res(digit1, parser)(input)
            }
        }
        )+
    };
}

impl_num_literal_for_uint!(u8, u16, u32, u64, u128);

macro_rules! impl_num_literal_for_int {
    ($($num:ident),+) => {
        $(
        impl NumLiteral for $num {
            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
                let parser = |s: InputSpan<'_>| {
                    s.fragment()
                        .parse::<$num>()
                        .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
                };
                map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
            }
        }
        )+
    };
}

impl_num_literal_for_int!(i8, i16, i32, i64, i128);

impl NumLiteral for f32 {
    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
        ensure_no_overlap(float)(input)
    }
}

impl NumLiteral for f64 {
    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
        ensure_no_overlap(double)(input)
    }
}

#[cfg(feature = "num-complex")]
mod complex {
    use nom::{
        branch::alt,
        character::complete::one_of,
        combinator::{map, opt},
        number::complete::{double, float},
        sequence::tuple,
    };
    use num_complex::Complex;
    use num_traits::Num;

    use super::{ensure_no_overlap, NumLiteral};
    use crate::{InputSpan, NomResult};

    fn complex_parser<'a, T: Num>(
        num_parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
    ) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, Complex<T>> {
        let i_parser = map(one_of("ij"), |_| Complex::new(T::zero(), T::one()));

        let parser = tuple((num_parser, opt(one_of("ij"))));
        let parser = map(parser, |(value, maybe_imag)| {
            if maybe_imag.is_some() {
                Complex::new(T::zero(), value)
            } else {
                Complex::new(value, T::zero())
            }
        });

        alt((i_parser, parser))
    }

    impl NumLiteral for num_complex::Complex32 {
        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
            ensure_no_overlap(complex_parser(float))(input)
        }
    }

    impl NumLiteral for num_complex::Complex64 {
        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
            ensure_no_overlap(complex_parser(double))(input)
        }
    }
}

#[cfg(feature = "num-bigint")]
mod bigint {
    use nom::{
        character::complete::{char as tag_char, digit1},
        combinator::{map_res, opt, recognize},
        sequence::tuple,
    };
    use num_bigint::{BigInt, BigUint};
    use num_traits::Num;

    use super::NumLiteral;
    use crate::{ErrorKind, InputSpan, NomResult};

    impl NumLiteral for BigInt {
        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
            let parser = |s: InputSpan<'_>| {
                BigInt::from_str_radix(s.fragment(), 10)
                    .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
            };
            map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
        }
    }

    impl NumLiteral for BigUint {
        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
            let parser = |s: InputSpan<'_>| {
                BigUint::from_str_radix(s.fragment(), 10)
                    .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
            };
            map_res(digit1, parser)(input)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use assert_matches::assert_matches;
    use nom::Err as NomErr;

    #[test]
    fn parsing_numbers_with_dot() {
        #[derive(Debug, Clone, Copy)]
        struct Sample {
            input: &'static str,
            consumed: usize,
            value: f32,
        }

        #[rustfmt::skip]
        const SAMPLES: &[Sample] = &[
            Sample { input: "1.25+3", consumed: 4, value: 1.25 },

            // Cases in which '.' should be consumed.
            Sample { input: "1.", consumed: 2, value: 1.0 },
            Sample { input: "-1.", consumed: 3, value: -1.0 },
            Sample { input: "1. + 2.", consumed: 2, value: 1.0 },
            Sample { input: "1.+2.", consumed: 2, value: 1.0 },
            Sample { input: "1. .sin()", consumed: 2, value: 1.0 },

            // Cases in which '.' should not be consumed.
            Sample { input: "1.sin()", consumed: 1, value: 1.0 },
            Sample { input: "-3.sin()", consumed: 2, value: -3.0 },
            Sample { input: "-3.5.sin()", consumed: 4, value: -3.5 },
        ];

        for &sample in SAMPLES {
            let (rest, number) = <f32 as NumLiteral>::parse(InputSpan::new(sample.input)).unwrap();
            assert!(
                (number - sample.value).abs() < f32::EPSILON,
                "Failed sample: {sample:?}"
            );
            assert_eq!(
                rest.location_offset(),
                sample.consumed,
                "Failed sample: {sample:?}"
            );
        }
    }

    #[cfg(feature = "num-complex")]
    #[test]
    fn parsing_i() {
        use crate::{Expr, UnaryOp};
        use num_complex::Complex32;

        type C32Grammar = Untyped<NumGrammar<Complex32>>;

        let parsed = C32Grammar::parse_statements("i").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        assert_matches!(ret, Expr::Literal(lit) if lit == Complex32::i());

        let parsed = C32Grammar::parse_statements("i + 5").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        let i_as_lhs = &ret.binary_lhs().unwrap().extra;
        assert_matches!(*i_as_lhs, Expr::Literal(lit) if lit == Complex32::i());

        let parsed = C32Grammar::parse_statements("5 - i").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        let i_as_rhs = &ret.binary_rhs().unwrap().extra;
        assert_matches!(*i_as_rhs, Expr::Literal(lit) if lit == Complex32::i());

        // `i` should not be parsed as a literal if it's a part of larger expression.
        let parsed = C32Grammar::parse_statements("ix + 5").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        let variable = &ret.binary_lhs().unwrap().extra;
        assert_matches!(*variable, Expr::Variable);

        let parsed = C32Grammar::parse_statements("-i + 5").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        let negation_expr = &ret.binary_lhs().unwrap().extra;
        let inner_lhs = match negation_expr {
            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
            _ => panic!("Unexpected LHS: {negation_expr:?}"),
        };
        assert_matches!(inner_lhs, Expr::Literal(lit) if *lit == Complex32::i());

        let parsed = C32Grammar::parse_statements("-ix + 5").unwrap();
        let ret = parsed.return_value.unwrap().extra;
        let var_negation = &ret.binary_lhs().unwrap().extra;
        let negated_var = match var_negation {
            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
            _ => panic!("Unexpected LHS: {var_negation:?}"),
        };
        assert_matches!(negated_var, Expr::Variable);
    }

    #[test]
    fn uint_parsers() {
        let (_, u8_val) = <u8 as NumLiteral>::parse(InputSpan::new("3")).unwrap();
        assert_eq!(u8_val, 3);
        let (_, u16_val) = <u16 as NumLiteral>::parse(InputSpan::new("33333")).unwrap();
        assert_eq!(u16_val, 33_333);
        let (_, u32_val) = <u32 as NumLiteral>::parse(InputSpan::new("1111111111")).unwrap();
        assert_eq!(u32_val, 1_111_111_111);
        let (_, u64_val) =
            <u64 as NumLiteral>::parse(InputSpan::new(&u64::MAX.to_string())).unwrap();
        assert_eq!(u64_val, u64::MAX);
        let (_, u128_val) =
            <u128 as NumLiteral>::parse(InputSpan::new(&u128::MAX.to_string())).unwrap();
        assert_eq!(u128_val, u128::MAX);
    }

    #[test]
    fn int_parsers() {
        let (_, min_val) = <i8 as NumLiteral>::parse(InputSpan::new("-128")).unwrap();
        assert_eq!(min_val, -128);
        let (_, max_val) = <i8 as NumLiteral>::parse(InputSpan::new("127")).unwrap();
        assert_eq!(max_val, 127);

        let err = <i8 as NumLiteral>::parse(InputSpan::new("128")).unwrap_err();
        let NomErr::Error(err) = &err else {
            panic!("Unexpected error type: {err:?}");
        };
        assert_matches!(err.kind(), ErrorKind::Literal(_));
    }

    #[cfg(feature = "num-bigint")]
    #[test]
    fn bigint_parsers() {
        use num_bigint::{BigInt, BigUint};

        for len in 1..500 {
            let input = "1".repeat(len);
            let (_, value) = <BigUint as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
            assert_eq!(value, BigUint::parse_bytes(input.as_bytes(), 10).unwrap());

            let (_, value) = <BigInt as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
            let expected_value = BigInt::parse_bytes(input.as_bytes(), 10).unwrap();
            assert_eq!(value, expected_value);
            let (_, value) =
                <BigInt as NumLiteral>::parse(InputSpan::new(&format!("-{input}"))).unwrap();
            assert_eq!(value, -expected_value);
        }
    }
}