1use core::{fmt, marker::PhantomData};
22
23use nom::{
24 bytes::complete::take_while_m_n,
25 character::complete::{char as tag_char, digit1},
26 combinator::{map_res, not, opt, peek, recognize},
27 number::complete::{double, float},
28 sequence::terminated,
29 Input, Parser as _,
30};
31
32pub use self::traits::{
33 Features, Grammar, IntoInputSpan, MockTypes, Parse, ParseLiteral, Typed, Untyped,
34 WithMockedTypes,
35};
36use crate::{spans::NomResult, ErrorKind, InputSpan};
37
38mod traits;
39
40#[derive(Debug)]
42pub struct NumGrammar<T>(PhantomData<T>);
43
44pub type F32Grammar = NumGrammar<f32>;
46pub type F64Grammar = NumGrammar<f64>;
48
49impl<T: NumLiteral> ParseLiteral for NumGrammar<T> {
50 type Lit = T;
51
52 fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
53 T::parse(input)
54 }
55}
56
57pub trait NumLiteral: 'static + Clone + fmt::Debug {
59 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self>;
61}
62
63pub fn ensure_no_overlap<'a, F>(
69 mut parser: F,
70) -> impl nom::Parser<InputSpan<'a>, Output = F::Output, Error = F::Error>
71where
72 F: nom::Parser<InputSpan<'a>>,
73{
74 let truncating_parser = move |input| {
75 parser
76 .parse(input)
77 .map(|(rest, number)| (maybe_truncate_consumed_input(input, rest), number))
78 };
79
80 terminated(
81 truncating_parser,
82 peek(not(take_while_m_n(1, 1, |c: char| {
83 c.is_ascii_alphabetic() || c == '_'
84 }))),
85 )
86}
87
88fn can_start_a_var_name(byte: u8) -> bool {
89 byte == b'_' || byte.is_ascii_alphabetic()
90}
91
92fn maybe_truncate_consumed_input<'a>(input: InputSpan<'a>, rest: InputSpan<'a>) -> InputSpan<'a> {
93 let relative_offset = rest.location_offset() - input.location_offset();
94 debug_assert!(relative_offset > 0, "num parser succeeded for empty string");
95 let last_consumed_byte_index = relative_offset - 1;
96
97 let input_fragment = *input.fragment();
98 let input_as_bytes = input_fragment.as_bytes();
99 if relative_offset < input_fragment.len()
100 && input_fragment.is_char_boundary(last_consumed_byte_index)
101 && input_as_bytes[last_consumed_byte_index] == b'.'
102 && can_start_a_var_name(input_as_bytes[relative_offset])
103 {
104 input.take_from(last_consumed_byte_index)
107 } else {
108 rest
109 }
110}
111
112macro_rules! impl_num_literal_for_uint {
113 ($($num:ident),+) => {
114 $(
115 impl NumLiteral for $num {
116 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
117 let parser = |s: InputSpan<'_>| {
118 s.fragment()
119 .parse::<$num>()
120 .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
121 };
122 map_res(digit1, parser).parse(input)
123 }
124 }
125 )+
126 };
127}
128
129impl_num_literal_for_uint!(u8, u16, u32, u64, u128);
130
131macro_rules! impl_num_literal_for_int {
132 ($($num:ident),+) => {
133 $(
134 impl NumLiteral for $num {
135 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
136 let parser = |s: InputSpan<'_>| {
137 s.fragment()
138 .parse::<$num>()
139 .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
140 };
141 map_res(recognize((opt(tag_char('-')), digit1)), parser).parse(input)
142 }
143 }
144 )+
145 };
146}
147
148impl_num_literal_for_int!(i8, i16, i32, i64, i128);
149
150impl NumLiteral for f32 {
151 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
152 ensure_no_overlap(float).parse(input)
153 }
154}
155
156impl NumLiteral for f64 {
157 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
158 ensure_no_overlap(double).parse(input)
159 }
160}
161
162#[cfg(feature = "num-complex")]
163mod complex {
164 use nom::{
165 branch::alt,
166 character::complete::one_of,
167 combinator::{map, opt},
168 number::complete::{double, float},
169 Parser as _,
170 };
171 use num_complex::Complex;
172 use num_traits::Num;
173
174 use super::{ensure_no_overlap, NumLiteral};
175 use crate::{InputSpan, NomResult};
176
177 fn complex_parser<'a, T: Num, F>(
178 num_parser: F,
179 ) -> impl nom::Parser<InputSpan<'a>, Output = Complex<T>, Error = F::Error>
180 where
181 F: nom::Parser<InputSpan<'a>, Output = T>,
182 {
183 let i_parser = map(one_of("ij"), |_| Complex::new(T::zero(), T::one()));
184
185 let parser = (num_parser, opt(one_of("ij")));
186 let parser = map(parser, |(value, maybe_imag)| {
187 if maybe_imag.is_some() {
188 Complex::new(T::zero(), value)
189 } else {
190 Complex::new(value, T::zero())
191 }
192 });
193
194 alt((i_parser, parser))
195 }
196
197 impl NumLiteral for num_complex::Complex32 {
198 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
199 ensure_no_overlap(complex_parser(float)).parse(input)
200 }
201 }
202
203 impl NumLiteral for num_complex::Complex64 {
204 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
205 ensure_no_overlap(complex_parser(double)).parse(input)
206 }
207 }
208}
209
210#[cfg(feature = "num-bigint")]
211mod bigint {
212 use nom::{
213 character::complete::{char as tag_char, digit1},
214 combinator::{map_res, opt, recognize},
215 Parser as _,
216 };
217 use num_bigint::{BigInt, BigUint};
218 use num_traits::Num;
219
220 use super::NumLiteral;
221 use crate::{ErrorKind, InputSpan, NomResult};
222
223 impl NumLiteral for BigInt {
224 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
225 let parser = |s: InputSpan<'_>| {
226 BigInt::from_str_radix(s.fragment(), 10)
227 .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
228 };
229 map_res(recognize((opt(tag_char('-')), digit1)), parser).parse(input)
230 }
231 }
232
233 impl NumLiteral for BigUint {
234 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
235 let parser = |s: InputSpan<'_>| {
236 BigUint::from_str_radix(s.fragment(), 10)
237 .map_err(|err| ErrorKind::literal(anyhow::anyhow!(err)))
238 };
239 map_res(digit1, parser).parse(input)
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use assert_matches::assert_matches;
247 use nom::Err as NomErr;
248
249 use super::*;
250
251 #[test]
252 fn parsing_numbers_with_dot() {
253 #[derive(Debug, Clone, Copy)]
254 struct Sample {
255 input: &'static str,
256 consumed: usize,
257 value: f32,
258 }
259
260 #[rustfmt::skip]
261 const SAMPLES: &[Sample] = &[
262 Sample { input: "1.25+3", consumed: 4, value: 1.25 },
263
264 Sample { input: "1.", consumed: 2, value: 1.0 },
266 Sample { input: "-1.", consumed: 3, value: -1.0 },
267 Sample { input: "1. + 2.", consumed: 2, value: 1.0 },
268 Sample { input: "1.+2.", consumed: 2, value: 1.0 },
269 Sample { input: "1. .sin()", consumed: 2, value: 1.0 },
270
271 Sample { input: "1.sin()", consumed: 1, value: 1.0 },
273 Sample { input: "-3.sin()", consumed: 2, value: -3.0 },
274 Sample { input: "-3.5.sin()", consumed: 4, value: -3.5 },
275 ];
276
277 for &sample in SAMPLES {
278 let (rest, number) = <f32 as NumLiteral>::parse(InputSpan::new(sample.input)).unwrap();
279 assert!(
280 (number - sample.value).abs() < f32::EPSILON,
281 "Failed sample: {sample:?}"
282 );
283 assert_eq!(
284 rest.location_offset(),
285 sample.consumed,
286 "Failed sample: {sample:?}"
287 );
288 }
289 }
290
291 #[cfg(feature = "num-complex")]
292 #[test]
293 fn parsing_i() {
294 use num_complex::Complex32;
295
296 use crate::{Expr, UnaryOp};
297
298 type C32Grammar = Untyped<NumGrammar<Complex32>>;
299
300 let parsed = C32Grammar::parse_statements("i").unwrap();
301 let ret = parsed.return_value.unwrap().extra;
302 assert_matches!(ret, Expr::Literal(lit) if lit == Complex32::i());
303
304 let parsed = C32Grammar::parse_statements("i + 5").unwrap();
305 let ret = parsed.return_value.unwrap().extra;
306 let i_as_lhs = &ret.binary_lhs().unwrap().extra;
307 assert_matches!(*i_as_lhs, Expr::Literal(lit) if lit == Complex32::i());
308
309 let parsed = C32Grammar::parse_statements("5 - i").unwrap();
310 let ret = parsed.return_value.unwrap().extra;
311 let i_as_rhs = &ret.binary_rhs().unwrap().extra;
312 assert_matches!(*i_as_rhs, Expr::Literal(lit) if lit == Complex32::i());
313
314 let parsed = C32Grammar::parse_statements("ix + 5").unwrap();
316 let ret = parsed.return_value.unwrap().extra;
317 let variable = &ret.binary_lhs().unwrap().extra;
318 assert_matches!(*variable, Expr::Variable);
319
320 let parsed = C32Grammar::parse_statements("-i + 5").unwrap();
321 let ret = parsed.return_value.unwrap().extra;
322 let negation_expr = &ret.binary_lhs().unwrap().extra;
323 let inner_lhs = match negation_expr {
324 Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
325 _ => panic!("Unexpected LHS: {negation_expr:?}"),
326 };
327 assert_matches!(inner_lhs, Expr::Literal(lit) if *lit == Complex32::i());
328
329 let parsed = C32Grammar::parse_statements("-ix + 5").unwrap();
330 let ret = parsed.return_value.unwrap().extra;
331 let var_negation = &ret.binary_lhs().unwrap().extra;
332 let negated_var = match var_negation {
333 Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
334 _ => panic!("Unexpected LHS: {var_negation:?}"),
335 };
336 assert_matches!(negated_var, Expr::Variable);
337 }
338
339 #[test]
340 fn uint_parsers() {
341 let (_, u8_val) = <u8 as NumLiteral>::parse(InputSpan::new("3")).unwrap();
342 assert_eq!(u8_val, 3);
343 let (_, u16_val) = <u16 as NumLiteral>::parse(InputSpan::new("33333")).unwrap();
344 assert_eq!(u16_val, 33_333);
345 let (_, u32_val) = <u32 as NumLiteral>::parse(InputSpan::new("1111111111")).unwrap();
346 assert_eq!(u32_val, 1_111_111_111);
347 let (_, u64_val) =
348 <u64 as NumLiteral>::parse(InputSpan::new(&u64::MAX.to_string())).unwrap();
349 assert_eq!(u64_val, u64::MAX);
350 let (_, u128_val) =
351 <u128 as NumLiteral>::parse(InputSpan::new(&u128::MAX.to_string())).unwrap();
352 assert_eq!(u128_val, u128::MAX);
353 }
354
355 #[test]
356 fn int_parsers() {
357 let (_, min_val) = <i8 as NumLiteral>::parse(InputSpan::new("-128")).unwrap();
358 assert_eq!(min_val, -128);
359 let (_, max_val) = <i8 as NumLiteral>::parse(InputSpan::new("127")).unwrap();
360 assert_eq!(max_val, 127);
361
362 let err = <i8 as NumLiteral>::parse(InputSpan::new("128")).unwrap_err();
363 let NomErr::Error(err) = &err else {
364 panic!("Unexpected error type: {err:?}");
365 };
366 assert_matches!(err.kind(), ErrorKind::Literal(_));
367 }
368
369 #[cfg(feature = "num-bigint")]
370 #[test]
371 fn bigint_parsers() {
372 use num_bigint::{BigInt, BigUint};
373
374 for len in 1..500 {
375 let input = "1".repeat(len);
376 let (_, value) = <BigUint as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
377 assert_eq!(value, BigUint::parse_bytes(input.as_bytes(), 10).unwrap());
378
379 let (_, value) = <BigInt as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
380 let expected_value = BigInt::parse_bytes(input.as_bytes(), 10).unwrap();
381 assert_eq!(value, expected_value);
382 let (_, value) =
383 <BigInt as NumLiteral>::parse(InputSpan::new(&format!("-{input}"))).unwrap();
384 assert_eq!(value, -expected_value);
385 }
386 }
387}