use core::fmt;
use arithmetic_parser::{Location, LvalueLen};
use crate::{
alloc::{Arc, HashMap, String, ToOwned, Vec},
arith::OrdArithmetic,
error::{Backtrace, Error, ErrorKind, LocationInModule},
exec::{ExecutableFn, ModuleId, Operations},
fns::ValueCell,
Environment, EvalResult, SpannedValue, Value,
};
#[derive(Debug)]
pub struct CallContext<'r, T> {
call_location: LocationInModule,
backtrace: Option<&'r mut Backtrace>,
operations: Operations<'r, T>,
}
impl<'r, T> CallContext<'r, T> {
pub fn mock<ID: ModuleId>(module_id: ID, location: Location, env: &'r Environment<T>) -> Self {
Self {
call_location: LocationInModule::new(Arc::new(module_id), location),
backtrace: None,
operations: env.operations(),
}
}
pub(crate) fn new(
call_location: LocationInModule,
backtrace: Option<&'r mut Backtrace>,
operations: Operations<'r, T>,
) -> Self {
Self {
call_location,
backtrace,
operations,
}
}
#[allow(clippy::needless_option_as_deref)] pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace> {
self.backtrace.as_deref_mut()
}
pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
self.operations.arithmetic
}
pub fn call_location(&self) -> &LocationInModule {
&self.call_location
}
pub fn apply_call_location<U>(&self, value: U) -> Location<U> {
self.call_location.in_module().copy_with_extra(value)
}
pub fn call_site_error(&self, error: ErrorKind) -> Error {
Error::from_parts(self.call_location.clone(), error)
}
pub fn check_args_count(
&self,
args: &[SpannedValue<T>],
expected_count: impl Into<LvalueLen>,
) -> Result<(), Error> {
let expected_count = expected_count.into();
if expected_count.matches(args.len()) {
Ok(())
} else {
Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
def: expected_count,
call: args.len(),
}))
}
}
}
pub trait NativeFn<T> {
fn evaluate(
&self,
args: Vec<SpannedValue<T>>,
context: &mut CallContext<'_, T>,
) -> EvalResult<T>;
}
impl<T, F> NativeFn<T> for F
where
F: 'static + Fn(Vec<SpannedValue<T>>, &mut CallContext<'_, T>) -> EvalResult<T>,
{
fn evaluate(
&self,
args: Vec<SpannedValue<T>>,
context: &mut CallContext<'_, T>,
) -> EvalResult<T> {
self(args, context)
}
}
impl<T> fmt::Debug for dyn NativeFn<T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("NativeFn").finish()
}
}
impl<T> dyn NativeFn<T> {
pub(crate) fn data_ptr(&self) -> *const () {
(self as *const dyn NativeFn<T>).cast()
}
}
#[derive(Debug, Clone)]
pub struct InterpretedFn<T> {
definition: Arc<ExecutableFn<T>>,
captures: Vec<Value<T>>,
capture_names: Vec<String>,
}
impl<T> InterpretedFn<T> {
pub(crate) fn new(
definition: Arc<ExecutableFn<T>>,
captures: Vec<Value<T>>,
capture_names: Vec<String>,
) -> Self {
Self {
definition,
captures,
capture_names,
}
}
pub fn module_id(&self) -> &Arc<dyn ModuleId> {
self.definition.inner.id()
}
pub fn arg_count(&self) -> LvalueLen {
self.definition.arg_count
}
pub fn captures(&self) -> HashMap<&str, &Value<T>> {
self.capture_names
.iter()
.zip(&self.captures)
.map(|(name, val)| (name.as_str(), val))
.collect()
}
}
impl<T: 'static + Clone> InterpretedFn<T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<T>>,
ctx: &mut CallContext<'_, T>,
) -> EvalResult<T> {
if !self.arg_count().matches(args.len()) {
let err = ErrorKind::ArgsLenMismatch {
def: self.arg_count(),
call: args.len(),
};
return Err(ctx.call_site_error(err));
}
let args = args.into_iter().map(|arg| arg.extra).collect();
let captures: Result<Vec<_>, _> = self
.captures
.iter()
.zip(&self.capture_names)
.map(|(capture, name)| Self::deref_capture(capture, name))
.collect();
let captures = captures.map_err(|err| ctx.call_site_error(err))?;
self.definition.inner.call_function(captures, args, ctx)
}
fn deref_capture(capture: &Value<T>, name: &str) -> Result<Value<T>, ErrorKind> {
Ok(match capture {
Value::Ref(opaque_ref) => {
if let Some(cell) = opaque_ref.downcast_ref::<ValueCell<T>>() {
cell.get()
.cloned()
.ok_or_else(|| ErrorKind::Uninitialized(name.to_owned()))?
} else {
capture.clone()
}
}
_ => capture.clone(),
})
}
}
#[derive(Debug)]
pub enum Function<T> {
Native(Arc<dyn NativeFn<T>>),
Interpreted(Arc<InterpretedFn<T>>),
}
impl<T> Clone for Function<T> {
fn clone(&self) -> Self {
match self {
Self::Native(function) => Self::Native(Arc::clone(function)),
Self::Interpreted(function) => Self::Interpreted(Arc::clone(function)),
}
}
}
impl<T: fmt::Display> fmt::Display for Function<T> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Native(_) => formatter.write_str("(native fn)"),
Self::Interpreted(function) => {
formatter.write_str("(interpreted fn @ ")?;
let location = LocationInModule::new(
function.module_id().clone(),
function.definition.def_location,
);
location.fmt_location(formatter)?;
formatter.write_str(")")
}
}
}
}
impl<T: PartialEq> PartialEq for Function<T> {
fn eq(&self, other: &Self) -> bool {
self.is_same_function(other)
}
}
impl<T> Function<T> {
pub fn native(function: impl NativeFn<T> + 'static) -> Self {
Self::Native(Arc::new(function))
}
pub fn is_same_function(&self, other: &Self) -> bool {
match (self, other) {
(Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
(Self::Interpreted(this), Self::Interpreted(other)) => Arc::ptr_eq(this, other),
_ => false,
}
}
pub(crate) fn def_location(&self) -> Option<LocationInModule> {
match self {
Self::Native(_) => None,
Self::Interpreted(function) => Some(LocationInModule::new(
function.module_id().clone(),
function.definition.def_location,
)),
}
}
}
impl<T: 'static + Clone> Function<T> {
pub fn evaluate(
&self,
args: Vec<SpannedValue<T>>,
ctx: &mut CallContext<'_, T>,
) -> EvalResult<T> {
match self {
Self::Native(function) => function.evaluate(args, ctx),
Self::Interpreted(function) => function.evaluate(args, ctx),
}
}
}