1use std::fmt::Write as _;
4
5use arithmetic_parser::BinaryOp;
6
7use crate::{Function, function::Evaluated};
8
9const COMPUTE_ARGUMENT: &str = "z";
10const VAR_PREFIX: &str = "__var_";
11const FN_PREFIX: &str = "complex_";
12
13#[derive(Debug, Clone, Copy)]
14pub(crate) struct Compiler {
15 complex_ty: &'static str,
16 complex_init: &'static str,
17}
18
19impl Compiler {
20 #[cfg(any(test, feature = "opencl_backend"))]
21 pub fn for_ocl() -> Self {
22 Self {
23 complex_ty: "float2",
24 complex_init: "(float2)",
25 }
26 }
27
28 #[cfg(any(test, feature = "vulkan_backend"))]
29 pub fn for_gl() -> Self {
30 Self {
31 complex_ty: "vec2",
32 complex_init: "vec2",
33 }
34 }
35
36 pub fn compile(self, function: &Function) -> String {
37 let mut code = String::new();
38 for (var_name, value) in function.assignments() {
39 write!(&mut code, "{} {VAR_PREFIX}{var_name} = ", self.complex_ty).unwrap();
41 self.compile_expr(&mut code, value);
42 code += "; ";
43 }
44
45 code += "return ";
46 self.compile_expr(&mut code, function.return_value());
47 code += ";";
48 code
49 }
50
51 fn op_function(op: BinaryOp) -> &'static str {
52 match op {
53 BinaryOp::Mul => "complex_mul",
54 BinaryOp::Div => "complex_div",
55 BinaryOp::Power => "complex_pow",
56 _ => unreachable!(),
57 }
58 }
59
60 fn compile_expr(self, dest: &mut String, expr: &Evaluated) {
61 match expr {
62 Evaluated::Variable(name) => {
63 if name != COMPUTE_ARGUMENT {
64 dest.push_str(VAR_PREFIX);
65 }
66 dest.push_str(name);
67 }
68
69 Evaluated::Value(val) => {
70 dest.push_str(self.complex_init);
71 dest.push('(');
72 dest.push_str(&val.re.to_string());
73 dest.push_str(", ");
74 dest.push_str(&val.im.to_string());
75 dest.push(')');
76 }
77
78 Evaluated::Negation(inner) => {
79 dest.push('-');
80 self.compile_expr(dest, inner);
81 }
82
83 Evaluated::Binary { op, lhs, rhs } => match op {
84 BinaryOp::Add | BinaryOp::Sub => {
85 self.compile_expr(dest, lhs);
86 dest.push(' ');
87 dest.push_str(op.as_str());
88 dest.push(' ');
89 self.compile_expr(dest, rhs);
90 }
91
92 _ => {
93 let function_name = Self::op_function(*op);
94 dest.push_str(function_name);
95 dest.push('(');
96 self.compile_expr(dest, lhs);
97 dest.push_str(", ");
98 self.compile_expr(dest, rhs);
99 dest.push(')');
100 }
101 },
102
103 Evaluated::FunctionCall { function, arg } => {
104 dest.push_str(FN_PREFIX);
105 dest.push_str(function.as_str());
106 dest.push('(');
107 self.compile_expr(dest, arg);
108 dest.push(')');
109 }
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn compiling_simple_fns() {
120 let function = "z*z + 0.2 + 0.5i".parse().unwrap();
121 let code = Compiler::for_ocl().compile(&function);
122 assert_eq!(code, "return complex_mul(z, z) + (float2)(0.2, 0.5);");
123 let code = Compiler::for_gl().compile(&function);
124 assert_eq!(code, "return complex_mul(z, z) + vec2(0.2, 0.5);");
125
126 let function = "z^3 * sinh(0.2 + z*z)".parse().unwrap();
127 let code = Compiler::for_ocl().compile(&function);
128 assert_eq!(
129 code,
130 "return complex_mul(complex_pow(z, (float2)(3, 0)), \
131 complex_sinh(complex_mul(z, z) + (float2)(0.2, 0)));"
132 );
133 let code = Compiler::for_gl().compile(&function);
134 assert_eq!(
135 code,
136 "return complex_mul(complex_pow(z, vec2(3, 0)), \
137 complex_sinh(complex_mul(z, z) + vec2(0.2, 0)));"
138 );
139 }
140
141 #[test]
142 fn complex_function_arg() {
143 let function = "sinh(z^2 + 2i * z * -0.5)".parse().unwrap();
144 let code = Compiler::for_ocl().compile(&function);
145 assert_eq!(
146 code,
147 "return complex_sinh(complex_pow(z, (float2)(2, 0)) + \
148 complex_mul(z, (float2)(-0, -1)));"
149 );
150
151 let function = "0.7 + cosh(z*z - 0.5i) * z".parse().unwrap();
152 let code = Compiler::for_ocl().compile(&function);
153 assert_eq!(
154 code,
155 "return complex_mul(complex_cosh(complex_mul(z, z) + (float2)(-0, -0.5)), z) + \
156 (float2)(0.7, 0);"
157 );
158 }
159
160 #[test]
161 fn compiling_fn_with_assignment() {
162 let function = "c = 0.5 + 0.4i; z*z + c".parse().unwrap();
163 let code = Compiler::for_ocl().compile(&function);
164 assert_eq!(
165 code,
166 "float2 __var_c = (float2)(0.5, 0.4); \
167 return complex_mul(z, z) + __var_c;"
168 );
169
170 let function = "d = sinh(z) * z * 1.1; z*z - 0.5 + d".parse().unwrap();
171 let code = Compiler::for_ocl().compile(&function);
172 assert_eq!(
173 code,
174 "float2 __var_d = complex_mul(complex_mul(complex_sinh(z), z), (float2)(1.1, 0)); \
175 return complex_mul(z, z) + __var_d + (float2)(-0.5, -0);"
176 );
177 }
178}