1use std::{convert::TryFrom, marker::PhantomData, sync::Mutex};
2
3use ndarray::{Array4, ArrayView4};
4use ocl::{
5 builders::KernelBuilder, prm::Uint3, Buffer, Context, Device, Kernel, Platform, ProQue,
6 Program, Queue,
7};
8
9use crate::{
10 buffers::{FeatureMap, FeatureMapShape, Filters, InputAndOutput, Pinned},
11 params::{OutputParams, Params, WithParams},
12 ConvElement,
13};
14
15#[derive(Debug)]
21pub struct ConvolutionBuilder<T> {
22 program: ProQue,
23 filter_size: u32,
24 _element_type: PhantomData<T>,
25}
26
27impl<T: ConvElement> ConvolutionBuilder<T> {
28 pub(crate) fn new(
30 filter_size: u32,
31 defines: &[(&'static str, i32)],
32 source: &str,
33 ) -> ocl::Result<Self> {
34 static MUTEX: Mutex<()> = Mutex::new(());
38
39 assert_eq!(
40 filter_size % 2,
41 1,
42 "Even convolution sizes are not supported"
43 );
44
45 let mut program_builder = Program::builder();
46 program_builder.cmplr_def(
47 "FILTER_SIZE",
48 i32::try_from(filter_size).expect("Cannot convert filter size to i32"),
49 );
50 for &(name, value) in defines {
51 program_builder.cmplr_def(name, value);
52 }
53 program_builder.source(source);
54
55 let (platform, device) = {
56 let _lock = MUTEX.lock().ok();
57 let platform = Platform::first()?;
58 (platform, Device::first(platform)?)
59 };
60
61 let context = Context::builder()
62 .platform(platform)
63 .devices(device)
64 .build()?;
65 let program = ProQue::new(
66 context.clone(),
67 Queue::new(&context, device, None)?,
68 program_builder.build(&context)?,
69 None::<usize>,
70 );
71
72 Ok(Self {
73 program,
74 filter_size,
75 _element_type: PhantomData,
76 })
77 }
78
79 fn kernel_builder(&self) -> KernelBuilder<'_> {
80 self.program.kernel_builder("conv")
81 }
82}
83
84fn create_io<T: ConvElement, U: WithParams>(
85 signal_shape: FeatureMapShape,
86 filters: &Filters<T>,
87 conv: &Base<U>,
88) -> ocl::Result<InputAndOutput<T>> {
89 assert_eq!(
90 signal_shape.channels,
91 filters.channel_count() * Into::<Params>::into(conv.params).groups,
92 "Channel dimensionality in signal and filters must agree"
93 );
94 let io = InputAndOutput::new(signal_shape, filters.filter_count(), conv)?;
95 io.pass_as_arguments(&conv.kernel).map(|()| io)
96}
97
98#[derive(Debug)]
99pub(crate) struct Base<T: WithParams> {
100 size: u32,
101 params: T::Params,
102 kernel: Kernel,
103 buffers: T,
104 context: Context,
105}
106
107impl<T: WithParams> Base<T> {
108 pub fn kernel(&self) -> &Kernel {
109 &self.kernel
110 }
111
112 pub fn queue(&self) -> &Queue {
113 self.kernel
114 .default_queue()
115 .expect("kernel must come with a pre-configured queue")
116 }
117
118 pub fn size(&self) -> u32 {
119 self.size
120 }
121
122 pub fn params(&self) -> T::Params {
123 self.params
124 }
125
126 pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
127 self.params = params;
128 self.kernel
129 .set_arg("params", Into::<T::ClParams>::into(params))
130 }
131}
132
133impl<T: ConvElement> Base<PhantomData<T>> {
134 pub fn new(builder: &ConvolutionBuilder<T>, params: T::Params) -> ocl::Result<Self> {
135 let kernel = builder
136 .kernel_builder()
137 .arg_named("output", None::<&Buffer<T>>)
138 .arg_named("out_params", OutputParams::default())
139 .arg_named("signal", None::<&Buffer<T>>)
140 .arg_named("signal_dims", Uint3::new(0, 0, 0))
141 .arg_named("filters", None::<&Buffer<T>>)
142 .arg_named("filter_biases", None::<&Buffer<T::Acc>>)
143 .arg_named("params", Into::<T::ClParams>::into(params))
144 .build()?;
145 Ok(Base {
146 size: builder.filter_size,
147 params,
148 kernel,
149 buffers: PhantomData,
150 context: builder.program.context().clone(),
151 })
152 }
153
154 pub fn with_filters(
155 self,
156 filters: &ArrayView4<'_, T>,
157 filter_biases: Option<&[T::Acc]>,
158 ) -> ocl::Result<Base<Filters<T>>> {
159 let filters = Filters::new(filters, filter_biases, &self)?;
160 Ok(Base {
161 buffers: filters,
162 size: self.size,
163 params: self.params,
164 kernel: self.kernel,
165 context: self.context,
166 })
167 }
168
169 pub fn compute(
170 &self,
171 signal: FeatureMap<'_, T>,
172 filters: &ArrayView4<'_, T>,
173 filter_biases: Option<&[T::Acc]>,
174 ) -> ocl::Result<Array4<T>> {
175 let filter_channels =
176 u32::try_from(filters.shape()[3]).expect("Cannot convert filter dimension to `u32`");
177 assert_eq!(
178 signal.shape().channels,
179 filter_channels * Into::<Params>::into(self.params).groups,
180 "Channel dimensionality in signal and filters must agree"
181 );
182
183 let filter_count =
184 u32::try_from(filters.shape()[0]).expect("Cannot convert filter count to `u32`");
185 let filters = Filters::new(filters, filter_biases, self)?;
186 filters.pass_as_arguments(&self.kernel)?;
187 let io = InputAndOutput::new(signal.shape(), filter_count, self)?;
188 io.write_signal(signal)?;
189 io.pass_as_arguments(&self.kernel)?;
190 io.execute(&self.kernel, signal.layout())
191 }
192}
193
194impl<T: ConvElement> Base<Filters<T>> {
195 pub fn pinned(self, signal_shape: FeatureMapShape) -> ocl::Result<Base<Pinned<T>>> {
196 let io = create_io(signal_shape, &self.buffers, &self)?;
197 Ok(Base {
198 size: self.size,
199 params: self.params,
200 kernel: self.kernel,
201 buffers: Pinned { io, signal_shape },
202 context: self.context,
203 })
204 }
205
206 pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
207 let io = create_io(signal.shape(), &self.buffers, self)?;
208 io.write_signal(signal)?;
209 io.execute(&self.kernel, signal.layout())
210 }
211}
212
213impl<T: ConvElement> Base<Pinned<T>> {
214 pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
215 assert_eq!(
216 signal.shape(),
217 self.buffers.signal_shape,
218 "Signal dimensions differ from the ones set when pinning signal memory"
219 );
220 self.buffers.io.write_signal(signal)?;
221 self.buffers.io.execute(&self.kernel, signal.layout())
222 }
223}