1use std::fmt;
22
23use image::{ImageBuffer, Luma, Pixel};
24
25pub trait PixelTransform<Pix: Pixel> {
31 type Output: Pixel + 'static;
33
34 fn transform_pixel(&self, pixel: Pix) -> Self::Output;
36}
37
38impl<Pix: Pixel + 'static> PixelTransform<Pix> for () {
40 type Output = Pix;
41
42 #[inline]
43 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
44 pixel
45 }
46}
47
48impl<Pix, O> PixelTransform<Pix> for Box<dyn PixelTransform<Pix, Output = O>>
49where
50 Pix: Pixel,
51 O: Pixel + 'static,
52{
53 type Output = O;
54
55 #[inline]
56 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
57 (**self).transform_pixel(pixel)
58 }
59}
60
61impl<F, G, Pix: Pixel> PixelTransform<Pix> for (F, G)
64where
65 F: PixelTransform<Pix>,
66 G: PixelTransform<F::Output>,
67{
68 type Output = G::Output;
69
70 #[inline]
71 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
72 self.1.transform_pixel(self.0.transform_pixel(pixel))
73 }
74}
75
76#[derive(Debug, Clone, Copy, Default)]
78pub struct Negative;
79
80impl PixelTransform<Luma<u8>> for Negative {
81 type Output = Luma<u8>;
82
83 #[inline]
84 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
85 Luma([u8::MAX - pixel[0]])
86 }
87}
88
89#[derive(Debug, Clone, Copy, Default)]
92pub struct Smoothstep;
93
94impl PixelTransform<Luma<u8>> for Smoothstep {
95 type Output = Luma<u8>;
96
97 #[inline]
98 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
99 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
100 let clamped_x = f32::from(pixel[0]) / 255.0;
101 let output = clamped_x * clamped_x * (3.0 - 2.0 * clamped_x);
102 Luma([(output * 255.0).round() as u8])
103 }
104}
105
106#[derive(Clone)]
108pub struct Palette<T> {
109 pixels: [T; 256],
110}
111
112impl<T: fmt::Debug> fmt::Debug for Palette<T> {
113 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
114 formatter
115 .debug_struct("Palette")
116 .field("pixels", &(&self.pixels as &[T]))
117 .finish()
118 }
119}
120
121impl<T> Palette<T>
122where
123 T: Pixel<Subpixel = u8> + 'static,
124{
125 #[allow(
133 clippy::cast_precision_loss,
134 clippy::cast_possible_truncation,
135 clippy::cast_sign_loss
136 )]
137 pub fn new(colors: &[T]) -> Self {
138 assert!(colors.len() >= 2, "palette must contain at least 2 colors");
139 assert!(
140 colors.len() <= 256,
141 "palette cannot contain more than 256 colors"
142 );
143 let len_scale = (colors.len() - 1) as f32;
144 let zero_slice = [0_u8; 4];
145 let zero_slice = &zero_slice[..T::CHANNEL_COUNT as usize];
146
147 let mut pixels = [*T::from_slice(zero_slice); 256];
148 for (i, pixel) in pixels.iter_mut().enumerate() {
149 let float_i = i as f32 / 255.0 * len_scale;
150
151 let mut prev_color_idx = float_i as usize; if prev_color_idx == colors.len() - 1 {
153 prev_color_idx -= 1;
154 }
155 debug_assert!(prev_color_idx + 1 < colors.len());
156
157 let prev_color = colors[prev_color_idx].channels();
158 let next_color = colors[prev_color_idx + 1].channels();
159 let blend_factor = float_i - prev_color_idx as f32;
160 debug_assert!((0.0..=1.0).contains(&blend_factor));
161
162 let mut blended_channels = [0_u8; 4];
163 let channel_count = T::CHANNEL_COUNT as usize;
164 for (ch, blended_channel) in blended_channels[..channel_count].iter_mut().enumerate() {
165 let blended = f32::from(prev_color[ch]) * (1.0 - blend_factor)
166 + f32::from(next_color[ch]) * blend_factor;
167 *blended_channel = blended.round() as u8;
168 }
169 *pixel = *T::from_slice(&blended_channels[..channel_count]);
170 }
171
172 Self { pixels }
173 }
174}
175
176impl<Pix: Pixel + 'static> PixelTransform<Luma<u8>> for Palette<Pix> {
177 type Output = Pix;
178
179 #[inline]
180 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
181 self.pixels[pixel[0] as usize]
182 }
183}
184
185pub trait ApplyTransform<Pix: Pixel, F> {
190 type CombinedTransform: PixelTransform<Pix>;
192 fn apply(self, transform: F) -> ImageAndTransform<Pix, Self::CombinedTransform>;
194}
195
196#[derive(Debug)]
199pub struct ImageAndTransform<Pix, F>
200where
201 Pix: Pixel,
202{
203 source_image: ImageBuffer<Pix, Vec<Pix::Subpixel>>,
204 transform: F,
205}
206
207impl<Pix, F> ImageAndTransform<Pix, F>
208where
209 Pix: Pixel + Copy + 'static,
210 F: PixelTransform<Pix>,
211 <F::Output as Pixel>::Subpixel: 'static,
212{
213 pub fn transform(&self) -> ImageBuffer<F::Output, Vec<<F::Output as Pixel>::Subpixel>> {
215 let mut output = ImageBuffer::new(self.source_image.width(), self.source_image.height());
216
217 let output_iter = self
218 .source_image
219 .enumerate_pixels()
220 .map(|(x, y, pixel)| (x, y, self.transform.transform_pixel(*pixel)));
221 for (x, y, out_pixel) in output_iter {
222 output[(x, y)] = out_pixel;
223 }
224 output
225 }
226}
227
228impl<Pix, F> ApplyTransform<Pix, F> for ImageBuffer<Pix, Vec<Pix::Subpixel>>
229where
230 Pix: Pixel,
231 F: PixelTransform<Pix>,
232{
233 type CombinedTransform = F;
234
235 fn apply(self, transform: F) -> ImageAndTransform<Pix, F> {
236 ImageAndTransform {
237 source_image: self,
238 transform,
239 }
240 }
241}
242
243impl<Pix, F, G> ApplyTransform<Pix, G> for ImageAndTransform<Pix, F>
244where
245 Pix: Pixel,
246 F: PixelTransform<Pix>,
247 G: PixelTransform<F::Output>,
248{
249 type CombinedTransform = (F, G);
250
251 fn apply(self, transform: G) -> ImageAndTransform<Pix, (F, G)> {
252 ImageAndTransform {
253 source_image: self.source_image,
254 transform: (self.transform, transform),
255 }
256 }
257}
258
259#[cfg(test)]
260#[allow(
261 clippy::cast_possible_truncation,
262 clippy::cast_precision_loss,
263 clippy::cast_sign_loss
264)]
265mod tests {
266 use image::{GrayImage, Rgb};
267
268 use super::*;
269
270 #[test]
271 fn simple_transform() {
272 let image = GrayImage::from_fn(100, 100, |x, y| Luma::from([(x + y) as u8]));
273 let image = image.apply(Negative).apply(Smoothstep).transform();
274 for (x, y, pix) in image.enumerate_pixels() {
275 let negated = (255 - x - y) as f32 / 255.0;
276 let smoothed = negated * negated * (3.0 - 2.0 * negated);
277 let expected_pixel = (smoothed * 255.0).round() as u8;
278 assert_eq!(pix[0], expected_pixel);
279 }
280 }
281
282 #[test]
283 fn palette_basics() {
284 let palette = Palette::new(&[Rgb([0, 255, 0]), Rgb([255, 255, 255])]);
285 for (i, &pixel) in palette.pixels.iter().enumerate() {
286 assert_eq!(pixel, Rgb([i as u8, 255, i as u8]));
287 }
288 }
289}