1use std::{collections::HashSet, error::Error, fmt, iter, mem, ops, str::FromStr};
2
3use arithmetic_parser::{
4 BinaryOp, Block, Expr, Lvalue, Spanned, SpannedExpr, Statement, UnaryOp,
5 grammars::{Features, NumGrammar, Parse, Untyped},
6};
7use num_complex::Complex32;
8use thiserror::Error;
9
10#[derive(Debug)]
12#[cfg_attr(
13 docsrs,
14 doc(cfg(any(
15 feature = "dyn_cpu_backend",
16 feature = "opencl_backend",
17 feature = "vulkan_backend"
18 )))
19)]
20pub struct FnError {
21 fragment: String,
22 line: u32,
23 column: usize,
24 source: ErrorSource,
25}
26
27#[derive(Debug)]
28enum ErrorSource {
29 Parse(String),
30 Eval(EvalError),
31}
32
33impl fmt::Display for ErrorSource {
34 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Parse(err) => write!(formatter, "[PARSE] {err}"),
37 Self::Eval(err) => write!(formatter, "[EVAL] {err}"),
38 }
39 }
40}
41
42#[derive(Debug, Error)]
43pub(crate) enum EvalError {
44 #[error("Last statement in function body is not an expression")]
45 NoReturn,
46 #[error("Useless expression")]
47 UselessExpr,
48 #[error("Cannot redefine variable")]
49 RedefinedVar,
50 #[error("Undefined variable")]
51 UndefinedVar,
52 #[error("Undefined function")]
53 UndefinedFn,
54 #[error("Function call has bogus arity")]
55 FnArity,
56 #[error("Unsupported language construct")]
57 Unsupported,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub(crate) enum UnaryFunction {
62 Arg,
63 Sqrt,
64 Exp,
65 Log,
66 Sinh,
67 Cosh,
68 Tanh,
69 Asinh,
70 Acosh,
71 Atanh,
72}
73
74impl UnaryFunction {
75 #[cfg(any(feature = "opencl_backend", feature = "vulkan_backend"))]
76 pub fn as_str(self) -> &'static str {
77 match self {
78 Self::Arg => "arg",
79 Self::Sqrt => "sqrt",
80 Self::Exp => "exp",
81 Self::Log => "log",
82 Self::Sinh => "sinh",
83 Self::Cosh => "cosh",
84 Self::Tanh => "tanh",
85 Self::Asinh => "asinh",
86 Self::Acosh => "acosh",
87 Self::Atanh => "atanh",
88 }
89 }
90
91 #[cfg(feature = "dyn_cpu_backend")]
92 pub fn eval(self, arg: Complex32) -> Complex32 {
93 match self {
94 Self::Arg => Complex32::new(arg.arg(), 0.0),
95 Self::Sqrt => arg.sqrt(),
96 Self::Exp => arg.exp(),
97 Self::Log => arg.ln(),
98 Self::Sinh => arg.sinh(),
99 Self::Cosh => arg.cosh(),
100 Self::Tanh => arg.tanh(),
101 Self::Asinh => arg.asinh(),
102 Self::Acosh => arg.acosh(),
103 Self::Atanh => arg.atanh(),
104 }
105 }
106}
107
108impl FromStr for UnaryFunction {
109 type Err = EvalError;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 match s {
113 "arg" => Ok(Self::Arg),
114 "sqrt" => Ok(Self::Sqrt),
115 "exp" => Ok(Self::Exp),
116 "log" => Ok(Self::Log),
117 "sinh" => Ok(Self::Sinh),
118 "cosh" => Ok(Self::Cosh),
119 "tanh" => Ok(Self::Tanh),
120 "asinh" => Ok(Self::Asinh),
121 "acosh" => Ok(Self::Acosh),
122 "atanh" => Ok(Self::Atanh),
123 _ => Err(EvalError::UndefinedFn),
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub(crate) enum Evaluated {
130 Value(Complex32),
131 Variable(String),
132 Negation(Box<Evaluated>),
133 Binary {
134 op: BinaryOp,
135 lhs: Box<Evaluated>,
136 rhs: Box<Evaluated>,
137 },
138 FunctionCall {
139 function: UnaryFunction,
140 arg: Box<Evaluated>,
141 },
142}
143
144impl Evaluated {
145 fn is_commutative(op: BinaryOp) -> bool {
146 matches!(op, BinaryOp::Add | BinaryOp::Mul)
147 }
148
149 fn is_commutative_pair(op: BinaryOp, other_op: BinaryOp) -> bool {
150 op.priority() == other_op.priority() && op != BinaryOp::Power
151 }
152
153 fn fold(mut op: BinaryOp, mut lhs: Self, mut rhs: Self) -> Self {
154 if let (Self::Value(lhs_val), Self::Value(rhs_val)) = (&lhs, &rhs) {
157 return Self::Value(match op {
158 BinaryOp::Add => *lhs_val + *rhs_val,
159 BinaryOp::Sub => *lhs_val - *rhs_val,
160 BinaryOp::Mul => *lhs_val * *rhs_val,
161 BinaryOp::Div => *lhs_val / *rhs_val,
162 BinaryOp::Power => lhs_val.powc(*rhs_val),
163 _ => unreachable!(),
164 });
165 }
166
167 if let Self::Value(val) = rhs {
168 match op {
173 BinaryOp::Sub => {
174 op = BinaryOp::Add;
175 rhs = Self::Value(-val);
176 }
177 BinaryOp::Div => {
178 op = BinaryOp::Mul;
179 rhs = Self::Value(1.0 / val);
180 }
181 _ => { }
182 }
183 } else if let Self::Value(_) = lhs {
184 if Self::is_commutative(op) {
188 mem::swap(&mut lhs, &mut rhs);
189 }
190 }
191
192 if let Self::Binary {
193 op: inner_op,
194 rhs: inner_rhs,
195 ..
196 } = &mut lhs
197 {
198 if Self::is_commutative_pair(*inner_op, op) {
199 if let Self::Value(inner_val) = **inner_rhs {
200 if let Self::Value(val) = rhs {
201 let new_rhs = match op {
209 BinaryOp::Add => inner_val + val,
210 BinaryOp::Mul => inner_val * val,
211 _ => unreachable!(),
212 };
214
215 *inner_rhs = Box::new(Self::Value(new_rhs));
216 return lhs;
217 }
218 mem::swap(&mut rhs, inner_rhs);
221 mem::swap(&mut op, inner_op);
222 }
223 }
224 }
225
226 Self::Binary {
227 op,
228 lhs: Box::new(lhs),
229 rhs: Box::new(rhs),
230 }
231 }
232}
233
234impl ops::Neg for Evaluated {
235 type Output = Self;
236
237 fn neg(self) -> Self::Output {
238 match self {
239 Self::Value(val) => Self::Value(-val),
240 Self::Negation(inner) => *inner,
241 other => Self::Negation(Box::new(other)),
242 }
243 }
244}
245
246impl FnError {
247 fn parse(source: &arithmetic_parser::Error, s: &str) -> Self {
248 let column = source.location().get_column();
249 Self {
250 fragment: source.location().span(s).to_owned(),
251 line: source.location().location_line(),
252 column,
253 source: ErrorSource::Parse(source.kind().to_string()),
254 }
255 }
256
257 fn eval<T>(span: &Spanned<'_, T>, source: EvalError) -> Self {
258 let column = span.get_column();
259 Self {
260 fragment: (*span.fragment()).to_owned(),
261 line: span.location_line(),
262 column,
263 source: ErrorSource::Eval(source),
264 }
265 }
266}
267
268impl fmt::Display for FnError {
269 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(formatter, "{}:{}: {}", self.line, self.column, self.source)?;
271 if formatter.alternate() {
272 formatter.write_str("\n")?;
273 formatter.pad(&self.fragment)?;
274 }
275 Ok(())
276 }
277}
278
279impl Error for FnError {
280 fn source(&self) -> Option<&(dyn Error + 'static)> {
281 match &self.source {
282 ErrorSource::Eval(e) => Some(e),
283 ErrorSource::Parse(_) => None,
284 }
285 }
286}
287
288type FnGrammarBase = Untyped<NumGrammar<Complex32>>;
289
290#[derive(Debug, Clone, Copy)]
291struct FnGrammar;
292
293impl Parse for FnGrammar {
294 type Base = FnGrammarBase;
295 const FEATURES: Features = Features::empty();
296}
297
298#[derive(Debug)]
299pub(crate) struct Context {
300 variables: HashSet<String>,
301}
302
303impl Context {
304 pub(crate) fn new(arg_name: &str) -> Self {
305 Self {
306 variables: iter::once(arg_name.to_owned()).collect(),
307 }
308 }
309
310 fn process(
311 &mut self,
312 block: &Block<'_, FnGrammarBase>,
313 total_span: Spanned<'_>,
314 ) -> Result<Function, FnError> {
315 let mut assignments = vec![];
316 for statement in &block.statements {
317 match &statement.extra {
318 Statement::Assignment { lhs, rhs } => {
319 let variable_name = match lhs.extra {
320 Lvalue::Variable { .. } => *lhs.fragment(),
321 _ => unreachable!("Tuples are disabled in parser"),
322 };
323
324 if self.variables.contains(variable_name) {
325 let err = FnError::eval(lhs, EvalError::RedefinedVar);
326 return Err(err);
327 }
328
329 let value = self.eval_expr(rhs)?;
331 self.variables.insert(variable_name.to_owned());
332 assignments.push((variable_name.to_owned(), value));
333 }
334
335 Statement::Expr(_) => {
336 return Err(FnError::eval(statement, EvalError::UselessExpr));
337 }
338
339 _ => return Err(FnError::eval(statement, EvalError::Unsupported)),
340 }
341 }
342
343 let return_value = block
344 .return_value
345 .as_ref()
346 .ok_or_else(|| FnError::eval(&total_span, EvalError::NoReturn))?;
347 let value = self.eval_expr(return_value)?;
348 assignments.push((String::new(), value));
349
350 Ok(Function { assignments })
351 }
352
353 fn eval_expr(&self, expr: &SpannedExpr<'_, FnGrammarBase>) -> Result<Evaluated, FnError> {
354 match &expr.extra {
355 Expr::Variable => {
356 let var_name = *expr.fragment();
357 self.variables
358 .get(var_name)
359 .ok_or_else(|| FnError::eval(expr, EvalError::UndefinedVar))?;
360
361 Ok(Evaluated::Variable(var_name.to_owned()))
362 }
363 Expr::Literal(lit) => Ok(Evaluated::Value(*lit)),
364
365 Expr::Unary { op, inner } => match op.extra {
366 UnaryOp::Neg => Ok(-self.eval_expr(inner)?),
367 _ => Err(FnError::eval(op, EvalError::Unsupported)),
368 },
369
370 Expr::Binary { lhs, op, rhs } => {
371 let lhs_value = self.eval_expr(lhs)?;
372 let rhs_value = self.eval_expr(rhs)?;
373
374 Ok(match op.extra {
375 BinaryOp::Add
376 | BinaryOp::Sub
377 | BinaryOp::Mul
378 | BinaryOp::Div
379 | BinaryOp::Power => Evaluated::fold(op.extra, lhs_value, rhs_value),
380 _ => {
381 return Err(FnError::eval(op, EvalError::Unsupported));
382 }
383 })
384 }
385
386 Expr::Function { name, args } => {
387 let fn_name = *name.fragment();
388 let function: UnaryFunction =
389 fn_name.parse().map_err(|e| FnError::eval(name, e))?;
390
391 if args.len() != 1 {
392 return Err(FnError::eval(expr, EvalError::FnArity));
393 }
394
395 Ok(Evaluated::FunctionCall {
396 function,
397 arg: Box::new(self.eval_expr(&args[0])?),
398 })
399 }
400
401 Expr::FnDefinition(_) | Expr::Block(_) | Expr::Tuple(_) | Expr::Method { .. } => {
402 unreachable!("Disabled in parser")
403 }
404
405 _ => Err(FnError::eval(expr, EvalError::Unsupported)),
406 }
407 }
408}
409
410#[cfg_attr(
437 docsrs,
438 doc(cfg(any(
439 feature = "dyn_cpu_backend",
440 feature = "opencl_backend",
441 feature = "vulkan_backend"
442 )))
443)]
444#[derive(Debug, Clone)]
445pub struct Function {
446 assignments: Vec<(String, Evaluated)>,
447}
448
449impl Function {
450 pub(crate) fn assignments(&self) -> impl Iterator<Item = (&str, &Evaluated)> + '_ {
451 self.assignments.iter().filter_map(|(name, value)| {
452 if name.is_empty() {
453 None
454 } else {
455 Some((name.as_str(), value))
456 }
457 })
458 }
459
460 pub(crate) fn return_value(&self) -> &Evaluated {
461 &self.assignments.last().unwrap().1
462 }
463}
464
465impl FromStr for Function {
466 type Err = FnError;
467
468 fn from_str(s: &str) -> Result<Self, Self::Err> {
469 let statements = FnGrammar::parse_statements(s).map_err(|err| FnError::parse(&err, s))?;
470 let body_span = Spanned::from_str(s, ..);
471 Context::new("z").process(&statements, body_span)
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 fn z_square() -> Evaluated {
480 Evaluated::Binary {
481 op: BinaryOp::Mul,
482 lhs: Box::new(Evaluated::Variable("z".to_owned())),
483 rhs: Box::new(Evaluated::Variable("z".to_owned())),
484 }
485 }
486
487 #[test]
488 fn simple_function() {
489 let function: Function = "z*z + (0.77 - 0.2i)".parse().unwrap();
490 let expected_expr = Evaluated::Binary {
491 op: BinaryOp::Add,
492 lhs: Box::new(z_square()),
493 rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
494 };
495 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
496 }
497
498 #[test]
499 fn simple_function_with_rewrite_rules() {
500 let function: Function = "z / 0.25 - 0.1i + (0.77 - 0.1i)".parse().unwrap();
501 let expected_expr = Evaluated::Binary {
502 op: BinaryOp::Add,
503 lhs: Box::new(Evaluated::Binary {
504 op: BinaryOp::Mul,
505 lhs: Box::new(Evaluated::Variable("z".to_owned())),
506 rhs: Box::new(Evaluated::Value(Complex32::new(4.0, 0.0))),
507 }),
508 rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
509 };
510 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
511 }
512
513 #[test]
514 fn function_with_several_rewrite_rules() {
515 let function: Function = "z + 0.1 - z*z + 0.3i".parse().unwrap();
516 let expected_expr = Evaluated::Binary {
517 op: BinaryOp::Add,
518 lhs: Box::new(Evaluated::Binary {
519 op: BinaryOp::Sub,
520 lhs: Box::new(Evaluated::Variable("z".to_owned())),
521 rhs: Box::new(z_square()),
522 }),
523 rhs: Box::new(Evaluated::Value(Complex32::new(0.1, 0.3))),
524 };
525 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
526 }
527
528 #[test]
529 fn simple_function_with_mul_rewrite_rules() {
530 let function: Function = "sinh(z - 5) / 4. * 6i".parse().unwrap();
531 let expected_expr = Evaluated::Binary {
532 op: BinaryOp::Mul,
533 lhs: Box::new(Evaluated::FunctionCall {
534 function: UnaryFunction::Sinh,
535 arg: Box::new(Evaluated::Binary {
536 op: BinaryOp::Add,
537 lhs: Box::new(Evaluated::Variable("z".to_owned())),
538 rhs: Box::new(Evaluated::Value(Complex32::new(-5.0, 0.0))),
539 }),
540 }),
541 rhs: Box::new(Evaluated::Value(Complex32::new(0.0, 1.5))),
542 };
543 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
544 }
545
546 #[test]
547 fn simple_function_with_redistribution() {
548 let function: Function = "0.5 + sinh(z) - 0.2i".parse().unwrap();
549 let expected_expr = Evaluated::Binary {
550 op: BinaryOp::Add,
551 lhs: Box::new(Evaluated::FunctionCall {
552 function: UnaryFunction::Sinh,
553 arg: Box::new(Evaluated::Variable("z".to_owned())),
554 }),
555 rhs: Box::new(Evaluated::Value(Complex32::new(0.5, -0.2))),
556 };
557 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
558 }
559
560 #[test]
561 fn function_with_assignments() {
562 let function: Function = "c = 0.5 - 0.2i; z*z + c".parse().unwrap();
563 let expected_expr = Evaluated::Binary {
564 op: BinaryOp::Add,
565 lhs: Box::new(z_square()),
566 rhs: Box::new(Evaluated::Variable("c".to_owned())),
567 };
568
569 assert_eq!(
570 function.assignments,
571 vec![
572 ("c".to_owned(), Evaluated::Value(Complex32::new(0.5, -0.2))),
573 (String::new(), expected_expr),
574 ]
575 );
576 }
577}