julia_set/
compiler.rs

1//! Code shared among backends.
2
3use std::fmt::Write as _;
4
5use arithmetic_parser::BinaryOp;
6
7use crate::{Function, function::Evaluated};
8
9const COMPUTE_ARGUMENT: &str = "z";
10const VAR_PREFIX: &str = "__var_";
11const FN_PREFIX: &str = "complex_";
12
13#[derive(Debug, Clone, Copy)]
14pub(crate) struct Compiler {
15    complex_ty: &'static str,
16    complex_init: &'static str,
17}
18
19impl Compiler {
20    #[cfg(any(test, feature = "opencl_backend"))]
21    pub fn for_ocl() -> Self {
22        Self {
23            complex_ty: "float2",
24            complex_init: "(float2)",
25        }
26    }
27
28    #[cfg(any(test, feature = "vulkan_backend"))]
29    pub fn for_gl() -> Self {
30        Self {
31            complex_ty: "vec2",
32            complex_init: "vec2",
33        }
34    }
35
36    pub fn compile(self, function: &Function) -> String {
37        let mut code = String::new();
38        for (var_name, value) in function.assignments() {
39            // Writing to a `String` always succeeds.
40            write!(&mut code, "{} {VAR_PREFIX}{var_name} = ", self.complex_ty).unwrap();
41            self.compile_expr(&mut code, value);
42            code += "; ";
43        }
44
45        code += "return ";
46        self.compile_expr(&mut code, function.return_value());
47        code += ";";
48        code
49    }
50
51    fn op_function(op: BinaryOp) -> &'static str {
52        match op {
53            BinaryOp::Mul => "complex_mul",
54            BinaryOp::Div => "complex_div",
55            BinaryOp::Power => "complex_pow",
56            _ => unreachable!(),
57        }
58    }
59
60    fn compile_expr(self, dest: &mut String, expr: &Evaluated) {
61        match expr {
62            Evaluated::Variable(name) => {
63                if name != COMPUTE_ARGUMENT {
64                    dest.push_str(VAR_PREFIX);
65                }
66                dest.push_str(name);
67            }
68
69            Evaluated::Value(val) => {
70                dest.push_str(self.complex_init);
71                dest.push('(');
72                dest.push_str(&val.re.to_string());
73                dest.push_str(", ");
74                dest.push_str(&val.im.to_string());
75                dest.push(')');
76            }
77
78            Evaluated::Negation(inner) => {
79                dest.push('-');
80                self.compile_expr(dest, inner);
81            }
82
83            Evaluated::Binary { op, lhs, rhs } => match op {
84                BinaryOp::Add | BinaryOp::Sub => {
85                    self.compile_expr(dest, lhs);
86                    dest.push(' ');
87                    dest.push_str(op.as_str());
88                    dest.push(' ');
89                    self.compile_expr(dest, rhs);
90                }
91
92                _ => {
93                    let function_name = Self::op_function(*op);
94                    dest.push_str(function_name);
95                    dest.push('(');
96                    self.compile_expr(dest, lhs);
97                    dest.push_str(", ");
98                    self.compile_expr(dest, rhs);
99                    dest.push(')');
100                }
101            },
102
103            Evaluated::FunctionCall { function, arg } => {
104                dest.push_str(FN_PREFIX);
105                dest.push_str(function.as_str());
106                dest.push('(');
107                self.compile_expr(dest, arg);
108                dest.push(')');
109            }
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn compiling_simple_fns() {
120        let function = "z*z + 0.2 + 0.5i".parse().unwrap();
121        let code = Compiler::for_ocl().compile(&function);
122        assert_eq!(code, "return complex_mul(z, z) + (float2)(0.2, 0.5);");
123        let code = Compiler::for_gl().compile(&function);
124        assert_eq!(code, "return complex_mul(z, z) + vec2(0.2, 0.5);");
125
126        let function = "z^3 * sinh(0.2 + z*z)".parse().unwrap();
127        let code = Compiler::for_ocl().compile(&function);
128        assert_eq!(
129            code,
130            "return complex_mul(complex_pow(z, (float2)(3, 0)), \
131             complex_sinh(complex_mul(z, z) + (float2)(0.2, 0)));"
132        );
133        let code = Compiler::for_gl().compile(&function);
134        assert_eq!(
135            code,
136            "return complex_mul(complex_pow(z, vec2(3, 0)), \
137             complex_sinh(complex_mul(z, z) + vec2(0.2, 0)));"
138        );
139    }
140
141    #[test]
142    fn complex_function_arg() {
143        let function = "sinh(z^2 + 2i * z * -0.5)".parse().unwrap();
144        let code = Compiler::for_ocl().compile(&function);
145        assert_eq!(
146            code,
147            "return complex_sinh(complex_pow(z, (float2)(2, 0)) + \
148             complex_mul(z, (float2)(-0, -1)));"
149        );
150
151        let function = "0.7 + cosh(z*z - 0.5i) * z".parse().unwrap();
152        let code = Compiler::for_ocl().compile(&function);
153        assert_eq!(
154            code,
155            "return complex_mul(complex_cosh(complex_mul(z, z) + (float2)(-0, -0.5)), z) + \
156             (float2)(0.7, 0);"
157        );
158    }
159
160    #[test]
161    fn compiling_fn_with_assignment() {
162        let function = "c = 0.5 + 0.4i; z*z + c".parse().unwrap();
163        let code = Compiler::for_ocl().compile(&function);
164        assert_eq!(
165            code,
166            "float2 __var_c = (float2)(0.5, 0.4); \
167             return complex_mul(z, z) + __var_c;"
168        );
169
170        let function = "d = sinh(z) * z * 1.1; z*z - 0.5 + d".parse().unwrap();
171        let code = Compiler::for_ocl().compile(&function);
172        assert_eq!(
173            code,
174            "float2 __var_d = complex_mul(complex_mul(complex_sinh(z), z), (float2)(1.1, 0)); \
175             return complex_mul(z, z) + __var_d + (float2)(-0.5, -0);"
176        );
177    }
178}