arithmetic_eval/values/
function.rs

1//! `Function` and closely related types.
2
3use core::fmt;
4
5use arithmetic_parser::{Location, LvalueLen};
6
7use crate::{
8    alloc::{Arc, HashMap, String, ToOwned, Vec},
9    arith::OrdArithmetic,
10    error::{Backtrace, Error, ErrorKind, LocationInModule},
11    exec::{ExecutableFn, ModuleId, Operations},
12    fns::ValueCell,
13    Environment, EvalResult, SpannedValue, Value,
14};
15
16/// Context for native function calls.
17#[derive(Debug)]
18pub struct CallContext<'r, T> {
19    call_location: LocationInModule,
20    backtrace: Option<&'r mut Backtrace>,
21    operations: Operations<'r, T>,
22}
23
24impl<'r, T> CallContext<'r, T> {
25    /// Creates a mock call context with the specified module ID and call span.
26    /// The provided [`Environment`] is used to extract an [`OrdArithmetic`] implementation.
27    pub fn mock<ID: ModuleId>(module_id: ID, location: Location, env: &'r Environment<T>) -> Self {
28        Self {
29            call_location: LocationInModule::new(Arc::new(module_id), location),
30            backtrace: None,
31            operations: env.operations(),
32        }
33    }
34
35    pub(crate) fn new(
36        call_location: LocationInModule,
37        backtrace: Option<&'r mut Backtrace>,
38        operations: Operations<'r, T>,
39    ) -> Self {
40        Self {
41            call_location,
42            backtrace,
43            operations,
44        }
45    }
46
47    #[allow(clippy::needless_option_as_deref)] // false positive
48    pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace> {
49        self.backtrace.as_deref_mut()
50    }
51
52    pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
53        self.operations.arithmetic
54    }
55
56    /// Returns the call location of the currently executing function.
57    pub fn call_location(&self) -> &LocationInModule {
58        &self.call_location
59    }
60
61    /// Applies the call span to the specified `value`.
62    pub fn apply_call_location<U>(&self, value: U) -> Location<U> {
63        self.call_location.in_module().copy_with_extra(value)
64    }
65
66    /// Creates an error spanning the call site.
67    pub fn call_site_error(&self, error: ErrorKind) -> Error {
68        Error::from_parts(self.call_location.clone(), error)
69    }
70
71    /// Checks argument count and returns an error if it doesn't match.
72    pub fn check_args_count(
73        &self,
74        args: &[SpannedValue<T>],
75        expected_count: impl Into<LvalueLen>,
76    ) -> Result<(), Error> {
77        let expected_count = expected_count.into();
78        if expected_count.matches(args.len()) {
79            Ok(())
80        } else {
81            Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
82                def: expected_count,
83                call: args.len(),
84            }))
85        }
86    }
87}
88
89/// Function on zero or more [`Value`]s.
90///
91/// Native functions are defined in the Rust code and then can be used from the interpreted
92/// code. See [`fns`](crate::fns) module docs for different ways to define native functions.
93pub trait NativeFn<T> {
94    /// Executes the function on the specified arguments.
95    fn evaluate(
96        &self,
97        args: Vec<SpannedValue<T>>,
98        context: &mut CallContext<'_, T>,
99    ) -> EvalResult<T>;
100}
101
102impl<T, F> NativeFn<T> for F
103where
104    F: 'static + Fn(Vec<SpannedValue<T>>, &mut CallContext<'_, T>) -> EvalResult<T>,
105{
106    fn evaluate(
107        &self,
108        args: Vec<SpannedValue<T>>,
109        context: &mut CallContext<'_, T>,
110    ) -> EvalResult<T> {
111        self(args, context)
112    }
113}
114
115impl<T> fmt::Debug for dyn NativeFn<T> {
116    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
117        formatter.debug_tuple("NativeFn").finish()
118    }
119}
120
121impl<T> dyn NativeFn<T> {
122    /// Extracts a data pointer from this trait object reference.
123    pub(crate) fn data_ptr(&self) -> *const () {
124        // `*const dyn Trait as *const ()` extracts the data pointer,
125        // see https://github.com/rust-lang/rust/issues/27751. This is seemingly
126        // the simplest way to extract the data pointer; `TraitObject` in `std::raw` is
127        // a more future-proof alternative, but it is unstable.
128        (self as *const dyn NativeFn<T>).cast()
129    }
130}
131
132/// Function defined within the interpreter.
133#[derive(Debug, Clone)]
134pub struct InterpretedFn<T> {
135    definition: Arc<ExecutableFn<T>>,
136    captures: Vec<Value<T>>,
137    capture_names: Vec<String>,
138}
139
140impl<T> InterpretedFn<T> {
141    pub(crate) fn new(
142        definition: Arc<ExecutableFn<T>>,
143        captures: Vec<Value<T>>,
144        capture_names: Vec<String>,
145    ) -> Self {
146        Self {
147            definition,
148            captures,
149            capture_names,
150        }
151    }
152
153    /// Returns ID of the module defining this function.
154    pub fn module_id(&self) -> &Arc<dyn ModuleId> {
155        self.definition.inner.id()
156    }
157
158    /// Returns the number of arguments for this function.
159    pub fn arg_count(&self) -> LvalueLen {
160        self.definition.arg_count
161    }
162
163    /// Returns values captured by this function.
164    pub fn captures(&self) -> HashMap<&str, &Value<T>> {
165        self.capture_names
166            .iter()
167            .zip(&self.captures)
168            .map(|(name, val)| (name.as_str(), val))
169            .collect()
170    }
171}
172
173impl<T: 'static + Clone> InterpretedFn<T> {
174    /// Evaluates this function with the provided arguments and the execution context.
175    pub fn evaluate(
176        &self,
177        args: Vec<SpannedValue<T>>,
178        ctx: &mut CallContext<'_, T>,
179    ) -> EvalResult<T> {
180        if !self.arg_count().matches(args.len()) {
181            let err = ErrorKind::ArgsLenMismatch {
182                def: self.arg_count(),
183                call: args.len(),
184            };
185            return Err(ctx.call_site_error(err));
186        }
187
188        let args = args.into_iter().map(|arg| arg.extra).collect();
189        let captures: Result<Vec<_>, _> = self
190            .captures
191            .iter()
192            .zip(&self.capture_names)
193            .map(|(capture, name)| Self::deref_capture(capture, name))
194            .collect();
195        let captures = captures.map_err(|err| ctx.call_site_error(err))?;
196
197        self.definition.inner.call_function(captures, args, ctx)
198    }
199
200    fn deref_capture(capture: &Value<T>, name: &str) -> Result<Value<T>, ErrorKind> {
201        Ok(match capture {
202            Value::Ref(opaque_ref) => {
203                if let Some(cell) = opaque_ref.downcast_ref::<ValueCell<T>>() {
204                    cell.get()
205                        .cloned()
206                        .ok_or_else(|| ErrorKind::Uninitialized(name.to_owned()))?
207                } else {
208                    capture.clone()
209                }
210            }
211            _ => capture.clone(),
212        })
213    }
214}
215
216/// Function definition. Functions can be either native (defined in the Rust code) or defined
217/// in the interpreter.
218#[derive(Debug)]
219pub enum Function<T> {
220    /// Native function.
221    Native(Arc<dyn NativeFn<T>>),
222    /// Interpreted function.
223    Interpreted(Arc<InterpretedFn<T>>),
224}
225
226impl<T> Clone for Function<T> {
227    fn clone(&self) -> Self {
228        match self {
229            Self::Native(function) => Self::Native(Arc::clone(function)),
230            Self::Interpreted(function) => Self::Interpreted(Arc::clone(function)),
231        }
232    }
233}
234
235impl<T: fmt::Display> fmt::Display for Function<T> {
236    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
237        match self {
238            Self::Native(_) => formatter.write_str("(native fn)"),
239            Self::Interpreted(function) => {
240                formatter.write_str("(interpreted fn @ ")?;
241                let location = LocationInModule::new(
242                    function.module_id().clone(),
243                    function.definition.def_location,
244                );
245                location.fmt_location(formatter)?;
246                formatter.write_str(")")
247            }
248        }
249    }
250}
251
252impl<T: PartialEq> PartialEq for Function<T> {
253    fn eq(&self, other: &Self) -> bool {
254        self.is_same_function(other)
255    }
256}
257
258impl<T> Function<T> {
259    /// Creates a native function.
260    pub fn native(function: impl NativeFn<T> + 'static) -> Self {
261        Self::Native(Arc::new(function))
262    }
263
264    /// Checks if the provided function is the same as this one.
265    pub fn is_same_function(&self, other: &Self) -> bool {
266        match (self, other) {
267            (Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
268            (Self::Interpreted(this), Self::Interpreted(other)) => Arc::ptr_eq(this, other),
269            _ => false,
270        }
271    }
272
273    pub(crate) fn def_location(&self) -> Option<LocationInModule> {
274        match self {
275            Self::Native(_) => None,
276            Self::Interpreted(function) => Some(LocationInModule::new(
277                function.module_id().clone(),
278                function.definition.def_location,
279            )),
280        }
281    }
282}
283
284impl<T: 'static + Clone> Function<T> {
285    /// Evaluates the function on the specified arguments.
286    pub fn evaluate(
287        &self,
288        args: Vec<SpannedValue<T>>,
289        ctx: &mut CallContext<'_, T>,
290    ) -> EvalResult<T> {
291        match self {
292            Self::Native(function) => function.evaluate(args, ctx),
293            Self::Interpreted(function) => function.evaluate(args, ctx),
294        }
295    }
296}