use arithmetic_parser::{
grammars::{Features, NumGrammar, Parse, Untyped},
BinaryOp, Block, Expr, Lvalue, Spanned, SpannedExpr, Statement, UnaryOp,
};
use num_complex::Complex32;
use thiserror::Error;
use std::{collections::HashSet, error::Error, fmt, iter, mem, ops, str::FromStr};
#[derive(Debug)]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "dyn_cpu_backend",
feature = "opencl_backend",
feature = "vulkan_backend"
)))
)]
pub struct FnError {
fragment: String,
line: u32,
column: usize,
source: ErrorSource,
}
#[derive(Debug)]
enum ErrorSource {
Parse(String),
Eval(EvalError),
}
impl fmt::Display for ErrorSource {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Parse(err) => write!(formatter, "[PARSE] {err}"),
Self::Eval(err) => write!(formatter, "[EVAL] {err}"),
}
}
}
#[derive(Debug, Error)]
pub(crate) enum EvalError {
#[error("Last statement in function body is not an expression")]
NoReturn,
#[error("Useless expression")]
UselessExpr,
#[error("Cannot redefine variable")]
RedefinedVar,
#[error("Undefined variable")]
UndefinedVar,
#[error("Undefined function")]
UndefinedFn,
#[error("Function call has bogus arity")]
FnArity,
#[error("Unsupported language construct")]
Unsupported,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum UnaryFunction {
Arg,
Sqrt,
Exp,
Log,
Sinh,
Cosh,
Tanh,
Asinh,
Acosh,
Atanh,
}
impl UnaryFunction {
#[cfg(any(feature = "opencl_backend", feature = "vulkan_backend"))]
pub fn as_str(self) -> &'static str {
match self {
Self::Arg => "arg",
Self::Sqrt => "sqrt",
Self::Exp => "exp",
Self::Log => "log",
Self::Sinh => "sinh",
Self::Cosh => "cosh",
Self::Tanh => "tanh",
Self::Asinh => "asinh",
Self::Acosh => "acosh",
Self::Atanh => "atanh",
}
}
#[cfg(feature = "dyn_cpu_backend")]
pub fn eval(self, arg: Complex32) -> Complex32 {
match self {
Self::Arg => Complex32::new(arg.arg(), 0.0),
Self::Sqrt => arg.sqrt(),
Self::Exp => arg.exp(),
Self::Log => arg.ln(),
Self::Sinh => arg.sinh(),
Self::Cosh => arg.cosh(),
Self::Tanh => arg.tanh(),
Self::Asinh => arg.asinh(),
Self::Acosh => arg.acosh(),
Self::Atanh => arg.atanh(),
}
}
}
impl FromStr for UnaryFunction {
type Err = EvalError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"arg" => Ok(Self::Arg),
"sqrt" => Ok(Self::Sqrt),
"exp" => Ok(Self::Exp),
"log" => Ok(Self::Log),
"sinh" => Ok(Self::Sinh),
"cosh" => Ok(Self::Cosh),
"tanh" => Ok(Self::Tanh),
"asinh" => Ok(Self::Asinh),
"acosh" => Ok(Self::Acosh),
"atanh" => Ok(Self::Atanh),
_ => Err(EvalError::UndefinedFn),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Evaluated {
Value(Complex32),
Variable(String),
Negation(Box<Evaluated>),
Binary {
op: BinaryOp,
lhs: Box<Evaluated>,
rhs: Box<Evaluated>,
},
FunctionCall {
function: UnaryFunction,
arg: Box<Evaluated>,
},
}
impl Evaluated {
fn is_commutative(op: BinaryOp) -> bool {
matches!(op, BinaryOp::Add | BinaryOp::Mul)
}
fn is_commutative_pair(op: BinaryOp, other_op: BinaryOp) -> bool {
op.priority() == other_op.priority() && op != BinaryOp::Power
}
fn fold(mut op: BinaryOp, mut lhs: Self, mut rhs: Self) -> Self {
if let (Self::Value(lhs_val), Self::Value(rhs_val)) = (&lhs, &rhs) {
return Self::Value(match op {
BinaryOp::Add => *lhs_val + *rhs_val,
BinaryOp::Sub => *lhs_val - *rhs_val,
BinaryOp::Mul => *lhs_val * *rhs_val,
BinaryOp::Div => *lhs_val / *rhs_val,
BinaryOp::Power => lhs_val.powc(*rhs_val),
_ => unreachable!(),
});
}
if let Self::Value(val) = rhs {
match op {
BinaryOp::Sub => {
op = BinaryOp::Add;
rhs = Self::Value(-val);
}
BinaryOp::Div => {
op = BinaryOp::Mul;
rhs = Self::Value(1.0 / val);
}
_ => { }
}
} else if let Self::Value(_) = lhs {
if Self::is_commutative(op) {
mem::swap(&mut lhs, &mut rhs);
}
}
if let Self::Binary {
op: inner_op,
rhs: inner_rhs,
..
} = &mut lhs
{
if Self::is_commutative_pair(*inner_op, op) {
if let Self::Value(inner_val) = **inner_rhs {
if let Self::Value(val) = rhs {
let new_rhs = match op {
BinaryOp::Add => inner_val + val,
BinaryOp::Mul => inner_val * val,
_ => unreachable!(),
};
*inner_rhs = Box::new(Self::Value(new_rhs));
return lhs;
}
mem::swap(&mut rhs, inner_rhs);
mem::swap(&mut op, inner_op);
}
}
}
Self::Binary {
op,
lhs: Box::new(lhs),
rhs: Box::new(rhs),
}
}
}
impl ops::Neg for Evaluated {
type Output = Self;
fn neg(self) -> Self::Output {
match self {
Self::Value(val) => Self::Value(-val),
Self::Negation(inner) => *inner,
other => Self::Negation(Box::new(other)),
}
}
}
impl FnError {
fn parse(source: &arithmetic_parser::Error, s: &str) -> Self {
let column = source.location().get_column();
Self {
fragment: source.location().span(s).to_owned(),
line: source.location().location_line(),
column,
source: ErrorSource::Parse(source.kind().to_string()),
}
}
fn eval<T>(span: &Spanned<'_, T>, source: EvalError) -> Self {
let column = span.get_column();
Self {
fragment: (*span.fragment()).to_owned(),
line: span.location_line(),
column,
source: ErrorSource::Eval(source),
}
}
}
impl fmt::Display for FnError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "{}:{}: {}", self.line, self.column, self.source)?;
if formatter.alternate() {
formatter.write_str("\n")?;
formatter.pad(&self.fragment)?;
}
Ok(())
}
}
impl Error for FnError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.source {
ErrorSource::Eval(e) => Some(e),
ErrorSource::Parse(_) => None,
}
}
}
type FnGrammarBase = Untyped<NumGrammar<Complex32>>;
#[derive(Debug, Clone, Copy)]
struct FnGrammar;
impl Parse for FnGrammar {
type Base = FnGrammarBase;
const FEATURES: Features = Features::empty();
}
#[derive(Debug)]
pub(crate) struct Context {
variables: HashSet<String>,
}
impl Context {
pub(crate) fn new(arg_name: &str) -> Self {
Self {
variables: iter::once(arg_name.to_owned()).collect(),
}
}
fn process(
&mut self,
block: &Block<'_, FnGrammarBase>,
total_span: Spanned<'_>,
) -> Result<Function, FnError> {
let mut assignments = vec![];
for statement in &block.statements {
match &statement.extra {
Statement::Assignment { lhs, rhs } => {
let variable_name = match lhs.extra {
Lvalue::Variable { .. } => *lhs.fragment(),
_ => unreachable!("Tuples are disabled in parser"),
};
if self.variables.contains(variable_name) {
let err = FnError::eval(lhs, EvalError::RedefinedVar);
return Err(err);
}
let value = self.eval_expr(rhs)?;
self.variables.insert(variable_name.to_owned());
assignments.push((variable_name.to_owned(), value));
}
Statement::Expr(_) => {
return Err(FnError::eval(statement, EvalError::UselessExpr));
}
_ => return Err(FnError::eval(statement, EvalError::Unsupported)),
}
}
let return_value = block
.return_value
.as_ref()
.ok_or_else(|| FnError::eval(&total_span, EvalError::NoReturn))?;
let value = self.eval_expr(return_value)?;
assignments.push((String::new(), value));
Ok(Function { assignments })
}
fn eval_expr(&self, expr: &SpannedExpr<'_, FnGrammarBase>) -> Result<Evaluated, FnError> {
match &expr.extra {
Expr::Variable => {
let var_name = *expr.fragment();
self.variables
.get(var_name)
.ok_or_else(|| FnError::eval(expr, EvalError::UndefinedVar))?;
Ok(Evaluated::Variable(var_name.to_owned()))
}
Expr::Literal(lit) => Ok(Evaluated::Value(*lit)),
Expr::Unary { op, inner } => match op.extra {
UnaryOp::Neg => Ok(-self.eval_expr(inner)?),
_ => Err(FnError::eval(op, EvalError::Unsupported)),
},
Expr::Binary { lhs, op, rhs } => {
let lhs_value = self.eval_expr(lhs)?;
let rhs_value = self.eval_expr(rhs)?;
Ok(match op.extra {
BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div
| BinaryOp::Power => Evaluated::fold(op.extra, lhs_value, rhs_value),
_ => {
return Err(FnError::eval(op, EvalError::Unsupported));
}
})
}
Expr::Function { name, args } => {
let fn_name = *name.fragment();
let function: UnaryFunction =
fn_name.parse().map_err(|e| FnError::eval(name, e))?;
if args.len() != 1 {
return Err(FnError::eval(expr, EvalError::FnArity));
}
Ok(Evaluated::FunctionCall {
function,
arg: Box::new(self.eval_expr(&args[0])?),
})
}
Expr::FnDefinition(_) | Expr::Block(_) | Expr::Tuple(_) | Expr::Method { .. } => {
unreachable!("Disabled in parser")
}
_ => Err(FnError::eval(expr, EvalError::Unsupported)),
}
}
}
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "dyn_cpu_backend",
feature = "opencl_backend",
feature = "vulkan_backend"
)))
)]
#[derive(Debug, Clone)]
pub struct Function {
assignments: Vec<(String, Evaluated)>,
}
impl Function {
pub(crate) fn assignments(&self) -> impl Iterator<Item = (&str, &Evaluated)> + '_ {
self.assignments.iter().filter_map(|(name, value)| {
if name.is_empty() {
None
} else {
Some((name.as_str(), value))
}
})
}
pub(crate) fn return_value(&self) -> &Evaluated {
&self.assignments.last().unwrap().1
}
}
impl FromStr for Function {
type Err = FnError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let statements = FnGrammar::parse_statements(s).map_err(|err| FnError::parse(&err, s))?;
let body_span = Spanned::from_str(s, ..);
Context::new("z").process(&statements, body_span)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn z_square() -> Evaluated {
Evaluated::Binary {
op: BinaryOp::Mul,
lhs: Box::new(Evaluated::Variable("z".to_owned())),
rhs: Box::new(Evaluated::Variable("z".to_owned())),
}
}
#[test]
fn simple_function() {
let function: Function = "z*z + (0.77 - 0.2i)".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(z_square()),
rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
};
assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
}
#[test]
fn simple_function_with_rewrite_rules() {
let function: Function = "z / 0.25 - 0.1i + (0.77 - 0.1i)".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(Evaluated::Binary {
op: BinaryOp::Mul,
lhs: Box::new(Evaluated::Variable("z".to_owned())),
rhs: Box::new(Evaluated::Value(Complex32::new(4.0, 0.0))),
}),
rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
};
assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
}
#[test]
fn function_with_several_rewrite_rules() {
let function: Function = "z + 0.1 - z*z + 0.3i".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(Evaluated::Binary {
op: BinaryOp::Sub,
lhs: Box::new(Evaluated::Variable("z".to_owned())),
rhs: Box::new(z_square()),
}),
rhs: Box::new(Evaluated::Value(Complex32::new(0.1, 0.3))),
};
assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
}
#[test]
fn simple_function_with_mul_rewrite_rules() {
let function: Function = "sinh(z - 5) / 4. * 6i".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Mul,
lhs: Box::new(Evaluated::FunctionCall {
function: UnaryFunction::Sinh,
arg: Box::new(Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(Evaluated::Variable("z".to_owned())),
rhs: Box::new(Evaluated::Value(Complex32::new(-5.0, 0.0))),
}),
}),
rhs: Box::new(Evaluated::Value(Complex32::new(0.0, 1.5))),
};
assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
}
#[test]
fn simple_function_with_redistribution() {
let function: Function = "0.5 + sinh(z) - 0.2i".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(Evaluated::FunctionCall {
function: UnaryFunction::Sinh,
arg: Box::new(Evaluated::Variable("z".to_owned())),
}),
rhs: Box::new(Evaluated::Value(Complex32::new(0.5, -0.2))),
};
assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
}
#[test]
fn function_with_assignments() {
let function: Function = "c = 0.5 - 0.2i; z*z + c".parse().unwrap();
let expected_expr = Evaluated::Binary {
op: BinaryOp::Add,
lhs: Box::new(z_square()),
rhs: Box::new(Evaluated::Variable("c".to_owned())),
};
assert_eq!(
function.assignments,
vec![
("c".to_owned(), Evaluated::Value(Complex32::new(0.5, -0.2))),
(String::new(), expected_expr),
]
);
}
}