use core::fmt;
use arithmetic_parser::{Location, LvalueLen};
use crate::{
    alloc::{Arc, HashMap, String, ToOwned, Vec},
    arith::OrdArithmetic,
    error::{Backtrace, Error, ErrorKind, LocationInModule},
    exec::{ExecutableFn, ModuleId, Operations},
    fns::ValueCell,
    Environment, EvalResult, SpannedValue, Value,
};
#[derive(Debug)]
pub struct CallContext<'r, T> {
    call_location: LocationInModule,
    backtrace: Option<&'r mut Backtrace>,
    operations: Operations<'r, T>,
}
impl<'r, T> CallContext<'r, T> {
    pub fn mock<ID: ModuleId>(module_id: ID, location: Location, env: &'r Environment<T>) -> Self {
        Self {
            call_location: LocationInModule::new(Arc::new(module_id), location),
            backtrace: None,
            operations: env.operations(),
        }
    }
    pub(crate) fn new(
        call_location: LocationInModule,
        backtrace: Option<&'r mut Backtrace>,
        operations: Operations<'r, T>,
    ) -> Self {
        Self {
            call_location,
            backtrace,
            operations,
        }
    }
    #[allow(clippy::needless_option_as_deref)] pub(crate) fn backtrace(&mut self) -> Option<&mut Backtrace> {
        self.backtrace.as_deref_mut()
    }
    pub(crate) fn arithmetic(&self) -> &'r dyn OrdArithmetic<T> {
        self.operations.arithmetic
    }
    pub fn call_location(&self) -> &LocationInModule {
        &self.call_location
    }
    pub fn apply_call_location<U>(&self, value: U) -> Location<U> {
        self.call_location.in_module().copy_with_extra(value)
    }
    pub fn call_site_error(&self, error: ErrorKind) -> Error {
        Error::from_parts(self.call_location.clone(), error)
    }
    pub fn check_args_count(
        &self,
        args: &[SpannedValue<T>],
        expected_count: impl Into<LvalueLen>,
    ) -> Result<(), Error> {
        let expected_count = expected_count.into();
        if expected_count.matches(args.len()) {
            Ok(())
        } else {
            Err(self.call_site_error(ErrorKind::ArgsLenMismatch {
                def: expected_count,
                call: args.len(),
            }))
        }
    }
}
pub trait NativeFn<T> {
    fn evaluate(
        &self,
        args: Vec<SpannedValue<T>>,
        context: &mut CallContext<'_, T>,
    ) -> EvalResult<T>;
}
impl<T, F> NativeFn<T> for F
where
    F: 'static + Fn(Vec<SpannedValue<T>>, &mut CallContext<'_, T>) -> EvalResult<T>,
{
    fn evaluate(
        &self,
        args: Vec<SpannedValue<T>>,
        context: &mut CallContext<'_, T>,
    ) -> EvalResult<T> {
        self(args, context)
    }
}
impl<T> fmt::Debug for dyn NativeFn<T> {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter.debug_tuple("NativeFn").finish()
    }
}
impl<T> dyn NativeFn<T> {
    pub(crate) fn data_ptr(&self) -> *const () {
        (self as *const dyn NativeFn<T>).cast()
    }
}
#[derive(Debug, Clone)]
pub struct InterpretedFn<T> {
    definition: Arc<ExecutableFn<T>>,
    captures: Vec<Value<T>>,
    capture_names: Vec<String>,
}
impl<T> InterpretedFn<T> {
    pub(crate) fn new(
        definition: Arc<ExecutableFn<T>>,
        captures: Vec<Value<T>>,
        capture_names: Vec<String>,
    ) -> Self {
        Self {
            definition,
            captures,
            capture_names,
        }
    }
    pub fn module_id(&self) -> &Arc<dyn ModuleId> {
        self.definition.inner.id()
    }
    pub fn arg_count(&self) -> LvalueLen {
        self.definition.arg_count
    }
    pub fn captures(&self) -> HashMap<&str, &Value<T>> {
        self.capture_names
            .iter()
            .zip(&self.captures)
            .map(|(name, val)| (name.as_str(), val))
            .collect()
    }
}
impl<T: 'static + Clone> InterpretedFn<T> {
    pub fn evaluate(
        &self,
        args: Vec<SpannedValue<T>>,
        ctx: &mut CallContext<'_, T>,
    ) -> EvalResult<T> {
        if !self.arg_count().matches(args.len()) {
            let err = ErrorKind::ArgsLenMismatch {
                def: self.arg_count(),
                call: args.len(),
            };
            return Err(ctx.call_site_error(err));
        }
        let args = args.into_iter().map(|arg| arg.extra).collect();
        let captures: Result<Vec<_>, _> = self
            .captures
            .iter()
            .zip(&self.capture_names)
            .map(|(capture, name)| Self::deref_capture(capture, name))
            .collect();
        let captures = captures.map_err(|err| ctx.call_site_error(err))?;
        self.definition.inner.call_function(captures, args, ctx)
    }
    fn deref_capture(capture: &Value<T>, name: &str) -> Result<Value<T>, ErrorKind> {
        Ok(match capture {
            Value::Ref(opaque_ref) => {
                if let Some(cell) = opaque_ref.downcast_ref::<ValueCell<T>>() {
                    cell.get()
                        .cloned()
                        .ok_or_else(|| ErrorKind::Uninitialized(name.to_owned()))?
                } else {
                    capture.clone()
                }
            }
            _ => capture.clone(),
        })
    }
}
#[derive(Debug)]
pub enum Function<T> {
    Native(Arc<dyn NativeFn<T>>),
    Interpreted(Arc<InterpretedFn<T>>),
}
impl<T> Clone for Function<T> {
    fn clone(&self) -> Self {
        match self {
            Self::Native(function) => Self::Native(Arc::clone(function)),
            Self::Interpreted(function) => Self::Interpreted(Arc::clone(function)),
        }
    }
}
impl<T: fmt::Display> fmt::Display for Function<T> {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Native(_) => formatter.write_str("(native fn)"),
            Self::Interpreted(function) => {
                formatter.write_str("(interpreted fn @ ")?;
                let location = LocationInModule::new(
                    function.module_id().clone(),
                    function.definition.def_location,
                );
                location.fmt_location(formatter)?;
                formatter.write_str(")")
            }
        }
    }
}
impl<T: PartialEq> PartialEq for Function<T> {
    fn eq(&self, other: &Self) -> bool {
        self.is_same_function(other)
    }
}
impl<T> Function<T> {
    pub fn native(function: impl NativeFn<T> + 'static) -> Self {
        Self::Native(Arc::new(function))
    }
    pub fn is_same_function(&self, other: &Self) -> bool {
        match (self, other) {
            (Self::Native(this), Self::Native(other)) => this.data_ptr() == other.data_ptr(),
            (Self::Interpreted(this), Self::Interpreted(other)) => Arc::ptr_eq(this, other),
            _ => false,
        }
    }
    pub(crate) fn def_location(&self) -> Option<LocationInModule> {
        match self {
            Self::Native(_) => None,
            Self::Interpreted(function) => Some(LocationInModule::new(
                function.module_id().clone(),
                function.definition.def_location,
            )),
        }
    }
}
impl<T: 'static + Clone> Function<T> {
    pub fn evaluate(
        &self,
        args: Vec<SpannedValue<T>>,
        ctx: &mut CallContext<'_, T>,
    ) -> EvalResult<T> {
        match self {
            Self::Native(function) => function.evaluate(args, ctx),
            Self::Interpreted(function) => function.evaluate(args, ctx),
        }
    }
}