1use core::fmt;
4
5use arithmetic_parser::{Location, LvalueLen};
6
7use crate::{
8 alloc::{Arc, HashMap, String, ToOwned, Vec},
9 arith::OrdArithmetic,
10 error::{Backtrace, Error, ErrorKind, LocationInModule},
11 exec::{ExecutableFn, ModuleId, Operations},
12 fns::ValueCell,
13 Environment, EvalResult, SpannedValue, Value,
14};
15
16#[derive(Debug)]
18pub struct CallContext<'r, T> {
19 call_location: LocationInModule,
20 backtrace: Option<&'r mut Backtrace>,
21 operations: Operations<'r, T>,
22}
23
24impl<'r, T> CallContext<'r, T> {
25 pub fn mock<ID: ModuleId>(module_id: ID, location: Location, env: &'r Environment<T>) -> Self {
28 Self {
29 call_location: LocationInModule::new(Arc::new(module_id), location),
30 backtrace: None,
31 operations: env.operations(),
32 }
33 }
34
35 pub(crate) fn new(
36 call_location: LocationInModule,
37 backtrace: Option<&'r mut Backtrace>,
38 operations: Operations<'r, T>,
39 ) -> Self {
40 Self {
41 call_location,
42 backtrace,
43 operations,
44 }
45 }
46
47 #[allow(clippy::needless_option_as_deref)] pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace> {
49 self.backtrace.as_deref_mut()
50 }
51
52 pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
53 self.operations.arithmetic
54 }
55
56 pub fn call_location(&self) -> &LocationInModule {
58 &self.call_location
59 }
60
61 pub fn apply_call_location<U>(&self, value: U) -> Location<U> {
63 self.call_location.in_module().copy_with_extra(value)
64 }
65
66 pub fn call_site_error(&self, error: ErrorKind) -> Error {
68 Error::from_parts(self.call_location.clone(), error)
69 }
70
71 pub fn check_args_count(
73 &self,
74 args: &[SpannedValue<T>],
75 expected_count: impl Into<LvalueLen>,
76 ) -> Result<(), Error> {
77 let expected_count = expected_count.into();
78 if expected_count.matches(args.len()) {
79 Ok(())
80 } else {
81 Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
82 def: expected_count,
83 call: args.len(),
84 }))
85 }
86 }
87}
88
89pub trait NativeFn<T> {
94 fn evaluate(
96 &self,
97 args: Vec<SpannedValue<T>>,
98 context: &mut CallContext<'_, T>,
99 ) -> EvalResult<T>;
100}
101
102impl<T, F> NativeFn<T> for F
103where
104 F: 'static + Fn(Vec<SpannedValue<T>>, &mut CallContext<'_, T>) -> EvalResult<T>,
105{
106 fn evaluate(
107 &self,
108 args: Vec<SpannedValue<T>>,
109 context: &mut CallContext<'_, T>,
110 ) -> EvalResult<T> {
111 self(args, context)
112 }
113}
114
115impl<T> fmt::Debug for dyn NativeFn<T> {
116 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
117 formatter.debug_tuple("NativeFn").finish()
118 }
119}
120
121impl<T> dyn NativeFn<T> {
122 pub(crate) fn data_ptr(&self) -> *const () {
124 (self as *const dyn NativeFn<T>).cast()
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct InterpretedFn<T> {
135 definition: Arc<ExecutableFn<T>>,
136 captures: Vec<Value<T>>,
137 capture_names: Vec<String>,
138}
139
140impl<T> InterpretedFn<T> {
141 pub(crate) fn new(
142 definition: Arc<ExecutableFn<T>>,
143 captures: Vec<Value<T>>,
144 capture_names: Vec<String>,
145 ) -> Self {
146 Self {
147 definition,
148 captures,
149 capture_names,
150 }
151 }
152
153 pub fn module_id(&self) -> &Arc<dyn ModuleId> {
155 self.definition.inner.id()
156 }
157
158 pub fn arg_count(&self) -> LvalueLen {
160 self.definition.arg_count
161 }
162
163 pub fn captures(&self) -> HashMap<&str, &Value<T>> {
165 self.capture_names
166 .iter()
167 .zip(&self.captures)
168 .map(|(name, val)| (name.as_str(), val))
169 .collect()
170 }
171}
172
173impl<T: 'static + Clone> InterpretedFn<T> {
174 pub fn evaluate(
176 &self,
177 args: Vec<SpannedValue<T>>,
178 ctx: &mut CallContext<'_, T>,
179 ) -> EvalResult<T> {
180 if !self.arg_count().matches(args.len()) {
181 let err = ErrorKind::ArgsLenMismatch {
182 def: self.arg_count(),
183 call: args.len(),
184 };
185 return Err(ctx.call_site_error(err));
186 }
187
188 let args = args.into_iter().map(|arg| arg.extra).collect();
189 let captures: Result<Vec<_>, _> = self
190 .captures
191 .iter()
192 .zip(&self.capture_names)
193 .map(|(capture, name)| Self::deref_capture(capture, name))
194 .collect();
195 let captures = captures.map_err(|err| ctx.call_site_error(err))?;
196
197 self.definition.inner.call_function(captures, args, ctx)
198 }
199
200 fn deref_capture(capture: &Value<T>, name: &str) -> Result<Value<T>, ErrorKind> {
201 Ok(match capture {
202 Value::Ref(opaque_ref) => {
203 if let Some(cell) = opaque_ref.downcast_ref::<ValueCell<T>>() {
204 cell.get()
205 .cloned()
206 .ok_or_else(|| ErrorKind::Uninitialized(name.to_owned()))?
207 } else {
208 capture.clone()
209 }
210 }
211 _ => capture.clone(),
212 })
213 }
214}
215
216#[derive(Debug)]
219pub enum Function<T> {
220 Native(Arc<dyn NativeFn<T>>),
222 Interpreted(Arc<InterpretedFn<T>>),
224}
225
226impl<T> Clone for Function<T> {
227 fn clone(&self) -> Self {
228 match self {
229 Self::Native(function) => Self::Native(Arc::clone(function)),
230 Self::Interpreted(function) => Self::Interpreted(Arc::clone(function)),
231 }
232 }
233}
234
235impl<T: fmt::Display> fmt::Display for Function<T> {
236 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
237 match self {
238 Self::Native(_) => formatter.write_str("(native fn)"),
239 Self::Interpreted(function) => {
240 formatter.write_str("(interpreted fn @ ")?;
241 let location = LocationInModule::new(
242 function.module_id().clone(),
243 function.definition.def_location,
244 );
245 location.fmt_location(formatter)?;
246 formatter.write_str(")")
247 }
248 }
249 }
250}
251
252impl<T: PartialEq> PartialEq for Function<T> {
253 fn eq(&self, other: &Self) -> bool {
254 self.is_same_function(other)
255 }
256}
257
258impl<T> Function<T> {
259 pub fn native(function: impl NativeFn<T> + 'static) -> Self {
261 Self::Native(Arc::new(function))
262 }
263
264 pub fn is_same_function(&self, other: &Self) -> bool {
266 match (self, other) {
267 (Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
268 (Self::Interpreted(this), Self::Interpreted(other)) => Arc::ptr_eq(this, other),
269 _ => false,
270 }
271 }
272
273 pub(crate) fn def_location(&self) -> Option<LocationInModule> {
274 match self {
275 Self::Native(_) => None,
276 Self::Interpreted(function) => Some(LocationInModule::new(
277 function.module_id().clone(),
278 function.definition.def_location,
279 )),
280 }
281 }
282}
283
284impl<T: 'static + Clone> Function<T> {
285 pub fn evaluate(
287 &self,
288 args: Vec<SpannedValue<T>>,
289 ctx: &mut CallContext<'_, T>,
290 ) -> EvalResult<T> {
291 match self {
292 Self::Native(function) => function.evaluate(args, ctx),
293 Self::Interpreted(function) => function.evaluate(args, ctx),
294 }
295 }
296}