1use core::{fmt, marker::PhantomData};
4
5pub use self::traits::{
6 ErrorOutput, FromValueError, FromValueErrorKind, FromValueErrorLocation, IntoEvalResult,
7 TryFromValue,
8};
9use crate::{
10 alloc::Vec, error::AuxErrorInfo, CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue,
11};
12
13mod traits;
14
15pub const fn wrap<const CTX: bool, T, F>(function: F) -> FnWrapper<T, F, CTX> {
20 FnWrapper::new(function)
21}
22
23pub struct FnWrapper<T, F, const CTX: bool = false> {
114 function: F,
115 _arg_types: PhantomData<T>,
116}
117
118impl<T, F, const CTX: bool> fmt::Debug for FnWrapper<T, F, CTX>
119where
120 F: fmt::Debug,
121{
122 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123 formatter
124 .debug_struct("FnWrapper")
125 .field("function", &self.function)
126 .field("context", &CTX)
127 .finish()
128 }
129}
130
131impl<T, F: Clone, const CTX: bool> Clone for FnWrapper<T, F, CTX> {
132 fn clone(&self) -> Self {
133 Self {
134 function: self.function.clone(),
135 _arg_types: PhantomData,
136 }
137 }
138}
139
140impl<T, F: Copy, const CTX: bool> Copy for FnWrapper<T, F, CTX> {}
141
142impl<T, F, const CTX: bool> FnWrapper<T, F, CTX> {
145 pub const fn new(function: F) -> Self {
152 Self {
153 function,
154 _arg_types: PhantomData,
155 }
156 }
157}
158
159macro_rules! arity_fn {
160 ($arity:tt, $with_ctx:tt $(, $ctx_name:ident : $ctx_t:ty)? => $($arg_name:ident : $t:ident),*) => {
161 impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F, $with_ctx>
162 where
163 F: Fn($($ctx_t,)? $($t,)*) -> Ret,
164 $($t: TryFromValue<Num>,)*
165 Ret: IntoEvalResult<Num>,
166 {
167 #[allow(clippy::shadow_unrelated)] #[allow(unused_variables, unused_mut)] fn evaluate(
170 &self,
171 args: Vec<SpannedValue<Num>>,
172 context: &mut CallContext<'_, Num>,
173 ) -> EvalResult<Num> {
174 context.check_args_count(&args, $arity)?;
175 let mut args_iter = args.into_iter().enumerate();
176
177 $(
178 let (index, $arg_name) = args_iter.next().unwrap();
179 let span = $arg_name.with_no_extra();
180 let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
181 err.set_arg_index(index);
182 context
183 .call_site_error(ErrorKind::Wrapper(err))
184 .with_location(&span, AuxErrorInfo::InvalidArg)
185 })?;
186 )*
187
188 $(let $ctx_name = &mut *context;)?
189 let output = (self.function)($($ctx_name,)? $($arg_name,)*);
190 output.into_eval_result().map_err(|err| err.into_spanned(context))
191 }
192 }
193 };
194}
195
196arity_fn!(0, false =>);
197arity_fn!(0, true, ctx: &mut CallContext<'_, Num> =>);
198arity_fn!(1, false => x0: T);
199arity_fn!(1, true, ctx: &mut CallContext<'_, Num> => x0: T);
200arity_fn!(2, false => x0: T, x1: U);
201arity_fn!(2, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U);
202arity_fn!(3, false => x0: T, x1: U, x2: V);
203arity_fn!(3, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V);
204arity_fn!(4, false => x0: T, x1: U, x2: V, x3: W);
205arity_fn!(4, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W);
206arity_fn!(5, false => x0: T, x1: U, x2: V, x3: W, x4: X);
207arity_fn!(5, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X);
208arity_fn!(6, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
209arity_fn!(6, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
210arity_fn!(7, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
211arity_fn!(7, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
212arity_fn!(8, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
213arity_fn!(8, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
214arity_fn!(9, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
215arity_fn!(9, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
216arity_fn!(10, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
217arity_fn!(10, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
218
219pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
221
222pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
224
225pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
227
228pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
230
231#[cfg(test)]
232mod tests {
233 use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
234 use assert_matches::assert_matches;
235
236 use super::*;
237 use crate::{
238 alloc::{format, ToOwned},
239 env::{Environment, Prelude},
240 exec::{ExecutableModule, WildcardId},
241 Function, Object, Tuple, Value,
242 };
243
244 #[test]
245 fn functions_with_primitive_args() -> anyhow::Result<()> {
246 let unary_fn = Unary::new(|x: f32| x + 3.0);
247 let binary_fn = Binary::new(f32::min);
248 let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
249
250 let program = "
251 unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
252 ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
253 ";
254 let block = Untyped::<F32Grammar>::parse_statements(program)?;
255 let module = ExecutableModule::new(WildcardId, &block)?;
256
257 let mut env = Environment::new();
258 env.insert_native_fn("unary_fn", unary_fn)
259 .insert_native_fn("binary_fn", binary_fn)
260 .insert_native_fn("ternary_fn", ternary_fn);
261
262 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
263 Ok(())
264 }
265
266 fn array_min_max(values: Vec<f32>) -> (f32, f32) {
267 let mut min = f32::INFINITY;
268 let mut max = f32::NEG_INFINITY;
269
270 for value in values {
271 if value < min {
272 min = value;
273 }
274 if value > max {
275 max = value;
276 }
277 }
278 (min, max)
279 }
280
281 fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
282 xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
283 }
284
285 #[test]
286 fn functions_with_composite_args() -> anyhow::Result<()> {
287 let program = "
288 array_min_max((1, 5, -3, 2, 1)) == (-3, 5) &&
289 total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
290 ";
291 let block = Untyped::<F32Grammar>::parse_statements(program)?;
292 let module = ExecutableModule::new(WildcardId, &block)?;
293
294 let mut env = Environment::new();
295 env.insert_wrapped_fn("array_min_max", array_min_max)
296 .insert_wrapped_fn("total_sum", overly_convoluted_fn);
297
298 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
299 Ok(())
300 }
301
302 fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
303 if xs.len() == ys.len() {
304 Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
305 } else {
306 Err("Summed arrays must have the same size".to_owned())
307 }
308 }
309
310 #[test]
311 fn fallible_function() -> anyhow::Result<()> {
312 let program = "sum_arrays((1, 2, 3), (4, 5, 6)) == (5, 7, 9)";
313 let block = Untyped::<F32Grammar>::parse_statements(program)?;
314 let module = ExecutableModule::new(WildcardId, &block)?;
315
316 let mut env = Environment::new();
317 env.insert_wrapped_fn("sum_arrays", sum_arrays);
318 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
319 Ok(())
320 }
321
322 #[test]
323 fn fallible_function_with_bogus_program() -> anyhow::Result<()> {
324 let program = "sum_arrays((1, 2, 3), (4, 5))";
325 let block = Untyped::<F32Grammar>::parse_statements(program)?;
326 let module = ExecutableModule::new(WildcardId, &block)?;
327
328 let mut env = Environment::new();
329 env.insert_wrapped_fn("sum_arrays", sum_arrays);
330
331 let err = module.with_env(&env)?.run().unwrap_err();
332 assert!(err
333 .source()
334 .kind()
335 .to_short_string()
336 .contains("Summed arrays must have the same size"));
337 Ok(())
338 }
339
340 #[test]
341 fn function_with_bool_return_value() -> anyhow::Result<()> {
342 let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
343
344 let program = "contains((-1, 2), 0) && !contains((1, 3), 0)";
345 let block = Untyped::<F32Grammar>::parse_statements(program)?;
346 let module = ExecutableModule::new(WildcardId, &block)?;
347
348 let mut env = Environment::new();
349 env.insert_native_fn("contains", contains);
350 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
351 Ok(())
352 }
353
354 #[test]
355 fn function_with_void_return_value() -> anyhow::Result<()> {
356 let program = "assert_eq(3, 1 + 2)";
357 let block = Untyped::<F32Grammar>::parse_statements(program)?;
358 let module = ExecutableModule::new(WildcardId, &block)?;
359
360 let mut env = Environment::new();
361 env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
362 if (expected - actual).abs() < f32::EPSILON {
363 Ok(())
364 } else {
365 Err(format!(
366 "Assertion failed: expected {expected}, got {actual}"
367 ))
368 }
369 });
370
371 assert!(module.with_env(&env)?.run()?.is_void());
372
373 let bogus_program = "assert_eq(3, 1 - 2)";
374 let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program)?;
375 let err = ExecutableModule::new(WildcardId, &bogus_block)?
376 .with_env(&env)?
377 .run()
378 .unwrap_err();
379
380 assert_matches!(
381 err.source().kind(),
382 ErrorKind::NativeCall(ref msg) if msg.contains("Assertion failed")
383 );
384 Ok(())
385 }
386
387 #[test]
388 fn function_with_bool_argument() -> anyhow::Result<()> {
389 let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
390 let block = Untyped::<F32Grammar>::parse_statements(program)?;
391 let module = ExecutableModule::new(WildcardId, &block)?;
392
393 let mut env = Environment::new();
394 env.extend(Prelude::iter());
395 env.insert_wrapped_fn(
396 "flip_sign",
397 |val: f32, flag: bool| if flag { -val } else { val },
398 );
399
400 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
401 Ok(())
402 }
403
404 #[test]
405 #[allow(clippy::cast_precision_loss)] fn function_with_object_and_tuple() -> anyhow::Result<()> {
407 fn test_function(tuple: Tuple<f32>) -> Object<f32> {
408 let mut obj = Object::default();
409 obj.insert("len", Value::Prim(tuple.len() as f32));
410 obj.insert("tuple", tuple.into());
411 obj
412 }
413
414 let program = "
415 { len, tuple } = test((1, 1, 1));
416 len == 3 && tuple == (1, 1, 1)
417 ";
418 let block = Untyped::<F32Grammar>::parse_statements(program)?;
419 let module = ExecutableModule::new(WildcardId, &block)?;
420
421 let test_function = Value::native_fn(wrap(test_function));
422 let mut env = Environment::new();
423 env.insert("test", test_function).extend(Prelude::iter());
424
425 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
426 Ok(())
427 }
428
429 #[test]
430 fn error_reporting_with_destructuring() -> anyhow::Result<()> {
431 let program = "destructure(((true, 1), (2, 3)))";
432 let block = Untyped::<F32Grammar>::parse_statements(program)?;
433 let module = ExecutableModule::new(WildcardId, &block)?;
434
435 let mut env = Environment::new();
436 env.extend(Prelude::iter());
437 env.insert_wrapped_fn("destructure", |values: Vec<(bool, f32)>| {
438 values
439 .into_iter()
440 .map(|(flag, x)| if flag { x } else { 0.0 })
441 .sum::<f32>()
442 });
443
444 let err = module.with_env(&env)?.run().unwrap_err();
445 let err_message = err.source().kind().to_short_string();
446 assert!(err_message.contains("Cannot convert primitive value to bool"));
447 assert!(err_message.contains("location: arg0[1].0"));
448 Ok(())
449 }
450
451 #[test]
452 fn function_with_context() -> anyhow::Result<()> {
453 #[allow(clippy::needless_pass_by_value)] fn call(
455 ctx: &mut CallContext<'_, f32>,
456 func: Function<f32>,
457 value: f32,
458 ) -> EvalResult<f32> {
459 let args = vec![ctx.apply_call_location(Value::Prim(value))];
460 func.evaluate(args, ctx)
461 }
462
463 let program = "(|x| { x + 1 }).call(1)";
464 let block = Untyped::<F32Grammar>::parse_statements(program)?;
465 let module = ExecutableModule::new(WildcardId, &block)?;
466
467 let mut env = Environment::new();
468 env.insert_wrapped_fn("call", call);
469 assert_eq!(module.with_env(&env)?.run()?, Value::Prim(2.0));
470 Ok(())
471 }
472}