elastic_elgamal/
serde.rs

1//! (De)serialization utils.
2
3use core::{fmt, marker::PhantomData};
4
5use base64ct::{Base64UrlUnpadded, Encoding};
6use elliptic_curve::zeroize::Zeroizing;
7use serde::{
8    Deserialize, Deserializer, Serialize, Serializer,
9    de::{DeserializeOwned, Error as DeError, SeqAccess, Unexpected, Visitor},
10};
11
12use crate::{
13    Keypair, PublicKey, SecretKey,
14    alloc::{ToString, Vec, vec},
15    dkg::Opening,
16    group::Group,
17};
18
19fn serialize_bytes<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
20where
21    S: Serializer,
22{
23    if serializer.is_human_readable() {
24        serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
25    } else {
26        serializer.serialize_bytes(value)
27    }
28}
29
30fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
31where
32    D: Deserializer<'de>,
33{
34    struct Base64Visitor;
35
36    impl Visitor<'_> for Base64Visitor {
37        type Value = Vec<u8>;
38
39        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40            formatter.write_str("base64url-encoded data")
41        }
42
43        fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
44            Base64UrlUnpadded::decode_vec(value)
45                .map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
46        }
47
48        fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
49            Ok(value.to_vec())
50        }
51
52        fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
53            Ok(value)
54        }
55    }
56
57    struct BytesVisitor;
58
59    impl Visitor<'_> for BytesVisitor {
60        type Value = Vec<u8>;
61
62        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63            formatter.write_str("byte buffer")
64        }
65
66        fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
67            Ok(value.to_vec())
68        }
69
70        fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
71            Ok(value)
72        }
73    }
74
75    if deserializer.is_human_readable() {
76        deserializer.deserialize_str(Base64Visitor)
77    } else {
78        deserializer.deserialize_byte_buf(BytesVisitor)
79    }
80}
81
82impl Serialize for Opening {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        serialize_bytes(self.0.as_slice(), serializer)
88    }
89}
90
91impl<'de> Deserialize<'de> for Opening {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: Deserializer<'de>,
95    {
96        let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
97        let mut opening = Opening(Zeroizing::new([0_u8; 32]));
98        if bytes.len() == 32 {
99            opening.0.copy_from_slice(&bytes);
100            Ok(opening)
101        } else {
102            Err(D::Error::invalid_length(bytes.len(), &"32"))
103        }
104    }
105}
106
107impl<G: Group> Serialize for PublicKey<G> {
108    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
109    where
110        S: Serializer,
111    {
112        serialize_bytes(self.as_bytes(), serializer)
113    }
114}
115
116impl<'de, G: Group> Deserialize<'de> for PublicKey<G> {
117    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118    where
119        D: Deserializer<'de>,
120    {
121        let bytes = deserialize_bytes(deserializer)?;
122        Self::from_bytes(&bytes).map_err(D::Error::custom)
123    }
124}
125
126impl<G: Group> Serialize for SecretKey<G> {
127    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128    where
129        S: Serializer,
130    {
131        let mut bytes = Zeroizing::new(vec![0_u8; G::SCALAR_SIZE]);
132        G::serialize_scalar(self.expose_scalar(), &mut bytes);
133        serialize_bytes(&bytes, serializer)
134    }
135}
136
137impl<'de, G: Group> Deserialize<'de> for SecretKey<G> {
138    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139    where
140        D: Deserializer<'de>,
141    {
142        let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
143        Self::from_bytes(&bytes)
144            .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
145    }
146}
147
148impl<G: Group> Serialize for Keypair<G> {
149    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150    where
151        S: Serializer,
152    {
153        self.secret().serialize(serializer)
154    }
155}
156
157impl<'de, G: Group> Deserialize<'de> for Keypair<G> {
158    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159    where
160        D: Deserializer<'de>,
161    {
162        SecretKey::<G>::deserialize(deserializer).map(From::from)
163    }
164}
165
166/// Common functionality for serialization helpers.
167pub(crate) trait Helper: Serialize + DeserializeOwned {
168    const PLURAL_DESCRIPTION: &'static str;
169    type Target;
170
171    fn from_target(target: &Self::Target) -> Self;
172    fn into_target(self) -> Self::Target;
173}
174
175/// Helper type to deserialize scalars.
176///
177/// **NB.** Scalars are assumed to be public! Secret scalars must be serialized via `SecretKey`.
178#[derive(Debug)]
179pub(crate) struct ScalarHelper<G: Group>(G::Scalar);
180
181impl<G: Group> ScalarHelper<G> {
182    pub fn serialize<S>(scalar: &G::Scalar, serializer: S) -> Result<S::Ok, S::Error>
183    where
184        S: Serializer,
185    {
186        let mut bytes = vec![0_u8; G::SCALAR_SIZE];
187        G::serialize_scalar(scalar, &mut bytes);
188        serialize_bytes(&bytes, serializer)
189    }
190
191    pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Scalar, D::Error>
192    where
193        D: Deserializer<'de>,
194    {
195        let bytes = deserialize_bytes(deserializer)?;
196        if bytes.len() == G::SCALAR_SIZE {
197            G::deserialize_scalar(&bytes)
198                .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
199        } else {
200            let expected_len = G::SCALAR_SIZE.to_string();
201            Err(D::Error::invalid_length(
202                bytes.len(),
203                &expected_len.as_str(),
204            ))
205        }
206    }
207}
208
209impl<G: Group> Serialize for ScalarHelper<G> {
210    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
211    where
212        S: Serializer,
213    {
214        Self::serialize(&self.0, serializer)
215    }
216}
217
218impl<'de, G: Group> Deserialize<'de> for ScalarHelper<G> {
219    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220    where
221        D: Deserializer<'de>,
222    {
223        Self::deserialize(deserializer).map(Self)
224    }
225}
226
227impl<G: Group> Helper for ScalarHelper<G> {
228    const PLURAL_DESCRIPTION: &'static str = "group scalars";
229    type Target = G::Scalar;
230
231    fn from_target(target: &Self::Target) -> Self {
232        Self(*target)
233    }
234
235    fn into_target(self) -> Self::Target {
236        self.0
237    }
238}
239
240/// Helper type to deserialize group elements.
241#[derive(Debug)]
242pub(crate) struct ElementHelper<G: Group>(G::Element);
243
244impl<G: Group> ElementHelper<G> {
245    pub fn serialize<S>(element: &G::Element, serializer: S) -> Result<S::Ok, S::Error>
246    where
247        S: Serializer,
248    {
249        let mut bytes = vec![0_u8; G::ELEMENT_SIZE];
250        G::serialize_element(element, &mut bytes);
251        serialize_bytes(&bytes, serializer)
252    }
253
254    pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Element, D::Error>
255    where
256        D: Deserializer<'de>,
257    {
258        let bytes = deserialize_bytes(deserializer)?;
259        if bytes.len() == G::ELEMENT_SIZE {
260            G::deserialize_element(&bytes)
261                .ok_or_else(|| D::Error::custom("bytes do not represent a group element"))
262        } else {
263            let expected_len = G::ELEMENT_SIZE.to_string();
264            Err(D::Error::invalid_length(
265                bytes.len(),
266                &expected_len.as_str(),
267            ))
268        }
269    }
270}
271
272impl<G: Group> Serialize for ElementHelper<G> {
273    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
274    where
275        S: Serializer,
276    {
277        Self::serialize(&self.0, serializer)
278    }
279}
280
281impl<'de, G: Group> Deserialize<'de> for ElementHelper<G> {
282    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283    where
284        D: Deserializer<'de>,
285    {
286        Self::deserialize(deserializer).map(Self)
287    }
288}
289
290impl<G: Group> Helper for ElementHelper<G> {
291    const PLURAL_DESCRIPTION: &'static str = "group elements";
292    type Target = G::Element;
293
294    fn from_target(target: &Self::Target) -> Self {
295        Self(*target)
296    }
297
298    fn into_target(self) -> Self::Target {
299        self.0
300    }
301}
302
303pub(crate) struct VecHelper<T, const MIN: usize>(PhantomData<T>);
304
305impl<T: Helper, const MIN: usize> VecHelper<T, MIN> {
306    fn new() -> Self {
307        Self(PhantomData)
308    }
309
310    pub fn serialize<S>(values: &[T::Target], serializer: S) -> Result<S::Ok, S::Error>
311    where
312        S: Serializer,
313    {
314        debug_assert!(values.len() >= MIN);
315        serializer.collect_seq(values.iter().map(T::from_target))
316    }
317
318    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<T::Target>, D::Error>
319    where
320        D: Deserializer<'de>,
321    {
322        deserializer.deserialize_seq(Self::new())
323    }
324}
325
326impl<'de, T: Helper, const MIN: usize> Visitor<'de> for VecHelper<T, MIN> {
327    type Value = Vec<T::Target>;
328
329    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
330        write!(formatter, "at least {MIN} {}", T::PLURAL_DESCRIPTION)
331    }
332
333    fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
334    where
335        S: SeqAccess<'de>,
336    {
337        let mut scalars: Vec<T::Target> = if let Some(size) = access.size_hint() {
338            if size < MIN {
339                return Err(S::Error::invalid_length(size, &self));
340            }
341            Vec::with_capacity(size)
342        } else {
343            Vec::new()
344        };
345
346        while let Some(value) = access.next_element::<T>()? {
347            scalars.push(value.into_target());
348        }
349        if scalars.len() >= MIN {
350            Ok(scalars)
351        } else {
352            Err(S::Error::invalid_length(scalars.len(), &self))
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::group::Ristretto;
361
362    #[test]
363    fn opening_roundtrip() {
364        let opening = Opening(Zeroizing::new([6; 32]));
365        let json = serde_json::to_value(&opening).unwrap();
366        assert!(json.is_string(), "{json:?}");
367        let opening_copy: Opening = serde_json::from_value(json).unwrap();
368        assert_eq!(opening_copy.0, opening.0);
369    }
370
371    #[test]
372    fn key_roundtrip() {
373        let keypair = Keypair::<Ristretto>::generate(&mut rand::rng());
374        let json = serde_json::to_value(&keypair).unwrap();
375        assert!(json.is_string(), "{json:?}");
376        let keypair_copy: Keypair<Ristretto> = serde_json::from_value(json).unwrap();
377        assert_eq!(keypair_copy.public(), keypair.public());
378
379        let json = serde_json::to_value(keypair.public()).unwrap();
380        assert!(json.is_string(), "{json:?}");
381        let public_key: PublicKey<Ristretto> = serde_json::from_value(json).unwrap();
382        assert_eq!(public_key, *keypair.public());
383
384        let json = serde_json::to_value(keypair.secret()).unwrap();
385        assert!(json.is_string(), "{json:?}");
386        let secret_key: SecretKey<Ristretto> = serde_json::from_value(json).unwrap();
387        assert_eq!(secret_key.expose_scalar(), keypair.secret().expose_scalar());
388    }
389
390    #[test]
391    fn public_key_deserialization_with_incorrect_length() {
392        let err = serde_json::from_str::<PublicKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
393        let err_string = err.to_string();
394        assert!(
395            err_string.contains("invalid size of the byte buffer"),
396            "{err_string}"
397        );
398    }
399
400    #[test]
401    fn public_key_deserialization_of_non_element() {
402        let err = serde_json::from_str::<PublicKey<Ristretto>>(
403            "\"tNDkeYUVQWgh34d-RqaElOk7yFB8d2qCh5f4Vi2euT0\"",
404        )
405        .unwrap_err();
406        let err_string = err.to_string();
407        assert!(
408            err_string.contains("does not represent a group element"),
409            "{err_string}"
410        );
411    }
412
413    #[test]
414    fn secret_key_deserialization_with_incorrect_length() {
415        let err = serde_json::from_str::<SecretKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
416        let err_string = err.to_string();
417        assert!(
418            err_string.contains("bytes do not represent a group scalar"),
419            "{err_string}"
420        );
421    }
422
423    #[test]
424    fn secret_key_deserialization_of_invalid_scalar() {
425        // Last `_8` chars set the upper byte of the scalar bytes to 0xff, which is invalid
426        // (all scalars are less than 2^253).
427        let err = serde_json::from_str::<SecretKey<Ristretto>>(
428            "\"nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8\"",
429        )
430        .unwrap_err();
431        let err_string = err.to_string();
432        assert!(
433            err_string.contains("bytes do not represent a group scalar"),
434            "{err_string}"
435        );
436    }
437
438    #[derive(Debug, PartialEq, Serialize, Deserialize)]
439    #[serde(bound = "")]
440    struct TestObject<G: Group> {
441        #[serde(with = "ScalarHelper::<G>")]
442        scalar: G::Scalar,
443        #[serde(with = "ElementHelper::<G>")]
444        element: G::Element,
445        #[serde(with = "VecHelper::<ScalarHelper<G>, 2>")]
446        more_scalars: Vec<G::Scalar>,
447    }
448
449    impl TestObject<Ristretto> {
450        fn sample() -> Self {
451            Self {
452                scalar: 12345_u64.into(),
453                element: Ristretto::mul_generator(&54321_u64.into()),
454                more_scalars: vec![7_u64.into(), 890_u64.into()],
455            }
456        }
457    }
458
459    #[test]
460    fn helpers_roundtrip() {
461        let object = TestObject::sample();
462        let json = serde_json::to_value(&object).unwrap();
463        let object_copy: TestObject<Ristretto> = serde_json::from_value(json).unwrap();
464        assert_eq!(object_copy, object);
465    }
466
467    #[test]
468    fn scalar_helper_invalid_scalar() {
469        let object = TestObject::sample();
470        let mut json = serde_json::to_value(object).unwrap();
471        json.as_object_mut()
472            .unwrap()
473            .insert("scalar".into(), "dGVzdA".into());
474
475        let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
476        let err_string = err.to_string();
477        assert!(
478            err_string.contains("invalid length 4, expected 32"),
479            "{err_string}"
480        );
481
482        json.as_object_mut().unwrap().insert(
483            "scalar".into(),
484            "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
485        );
486        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
487        let err_string = err.to_string();
488        assert!(
489            err_string.contains("bytes do not represent a group scalar"),
490            "{err_string}"
491        );
492    }
493
494    #[test]
495    fn element_helper_invalid_element() {
496        let object = TestObject::sample();
497        let mut json = serde_json::to_value(object).unwrap();
498        json.as_object_mut()
499            .unwrap()
500            .insert("element".into(), "dGVzdA".into());
501
502        let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
503        let err_string = err.to_string();
504        assert!(
505            err_string.contains("invalid length 4, expected 32"),
506            "{err_string}"
507        );
508
509        json.as_object_mut().unwrap().insert(
510            "element".into(),
511            "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
512        );
513        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
514        let err_string = err.to_string();
515        assert!(
516            err_string.contains("bytes do not represent a group element"),
517            "{err_string}"
518        );
519    }
520
521    #[test]
522    fn vec_helper_invalid_length() {
523        let object = TestObject::sample();
524        let mut json = serde_json::to_value(object).unwrap();
525        let more_scalars = &mut json.as_object_mut().unwrap()["more_scalars"];
526        more_scalars.as_array_mut().unwrap().pop();
527
528        let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
529        let err_string = err.to_string();
530        assert!(
531            err_string.contains("invalid length 1, expected at least 2 group scalars"),
532            "{err_string}"
533        );
534    }
535}