ocl_convolution/
base.rs

1use std::{convert::TryFrom, marker::PhantomData, sync::Mutex};
2
3use ndarray::{Array4, ArrayView4};
4use ocl::{
5    builders::KernelBuilder, prm::Uint3, Buffer, Context, Device, Kernel, Platform, ProQue,
6    Program, Queue,
7};
8
9use crate::{
10    buffers::{FeatureMap, FeatureMapShape, Filters, InputAndOutput, Pinned},
11    params::{OutputParams, Params, WithParams},
12    ConvElement,
13};
14
15/// Convolution builder. The same builder can be used to create multiple `Convolution`s
16/// which share the same spatial size.
17///
18/// A builder can be created using [`Convolution::f32()`](crate::Convolution::f32()) or
19/// [`Convolution::i8()`](crate::Convolution::i8()) methods.
20#[derive(Debug)]
21pub struct ConvolutionBuilder<T> {
22    program: ProQue,
23    filter_size: u32,
24    _element_type: PhantomData<T>,
25}
26
27impl<T: ConvElement> ConvolutionBuilder<T> {
28    /// Initializes a builder with a specific filter size.
29    pub(crate) fn new(
30        filter_size: u32,
31        defines: &[(&'static str, i32)],
32        source: &str,
33    ) -> ocl::Result<Self> {
34        // For some reason, certain OpenCL implementations (e.g., POCL) do not work well
35        // when the list of devices for a platform is queried from multiple threads.
36        // Hence, we introduce a `Mutex` to serialize these calls.
37        static MUTEX: Mutex<()> = Mutex::new(());
38
39        assert_eq!(
40            filter_size % 2,
41            1,
42            "Even convolution sizes are not supported"
43        );
44
45        let mut program_builder = Program::builder();
46        program_builder.cmplr_def(
47            "FILTER_SIZE",
48            i32::try_from(filter_size).expect("Cannot convert filter size to i32"),
49        );
50        for &(name, value) in defines {
51            program_builder.cmplr_def(name, value);
52        }
53        program_builder.source(source);
54
55        let (platform, device) = {
56            let _lock = MUTEX.lock().ok();
57            let platform = Platform::first()?;
58            (platform, Device::first(platform)?)
59        };
60
61        let context = Context::builder()
62            .platform(platform)
63            .devices(device)
64            .build()?;
65        let program = ProQue::new(
66            context.clone(),
67            Queue::new(&context, device, None)?,
68            program_builder.build(&context)?,
69            None::<usize>,
70        );
71
72        Ok(Self {
73            program,
74            filter_size,
75            _element_type: PhantomData,
76        })
77    }
78
79    fn kernel_builder(&self) -> KernelBuilder<'_> {
80        self.program.kernel_builder("conv")
81    }
82}
83
84fn create_io<T: ConvElement, U: WithParams>(
85    signal_shape: FeatureMapShape,
86    filters: &Filters<T>,
87    conv: &Base<U>,
88) -> ocl::Result<InputAndOutput<T>> {
89    assert_eq!(
90        signal_shape.channels,
91        filters.channel_count() * Into::<Params>::into(conv.params).groups,
92        "Channel dimensionality in signal and filters must agree"
93    );
94    let io = InputAndOutput::new(signal_shape, filters.filter_count(), conv)?;
95    io.pass_as_arguments(&conv.kernel).map(|()| io)
96}
97
98#[derive(Debug)]
99pub(crate) struct Base<T: WithParams> {
100    size: u32,
101    params: T::Params,
102    kernel: Kernel,
103    buffers: T,
104    context: Context,
105}
106
107impl<T: WithParams> Base<T> {
108    pub fn kernel(&self) -> &Kernel {
109        &self.kernel
110    }
111
112    pub fn queue(&self) -> &Queue {
113        self.kernel
114            .default_queue()
115            .expect("kernel must come with a pre-configured queue")
116    }
117
118    pub fn size(&self) -> u32 {
119        self.size
120    }
121
122    pub fn params(&self) -> T::Params {
123        self.params
124    }
125
126    pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
127        self.params = params;
128        self.kernel
129            .set_arg("params", Into::<T::ClParams>::into(params))
130    }
131}
132
133impl<T: ConvElement> Base<PhantomData<T>> {
134    pub fn new(builder: &ConvolutionBuilder<T>, params: T::Params) -> ocl::Result<Self> {
135        let kernel = builder
136            .kernel_builder()
137            .arg_named("output", None::<&Buffer<T>>)
138            .arg_named("out_params", OutputParams::default())
139            .arg_named("signal", None::<&Buffer<T>>)
140            .arg_named("signal_dims", Uint3::new(0, 0, 0))
141            .arg_named("filters", None::<&Buffer<T>>)
142            .arg_named("filter_biases", None::<&Buffer<T::Acc>>)
143            .arg_named("params", Into::<T::ClParams>::into(params))
144            .build()?;
145        Ok(Base {
146            size: builder.filter_size,
147            params,
148            kernel,
149            buffers: PhantomData,
150            context: builder.program.context().clone(),
151        })
152    }
153
154    pub fn with_filters(
155        self,
156        filters: &ArrayView4<'_, T>,
157        filter_biases: Option<&[T::Acc]>,
158    ) -> ocl::Result<Base<Filters<T>>> {
159        let filters = Filters::new(filters, filter_biases, &self)?;
160        Ok(Base {
161            buffers: filters,
162            size: self.size,
163            params: self.params,
164            kernel: self.kernel,
165            context: self.context,
166        })
167    }
168
169    pub fn compute(
170        &self,
171        signal: FeatureMap<'_, T>,
172        filters: &ArrayView4<'_, T>,
173        filter_biases: Option<&[T::Acc]>,
174    ) -> ocl::Result<Array4<T>> {
175        let filter_channels =
176            u32::try_from(filters.shape()[3]).expect("Cannot convert filter dimension to `u32`");
177        assert_eq!(
178            signal.shape().channels,
179            filter_channels * Into::<Params>::into(self.params).groups,
180            "Channel dimensionality in signal and filters must agree"
181        );
182
183        let filter_count =
184            u32::try_from(filters.shape()[0]).expect("Cannot convert filter count to `u32`");
185        let filters = Filters::new(filters, filter_biases, self)?;
186        filters.pass_as_arguments(&self.kernel)?;
187        let io = InputAndOutput::new(signal.shape(), filter_count, self)?;
188        io.write_signal(signal)?;
189        io.pass_as_arguments(&self.kernel)?;
190        io.execute(&self.kernel, signal.layout())
191    }
192}
193
194impl<T: ConvElement> Base<Filters<T>> {
195    pub fn pinned(self, signal_shape: FeatureMapShape) -> ocl::Result<Base<Pinned<T>>> {
196        let io = create_io(signal_shape, &self.buffers, &self)?;
197        Ok(Base {
198            size: self.size,
199            params: self.params,
200            kernel: self.kernel,
201            buffers: Pinned { io, signal_shape },
202            context: self.context,
203        })
204    }
205
206    pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
207        let io = create_io(signal.shape(), &self.buffers, self)?;
208        io.write_signal(signal)?;
209        io.execute(&self.kernel, signal.layout())
210    }
211}
212
213impl<T: ConvElement> Base<Pinned<T>> {
214    pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
215        assert_eq!(
216            signal.shape(),
217            self.buffers.signal_shape,
218            "Signal dimensions differ from the ones set when pinning signal memory"
219        );
220        self.buffers.io.write_signal(signal)?;
221        self.buffers.io.execute(&self.kernel, signal.layout())
222    }
223}