julia_set/
function.rs

1use std::{collections::HashSet, error::Error, fmt, iter, mem, ops, str::FromStr};
2
3use arithmetic_parser::{
4    BinaryOp, Block, Expr, Lvalue, Spanned, SpannedExpr, Statement, UnaryOp,
5    grammars::{Features, NumGrammar, Parse, Untyped},
6};
7use num_complex::Complex32;
8use thiserror::Error;
9
10/// Error associated with creating a [`Function`].
11#[derive(Debug)]
12#[cfg_attr(
13    docsrs,
14    doc(cfg(any(
15        feature = "dyn_cpu_backend",
16        feature = "opencl_backend",
17        feature = "vulkan_backend"
18    )))
19)]
20pub struct FnError {
21    fragment: String,
22    line: u32,
23    column: usize,
24    source: ErrorSource,
25}
26
27#[derive(Debug)]
28enum ErrorSource {
29    Parse(String),
30    Eval(EvalError),
31}
32
33impl fmt::Display for ErrorSource {
34    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Parse(err) => write!(formatter, "[PARSE] {err}"),
37            Self::Eval(err) => write!(formatter, "[EVAL] {err}"),
38        }
39    }
40}
41
42#[derive(Debug, Error)]
43pub(crate) enum EvalError {
44    #[error("Last statement in function body is not an expression")]
45    NoReturn,
46    #[error("Useless expression")]
47    UselessExpr,
48    #[error("Cannot redefine variable")]
49    RedefinedVar,
50    #[error("Undefined variable")]
51    UndefinedVar,
52    #[error("Undefined function")]
53    UndefinedFn,
54    #[error("Function call has bogus arity")]
55    FnArity,
56    #[error("Unsupported language construct")]
57    Unsupported,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub(crate) enum UnaryFunction {
62    Arg,
63    Sqrt,
64    Exp,
65    Log,
66    Sinh,
67    Cosh,
68    Tanh,
69    Asinh,
70    Acosh,
71    Atanh,
72}
73
74impl UnaryFunction {
75    #[cfg(any(feature = "opencl_backend", feature = "vulkan_backend"))]
76    pub fn as_str(self) -> &'static str {
77        match self {
78            Self::Arg => "arg",
79            Self::Sqrt => "sqrt",
80            Self::Exp => "exp",
81            Self::Log => "log",
82            Self::Sinh => "sinh",
83            Self::Cosh => "cosh",
84            Self::Tanh => "tanh",
85            Self::Asinh => "asinh",
86            Self::Acosh => "acosh",
87            Self::Atanh => "atanh",
88        }
89    }
90
91    #[cfg(feature = "dyn_cpu_backend")]
92    pub fn eval(self, arg: Complex32) -> Complex32 {
93        match self {
94            Self::Arg => Complex32::new(arg.arg(), 0.0),
95            Self::Sqrt => arg.sqrt(),
96            Self::Exp => arg.exp(),
97            Self::Log => arg.ln(),
98            Self::Sinh => arg.sinh(),
99            Self::Cosh => arg.cosh(),
100            Self::Tanh => arg.tanh(),
101            Self::Asinh => arg.asinh(),
102            Self::Acosh => arg.acosh(),
103            Self::Atanh => arg.atanh(),
104        }
105    }
106}
107
108impl FromStr for UnaryFunction {
109    type Err = EvalError;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        match s {
113            "arg" => Ok(Self::Arg),
114            "sqrt" => Ok(Self::Sqrt),
115            "exp" => Ok(Self::Exp),
116            "log" => Ok(Self::Log),
117            "sinh" => Ok(Self::Sinh),
118            "cosh" => Ok(Self::Cosh),
119            "tanh" => Ok(Self::Tanh),
120            "asinh" => Ok(Self::Asinh),
121            "acosh" => Ok(Self::Acosh),
122            "atanh" => Ok(Self::Atanh),
123            _ => Err(EvalError::UndefinedFn),
124        }
125    }
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub(crate) enum Evaluated {
130    Value(Complex32),
131    Variable(String),
132    Negation(Box<Evaluated>),
133    Binary {
134        op: BinaryOp,
135        lhs: Box<Evaluated>,
136        rhs: Box<Evaluated>,
137    },
138    FunctionCall {
139        function: UnaryFunction,
140        arg: Box<Evaluated>,
141    },
142}
143
144impl Evaluated {
145    fn is_commutative(op: BinaryOp) -> bool {
146        matches!(op, BinaryOp::Add | BinaryOp::Mul)
147    }
148
149    fn is_commutative_pair(op: BinaryOp, other_op: BinaryOp) -> bool {
150        op.priority() == other_op.priority() && op != BinaryOp::Power
151    }
152
153    fn fold(mut op: BinaryOp, mut lhs: Self, mut rhs: Self) -> Self {
154        // First, check if the both operands are values. In this case, we can eagerly compute
155        // the resulting value.
156        if let (Self::Value(lhs_val), Self::Value(rhs_val)) = (&lhs, &rhs) {
157            return Self::Value(match op {
158                BinaryOp::Add => *lhs_val + *rhs_val,
159                BinaryOp::Sub => *lhs_val - *rhs_val,
160                BinaryOp::Mul => *lhs_val * *rhs_val,
161                BinaryOp::Div => *lhs_val / *rhs_val,
162                BinaryOp::Power => lhs_val.powc(*rhs_val),
163                _ => unreachable!(),
164            });
165        }
166
167        if let Self::Value(val) = rhs {
168            // Convert an RHS value to use a commutative op (e.g., `+` instead of `-`).
169            // This will come in handy during later transforms.
170            //
171            // For example, this will transform `z - 1` into `z + -1`.
172            match op {
173                BinaryOp::Sub => {
174                    op = BinaryOp::Add;
175                    rhs = Self::Value(-val);
176                }
177                BinaryOp::Div => {
178                    op = BinaryOp::Mul;
179                    rhs = Self::Value(1.0 / val);
180                }
181                _ => { /* do nothing */ }
182            }
183        } else if let Self::Value(_) = lhs {
184            // Swap LHS and RHS to move the value to the right.
185            //
186            // For example, this will transform `1 + z` into `z + 1`.
187            if Self::is_commutative(op) {
188                mem::swap(&mut lhs, &mut rhs);
189            }
190        }
191
192        if let Self::Binary {
193            op: inner_op,
194            rhs: inner_rhs,
195            ..
196        } = &mut lhs
197        {
198            if Self::is_commutative_pair(*inner_op, op) {
199                if let Self::Value(inner_val) = **inner_rhs {
200                    if let Self::Value(val) = rhs {
201                        // Make the following replacement:
202                        //
203                        //    op             op
204                        //   /  \           /  \
205                        //  op  c   ---->  a  b op c
206                        // /  \
207                        // a  b
208                        let new_rhs = match op {
209                            BinaryOp::Add => inner_val + val,
210                            BinaryOp::Mul => inner_val * val,
211                            _ => unreachable!(),
212                            // ^-- We've replaced '-' and '/' `op`s previously.
213                        };
214
215                        *inner_rhs = Box::new(Self::Value(new_rhs));
216                        return lhs;
217                    }
218                    // Switch `inner_rhs` and `rhs`, moving a `Value` to the right.
219                    // For example, this will replace `z + 1 - z^2` to `z - z^2 + 1`.
220                    mem::swap(&mut rhs, inner_rhs);
221                    mem::swap(&mut op, inner_op);
222                }
223            }
224        }
225
226        Self::Binary {
227            op,
228            lhs: Box::new(lhs),
229            rhs: Box::new(rhs),
230        }
231    }
232}
233
234impl ops::Neg for Evaluated {
235    type Output = Self;
236
237    fn neg(self) -> Self::Output {
238        match self {
239            Self::Value(val) => Self::Value(-val),
240            Self::Negation(inner) => *inner,
241            other => Self::Negation(Box::new(other)),
242        }
243    }
244}
245
246impl FnError {
247    fn parse(source: &arithmetic_parser::Error, s: &str) -> Self {
248        let column = source.location().get_column();
249        Self {
250            fragment: source.location().span(s).to_owned(),
251            line: source.location().location_line(),
252            column,
253            source: ErrorSource::Parse(source.kind().to_string()),
254        }
255    }
256
257    fn eval<T>(span: &Spanned<'_, T>, source: EvalError) -> Self {
258        let column = span.get_column();
259        Self {
260            fragment: (*span.fragment()).to_owned(),
261            line: span.location_line(),
262            column,
263            source: ErrorSource::Eval(source),
264        }
265    }
266}
267
268impl fmt::Display for FnError {
269    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
270        write!(formatter, "{}:{}: {}", self.line, self.column, self.source)?;
271        if formatter.alternate() {
272            formatter.write_str("\n")?;
273            formatter.pad(&self.fragment)?;
274        }
275        Ok(())
276    }
277}
278
279impl Error for FnError {
280    fn source(&self) -> Option<&(dyn Error + 'static)> {
281        match &self.source {
282            ErrorSource::Eval(e) => Some(e),
283            ErrorSource::Parse(_) => None,
284        }
285    }
286}
287
288type FnGrammarBase = Untyped<NumGrammar<Complex32>>;
289
290#[derive(Debug, Clone, Copy)]
291struct FnGrammar;
292
293impl Parse for FnGrammar {
294    type Base = FnGrammarBase;
295    const FEATURES: Features = Features::empty();
296}
297
298#[derive(Debug)]
299pub(crate) struct Context {
300    variables: HashSet<String>,
301}
302
303impl Context {
304    pub(crate) fn new(arg_name: &str) -> Self {
305        Self {
306            variables: iter::once(arg_name.to_owned()).collect(),
307        }
308    }
309
310    fn process(
311        &mut self,
312        block: &Block<'_, FnGrammarBase>,
313        total_span: Spanned<'_>,
314    ) -> Result<Function, FnError> {
315        let mut assignments = vec![];
316        for statement in &block.statements {
317            match &statement.extra {
318                Statement::Assignment { lhs, rhs } => {
319                    let variable_name = match lhs.extra {
320                        Lvalue::Variable { .. } => *lhs.fragment(),
321                        _ => unreachable!("Tuples are disabled in parser"),
322                    };
323
324                    if self.variables.contains(variable_name) {
325                        let err = FnError::eval(lhs, EvalError::RedefinedVar);
326                        return Err(err);
327                    }
328
329                    // Evaluate the RHS.
330                    let value = self.eval_expr(rhs)?;
331                    self.variables.insert(variable_name.to_owned());
332                    assignments.push((variable_name.to_owned(), value));
333                }
334
335                Statement::Expr(_) => {
336                    return Err(FnError::eval(statement, EvalError::UselessExpr));
337                }
338
339                _ => return Err(FnError::eval(statement, EvalError::Unsupported)),
340            }
341        }
342
343        let return_value = block
344            .return_value
345            .as_ref()
346            .ok_or_else(|| FnError::eval(&total_span, EvalError::NoReturn))?;
347        let value = self.eval_expr(return_value)?;
348        assignments.push((String::new(), value));
349
350        Ok(Function { assignments })
351    }
352
353    fn eval_expr(&self, expr: &SpannedExpr<'_, FnGrammarBase>) -> Result<Evaluated, FnError> {
354        match &expr.extra {
355            Expr::Variable => {
356                let var_name = *expr.fragment();
357                self.variables
358                    .get(var_name)
359                    .ok_or_else(|| FnError::eval(expr, EvalError::UndefinedVar))?;
360
361                Ok(Evaluated::Variable(var_name.to_owned()))
362            }
363            Expr::Literal(lit) => Ok(Evaluated::Value(*lit)),
364
365            Expr::Unary { op, inner } => match op.extra {
366                UnaryOp::Neg => Ok(-self.eval_expr(inner)?),
367                _ => Err(FnError::eval(op, EvalError::Unsupported)),
368            },
369
370            Expr::Binary { lhs, op, rhs } => {
371                let lhs_value = self.eval_expr(lhs)?;
372                let rhs_value = self.eval_expr(rhs)?;
373
374                Ok(match op.extra {
375                    BinaryOp::Add
376                    | BinaryOp::Sub
377                    | BinaryOp::Mul
378                    | BinaryOp::Div
379                    | BinaryOp::Power => Evaluated::fold(op.extra, lhs_value, rhs_value),
380                    _ => {
381                        return Err(FnError::eval(op, EvalError::Unsupported));
382                    }
383                })
384            }
385
386            Expr::Function { name, args } => {
387                let fn_name = *name.fragment();
388                let function: UnaryFunction =
389                    fn_name.parse().map_err(|e| FnError::eval(name, e))?;
390
391                if args.len() != 1 {
392                    return Err(FnError::eval(expr, EvalError::FnArity));
393                }
394
395                Ok(Evaluated::FunctionCall {
396                    function,
397                    arg: Box::new(self.eval_expr(&args[0])?),
398                })
399            }
400
401            Expr::FnDefinition(_) | Expr::Block(_) | Expr::Tuple(_) | Expr::Method { .. } => {
402                unreachable!("Disabled in parser")
403            }
404
405            _ => Err(FnError::eval(expr, EvalError::Unsupported)),
406        }
407    }
408}
409
410/// Parsed complex-valued function of a single variable.
411///
412/// A `Function` instance can be created using [`FromStr`] trait. A function must use `z`
413/// as the (only) argument. A function may use arithmetic operations (`+`, `-`, `*`, `/`, `^`)
414/// and/or predefined unary functions:
415///
416/// - General functions: `arg`, `sqrt`, `exp`, `log`
417/// - Hyperbolic trigonometry: `sinh`, `cosh`, `tanh`
418/// - Inverse hyperbolic trigonometry: `asinh`, `acosh`, `atanh`
419///
420/// A function may define local variable assignment(s). The assignment syntax is similar to Python
421/// (or Rust, just without the `let` keyword): variable name followed by `=` and then by
422/// the arithmetic expression. Assignments must be separated by semicolons `;`. As in Rust,
423/// the last expression in function body is its return value.
424///
425/// # Examples
426///
427/// ```
428/// # use julia_set::Function;
429/// # fn main() -> anyhow::Result<()> {
430/// let function: Function = "z * z - 0.5".parse()?;
431/// let fn_with_calls: Function = "0.8 * z + z / atanh(z ^ -4)".parse()?;
432/// let fn_with_vars: Function = "c = -0.5 + 0.4i; z * z + c".parse()?;
433/// # Ok(())
434/// # }
435/// ```
436#[cfg_attr(
437    docsrs,
438    doc(cfg(any(
439        feature = "dyn_cpu_backend",
440        feature = "opencl_backend",
441        feature = "vulkan_backend"
442    )))
443)]
444#[derive(Debug, Clone)]
445pub struct Function {
446    assignments: Vec<(String, Evaluated)>,
447}
448
449impl Function {
450    pub(crate) fn assignments(&self) -> impl Iterator<Item = (&str, &Evaluated)> + '_ {
451        self.assignments.iter().filter_map(|(name, value)| {
452            if name.is_empty() {
453                None
454            } else {
455                Some((name.as_str(), value))
456            }
457        })
458    }
459
460    pub(crate) fn return_value(&self) -> &Evaluated {
461        &self.assignments.last().unwrap().1
462    }
463}
464
465impl FromStr for Function {
466    type Err = FnError;
467
468    fn from_str(s: &str) -> Result<Self, Self::Err> {
469        let statements = FnGrammar::parse_statements(s).map_err(|err| FnError::parse(&err, s))?;
470        let body_span = Spanned::from_str(s, ..);
471        Context::new("z").process(&statements, body_span)
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    fn z_square() -> Evaluated {
480        Evaluated::Binary {
481            op: BinaryOp::Mul,
482            lhs: Box::new(Evaluated::Variable("z".to_owned())),
483            rhs: Box::new(Evaluated::Variable("z".to_owned())),
484        }
485    }
486
487    #[test]
488    fn simple_function() {
489        let function: Function = "z*z + (0.77 - 0.2i)".parse().unwrap();
490        let expected_expr = Evaluated::Binary {
491            op: BinaryOp::Add,
492            lhs: Box::new(z_square()),
493            rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
494        };
495        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
496    }
497
498    #[test]
499    fn simple_function_with_rewrite_rules() {
500        let function: Function = "z / 0.25 - 0.1i + (0.77 - 0.1i)".parse().unwrap();
501        let expected_expr = Evaluated::Binary {
502            op: BinaryOp::Add,
503            lhs: Box::new(Evaluated::Binary {
504                op: BinaryOp::Mul,
505                lhs: Box::new(Evaluated::Variable("z".to_owned())),
506                rhs: Box::new(Evaluated::Value(Complex32::new(4.0, 0.0))),
507            }),
508            rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
509        };
510        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
511    }
512
513    #[test]
514    fn function_with_several_rewrite_rules() {
515        let function: Function = "z + 0.1 - z*z + 0.3i".parse().unwrap();
516        let expected_expr = Evaluated::Binary {
517            op: BinaryOp::Add,
518            lhs: Box::new(Evaluated::Binary {
519                op: BinaryOp::Sub,
520                lhs: Box::new(Evaluated::Variable("z".to_owned())),
521                rhs: Box::new(z_square()),
522            }),
523            rhs: Box::new(Evaluated::Value(Complex32::new(0.1, 0.3))),
524        };
525        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
526    }
527
528    #[test]
529    fn simple_function_with_mul_rewrite_rules() {
530        let function: Function = "sinh(z - 5) / 4. * 6i".parse().unwrap();
531        let expected_expr = Evaluated::Binary {
532            op: BinaryOp::Mul,
533            lhs: Box::new(Evaluated::FunctionCall {
534                function: UnaryFunction::Sinh,
535                arg: Box::new(Evaluated::Binary {
536                    op: BinaryOp::Add,
537                    lhs: Box::new(Evaluated::Variable("z".to_owned())),
538                    rhs: Box::new(Evaluated::Value(Complex32::new(-5.0, 0.0))),
539                }),
540            }),
541            rhs: Box::new(Evaluated::Value(Complex32::new(0.0, 1.5))),
542        };
543        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
544    }
545
546    #[test]
547    fn simple_function_with_redistribution() {
548        let function: Function = "0.5 + sinh(z) - 0.2i".parse().unwrap();
549        let expected_expr = Evaluated::Binary {
550            op: BinaryOp::Add,
551            lhs: Box::new(Evaluated::FunctionCall {
552                function: UnaryFunction::Sinh,
553                arg: Box::new(Evaluated::Variable("z".to_owned())),
554            }),
555            rhs: Box::new(Evaluated::Value(Complex32::new(0.5, -0.2))),
556        };
557        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
558    }
559
560    #[test]
561    fn function_with_assignments() {
562        let function: Function = "c = 0.5 - 0.2i; z*z + c".parse().unwrap();
563        let expected_expr = Evaluated::Binary {
564            op: BinaryOp::Add,
565            lhs: Box::new(z_square()),
566            rhs: Box::new(Evaluated::Variable("c".to_owned())),
567        };
568
569        assert_eq!(
570            function.assignments,
571            vec![
572                ("c".to_owned(), Evaluated::Value(Complex32::new(0.5, -0.2))),
573                (String::new(), expected_expr),
574            ]
575        );
576    }
577}