use ocl::{
prm::{Uint2, Uint4},
OclPrm,
};
use std::marker::PhantomData;
use crate::{
buffers::{Filters, Layout, Pinned},
ConvElement,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Params {
pub strides: [u32; 2],
pub pads: [u32; 4],
pub groups: u32,
pub dilation: [u32; 2],
}
impl Default for Params {
fn default() -> Self {
Self {
strides: [1, 1],
pads: [0; 4],
groups: 1,
dilation: [1, 1],
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
#[repr(C, packed)]
pub struct ClParams {
strides: Uint2,
pads: Uint4,
groups: u32,
dilation: Uint2,
}
impl From<Params> for ClParams {
fn from(value: Params) -> Self {
ClParams {
strides: Uint2::from(value.strides),
pads: Uint4::from(value.pads),
groups: value.groups,
dilation: Uint2::from(value.dilation),
}
}
}
unsafe impl OclPrm for ClParams {}
#[derive(Debug, Clone, Copy)]
pub struct I8Params {
pub common: Params,
pub bit_shift: u8,
pub scale: i32,
pub output_bias: i32,
pub signal_bias: i32,
pub filter_bias: i32,
}
impl From<I8Params> for Params {
fn from(value: I8Params) -> Self {
value.common
}
}
impl I8Params {
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
pub fn convert_scale(bit_shift: u8, scale: f32) -> i32 {
let scale = (2.0_f32.powi(i32::from(bit_shift)) * scale).round();
assert!(
scale >= i32::MIN as f32 && scale <= i32::MAX as f32,
"Scale is out of `i32` bounds"
);
scale as i32
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
#[repr(C, packed)]
pub struct ClI8Params {
strides: Uint2,
pads: Uint4,
group: u32,
dilation: Uint2,
bit_shift: i32,
scale: i32,
output_bias: i32,
signal_bias: i32,
filter_bias: i32,
}
impl From<I8Params> for ClI8Params {
fn from(value: I8Params) -> Self {
let common_params = ClParams::from(value.common);
ClI8Params {
strides: common_params.strides,
pads: common_params.pads,
group: common_params.groups,
dilation: common_params.dilation,
bit_shift: i32::from(value.bit_shift),
scale: value.scale,
output_bias: value.output_bias,
signal_bias: value.signal_bias,
filter_bias: value.filter_bias,
}
}
}
unsafe impl OclPrm for ClI8Params {}
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
#[repr(C, packed)]
pub(crate) struct OutputParams {
pub batch_size: u32,
pub layout: Layout,
}
unsafe impl OclPrm for OutputParams {}
impl Default for OutputParams {
fn default() -> Self {
Self {
batch_size: 0,
layout: Layout::ChannelsLast,
}
}
}
pub(crate) trait WithParams {
type Params: Copy + Into<Params> + Into<Self::ClParams>;
type ClParams: OclPrm;
}
impl<T: ConvElement> WithParams for PhantomData<T> {
type Params = T::Params;
type ClParams = T::ClParams;
}
impl<T: ConvElement> WithParams for Filters<T> {
type Params = T::Params;
type ClParams = T::ClParams;
}
impl<T: ConvElement> WithParams for Pinned<T> {
type Params = T::Params;
type ClParams = T::ClParams;
}