use ndarray::{Array4, ArrayView4};
use ocl::{flags, prm::Uint3, Buffer, Kernel};
use std::{borrow::Cow, convert::TryFrom};
use crate::{
    base::Base,
    params::{OutputParams, WithParams},
    ConvElement, Params,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FeatureMapShape {
    pub batch_size: u32,
    pub width: u32,
    pub height: u32,
    pub channels: u32,
}
impl FeatureMapShape {
    fn from_nhwc_slice(shape: &[usize]) -> Self {
        assert_eq!(shape.len(), 4);
        FeatureMapShape {
            batch_size: u32::try_from(shape[0]).expect("Cannot convert batch size to `u32`"),
            height: u32::try_from(shape[1]).expect("Cannot convert height to `u32`"),
            width: u32::try_from(shape[2]).expect("Cannot convert width to `u32`"),
            channels: u32::try_from(shape[3]).expect("Cannot convert channel count to `u32`"),
        }
    }
    fn from_nchw_slice(shape: &[usize]) -> Self {
        assert_eq!(shape.len(), 4);
        FeatureMapShape {
            batch_size: u32::try_from(shape[0]).expect("Cannot convert batch size to `u32`"),
            height: u32::try_from(shape[2]).expect("Cannot convert height to `u32`"),
            width: u32::try_from(shape[3]).expect("Cannot convert width to `u32`"),
            channels: u32::try_from(shape[1]).expect("Cannot convert channel count to `u32`"),
        }
    }
    fn buffer_len(self) -> usize {
        self.batch_size as usize
            * self.width as usize
            * self.height as usize
            * self.channels as usize
    }
    fn as_array(self, layout: Layout) -> [usize; 4] {
        match layout {
            Layout::ChannelsFirst => [
                self.batch_size as usize,
                self.channels as usize,
                self.height as usize,
                self.width as usize,
            ],
            Layout::ChannelsLast => [
                self.batch_size as usize,
                self.height as usize,
                self.width as usize,
                self.channels as usize,
            ],
        }
    }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum Layout {
    ChannelsFirst = 0,
    ChannelsLast = 1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FeatureMap<'a, T> {
    layout: Layout,
    inner: ArrayView4<'a, T>,
    shape: FeatureMapShape,
}
impl<'a, T: ConvElement> FeatureMap<'a, T> {
    pub fn nchw(array: impl Into<ArrayView4<'a, T>>) -> Self {
        let array = array.into();
        Self {
            layout: Layout::ChannelsFirst,
            shape: FeatureMapShape::from_nchw_slice(array.shape()),
            inner: array,
        }
    }
    pub fn nhwc(array: impl Into<ArrayView4<'a, T>>) -> Self {
        let array = array.into();
        Self {
            layout: Layout::ChannelsLast,
            shape: FeatureMapShape::from_nhwc_slice(array.shape()),
            inner: array,
        }
    }
    pub fn layout(self) -> Layout {
        self.layout
    }
    pub fn shape(self) -> FeatureMapShape {
        self.shape
    }
    fn to_nhwc(self) -> ArrayView4<'a, T> {
        match self.layout {
            Layout::ChannelsFirst => self.inner.permuted_axes([0, 2, 3, 1]),
            Layout::ChannelsLast => self.inner,
        }
    }
}
#[derive(Debug, Clone)]
pub(crate) struct Filters<T: ConvElement> {
    inner: Buffer<T>,
    biases: Option<Buffer<T::Acc>>,
    filter_count: u32,
    channel_count: u32,
}
impl<T: ConvElement> Filters<T> {
    pub fn filter_count(&self) -> u32 {
        self.filter_count
    }
    pub fn channel_count(&self) -> u32 {
        self.channel_count
    }
    pub fn new<U: WithParams>(
        filters: ArrayView4<'_, T>,
        biases: Option<&[T::Acc]>,
        conv: &Base<U>,
    ) -> ocl::Result<Self> {
        assert!(
            filters.shape()[1] == conv.size() as usize
                && filters.shape()[2] == conv.size() as usize,
            "Invalid filter shape: expected {0}x{0}, got {1}x{2}",
            conv.size(),
            filters.shape()[1],
            filters.shape()[2]
        );
        if let Some(biases) = biases {
            assert_eq!(
                filters.shape()[0],
                biases.len(),
                "Number of filter biases does not agree with the number of filters"
            );
        }
        let filters_slice = filters.as_slice().map_or_else(
            || Cow::Owned(filters.iter().copied().collect()),
            Cow::Borrowed,
        );
        let filters_buffer = Buffer::builder()
            .queue(conv.queue().clone())
            .len(filters.shape().iter().product::<usize>())
            .flags(flags::MEM_READ_ONLY)
            .copy_host_slice(filters_slice.as_ref())
            .build()?;
        let filter_biases = biases
            .map(|biases| {
                Buffer::builder()
                    .queue(conv.queue().clone())
                    .len(biases.len())
                    .flags(flags::MEM_READ_ONLY)
                    .copy_host_slice(biases)
                    .build()
            })
            .transpose()?;
        conv.kernel().set_arg("filters", &filters_buffer)?;
        conv.kernel()
            .set_arg("filter_biases", filter_biases.as_ref())?;
        Ok(Self {
            inner: filters_buffer,
            biases: filter_biases,
            filter_count: u32::try_from(filters.shape()[0])
                .expect("Cannot convert filter count to `u32`"),
            channel_count: u32::try_from(filters.shape()[3])
                .expect("Cannot convert channel count to `u32`"),
        })
    }
    pub fn pass_as_arguments(&self, kernel: &Kernel) -> ocl::Result<()> {
        kernel.set_arg("filters", &self.inner)?;
        if let Some(ref biases) = self.biases {
            kernel.set_arg("filter_biases", biases)?;
        }
        Ok(())
    }
}
#[derive(Debug, Clone)]
pub(crate) struct InputAndOutput<T: ConvElement> {
    signal_buffer: Buffer<T>,
    signal_dims: Uint3,
    output_buffer: Buffer<T>,
    output_shape: FeatureMapShape,
}
impl<T: ConvElement> InputAndOutput<T> {
    pub fn new<U: WithParams>(
        signal_shape: FeatureMapShape,
        filter_count: u32,
        conv: &Base<U>,
    ) -> ocl::Result<Self> {
        let Params {
            pads,
            strides,
            dilation,
            ..
        } = conv.params().into();
        let effective_kernel_h = conv.size() + (dilation[0] - 1) * (conv.size() - 1);
        let out_h = (signal_shape.height - effective_kernel_h + pads[0] + pads[2]) / strides[0] + 1;
        let effective_kernel_w = conv.size() + (dilation[1] - 1) * (conv.size() - 1);
        let out_w = (signal_shape.width - effective_kernel_w + pads[1] + pads[3]) / strides[1] + 1;
        let output_shape = FeatureMapShape {
            height: out_h,
            width: out_w,
            channels: filter_count,
            ..signal_shape
        };
        let signal_buffer = Buffer::builder()
            .queue(conv.queue().clone())
            .len(signal_shape.buffer_len())
            .flags(flags::MEM_READ_ONLY)
            .build()?;
        let output_buffer = Buffer::builder()
            .queue(conv.queue().clone())
            .len(output_shape.buffer_len())
            .flags(flags::MEM_HOST_READ_ONLY | flags::MEM_WRITE_ONLY)
            .build()?;
        let signal_dims = Uint3::new(
            signal_shape.height,
            signal_shape.width,
            signal_shape.channels,
        );
        Ok(InputAndOutput {
            signal_buffer,
            signal_dims,
            output_buffer,
            output_shape,
        })
    }
    pub fn write_signal(&self, signal: FeatureMap<'_, T>) -> ocl::Result<()> {
        let signal = signal.to_nhwc();
        let signal_slice = signal.as_slice().map_or_else(
            || Cow::Owned(signal.iter().copied().collect()),
            Cow::Borrowed,
        );
        self.signal_buffer.write(signal_slice.as_ref()).enq()
    }
    pub fn pass_as_arguments(&self, kernel: &Kernel) -> ocl::Result<()> {
        kernel.set_arg("signal_dims", self.signal_dims)
    }
    pub fn execute(&self, kernel: &Kernel, out_layout: Layout) -> ocl::Result<Array4<T>> {
        let s = self.output_shape;
        kernel.set_arg(
            "out_params",
            OutputParams {
                batch_size: s.batch_size,
                layout: out_layout,
            },
        )?;
        kernel.set_arg("output", &self.output_buffer)?;
        kernel.set_arg("signal", &self.signal_buffer)?;
        let command = kernel.cmd().global_work_size([
            s.height as usize * s.batch_size as usize,
            s.width as usize,
            s.channels as usize,
        ]);
        unsafe {
            command.enq()?;
        }
        let mut output_data = vec![T::default(); self.output_buffer.len()];
        self.output_buffer.read(&mut output_data).enq()?;
        let output =
            Array4::from_shape_vec(self.output_shape.as_array(out_layout), output_data).unwrap();
        Ok(output)
    }
}
#[derive(Debug, Clone)]
pub(crate) struct Pinned<T: ConvElement> {
    pub io: InputAndOutput<T>,
    pub signal_shape: FeatureMapShape,
}