1use core::{fmt, marker::PhantomData};
4
5use base64ct::{Base64UrlUnpadded, Encoding};
6use elliptic_curve::zeroize::Zeroizing;
7use serde::{
8 Deserialize, Deserializer, Serialize, Serializer,
9 de::{DeserializeOwned, Error as DeError, SeqAccess, Unexpected, Visitor},
10};
11
12use crate::{
13 Keypair, PublicKey, SecretKey,
14 alloc::{ToString, Vec, vec},
15 dkg::Opening,
16 group::Group,
17};
18
19fn serialize_bytes<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
20where
21 S: Serializer,
22{
23 if serializer.is_human_readable() {
24 serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
25 } else {
26 serializer.serialize_bytes(value)
27 }
28}
29
30fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
31where
32 D: Deserializer<'de>,
33{
34 struct Base64Visitor;
35
36 impl Visitor<'_> for Base64Visitor {
37 type Value = Vec<u8>;
38
39 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40 formatter.write_str("base64url-encoded data")
41 }
42
43 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
44 Base64UrlUnpadded::decode_vec(value)
45 .map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
46 }
47
48 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
49 Ok(value.to_vec())
50 }
51
52 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
53 Ok(value)
54 }
55 }
56
57 struct BytesVisitor;
58
59 impl Visitor<'_> for BytesVisitor {
60 type Value = Vec<u8>;
61
62 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63 formatter.write_str("byte buffer")
64 }
65
66 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
67 Ok(value.to_vec())
68 }
69
70 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
71 Ok(value)
72 }
73 }
74
75 if deserializer.is_human_readable() {
76 deserializer.deserialize_str(Base64Visitor)
77 } else {
78 deserializer.deserialize_byte_buf(BytesVisitor)
79 }
80}
81
82impl Serialize for Opening {
83 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 serialize_bytes(self.0.as_slice(), serializer)
88 }
89}
90
91impl<'de> Deserialize<'de> for Opening {
92 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93 where
94 D: Deserializer<'de>,
95 {
96 let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
97 let mut opening = Opening(Zeroizing::new([0_u8; 32]));
98 if bytes.len() == 32 {
99 opening.0.copy_from_slice(&bytes);
100 Ok(opening)
101 } else {
102 Err(D::Error::invalid_length(bytes.len(), &"32"))
103 }
104 }
105}
106
107impl<G: Group> Serialize for PublicKey<G> {
108 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
109 where
110 S: Serializer,
111 {
112 serialize_bytes(self.as_bytes(), serializer)
113 }
114}
115
116impl<'de, G: Group> Deserialize<'de> for PublicKey<G> {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: Deserializer<'de>,
120 {
121 let bytes = deserialize_bytes(deserializer)?;
122 Self::from_bytes(&bytes).map_err(D::Error::custom)
123 }
124}
125
126impl<G: Group> Serialize for SecretKey<G> {
127 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: Serializer,
130 {
131 let mut bytes = Zeroizing::new(vec![0_u8; G::SCALAR_SIZE]);
132 G::serialize_scalar(self.expose_scalar(), &mut bytes);
133 serialize_bytes(&bytes, serializer)
134 }
135}
136
137impl<'de, G: Group> Deserialize<'de> for SecretKey<G> {
138 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139 where
140 D: Deserializer<'de>,
141 {
142 let bytes = Zeroizing::new(deserialize_bytes(deserializer)?);
143 Self::from_bytes(&bytes)
144 .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
145 }
146}
147
148impl<G: Group> Serialize for Keypair<G> {
149 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150 where
151 S: Serializer,
152 {
153 self.secret().serialize(serializer)
154 }
155}
156
157impl<'de, G: Group> Deserialize<'de> for Keypair<G> {
158 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159 where
160 D: Deserializer<'de>,
161 {
162 SecretKey::<G>::deserialize(deserializer).map(From::from)
163 }
164}
165
166pub(crate) trait Helper: Serialize + DeserializeOwned {
168 const PLURAL_DESCRIPTION: &'static str;
169 type Target;
170
171 fn from_target(target: &Self::Target) -> Self;
172 fn into_target(self) -> Self::Target;
173}
174
175#[derive(Debug)]
179pub(crate) struct ScalarHelper<G: Group>(G::Scalar);
180
181impl<G: Group> ScalarHelper<G> {
182 pub fn serialize<S>(scalar: &G::Scalar, serializer: S) -> Result<S::Ok, S::Error>
183 where
184 S: Serializer,
185 {
186 let mut bytes = vec![0_u8; G::SCALAR_SIZE];
187 G::serialize_scalar(scalar, &mut bytes);
188 serialize_bytes(&bytes, serializer)
189 }
190
191 pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Scalar, D::Error>
192 where
193 D: Deserializer<'de>,
194 {
195 let bytes = deserialize_bytes(deserializer)?;
196 if bytes.len() == G::SCALAR_SIZE {
197 G::deserialize_scalar(&bytes)
198 .ok_or_else(|| D::Error::custom("bytes do not represent a group scalar"))
199 } else {
200 let expected_len = G::SCALAR_SIZE.to_string();
201 Err(D::Error::invalid_length(
202 bytes.len(),
203 &expected_len.as_str(),
204 ))
205 }
206 }
207}
208
209impl<G: Group> Serialize for ScalarHelper<G> {
210 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
211 where
212 S: Serializer,
213 {
214 Self::serialize(&self.0, serializer)
215 }
216}
217
218impl<'de, G: Group> Deserialize<'de> for ScalarHelper<G> {
219 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220 where
221 D: Deserializer<'de>,
222 {
223 Self::deserialize(deserializer).map(Self)
224 }
225}
226
227impl<G: Group> Helper for ScalarHelper<G> {
228 const PLURAL_DESCRIPTION: &'static str = "group scalars";
229 type Target = G::Scalar;
230
231 fn from_target(target: &Self::Target) -> Self {
232 Self(*target)
233 }
234
235 fn into_target(self) -> Self::Target {
236 self.0
237 }
238}
239
240#[derive(Debug)]
242pub(crate) struct ElementHelper<G: Group>(G::Element);
243
244impl<G: Group> ElementHelper<G> {
245 pub fn serialize<S>(element: &G::Element, serializer: S) -> Result<S::Ok, S::Error>
246 where
247 S: Serializer,
248 {
249 let mut bytes = vec![0_u8; G::ELEMENT_SIZE];
250 G::serialize_element(element, &mut bytes);
251 serialize_bytes(&bytes, serializer)
252 }
253
254 pub fn deserialize<'de, D>(deserializer: D) -> Result<G::Element, D::Error>
255 where
256 D: Deserializer<'de>,
257 {
258 let bytes = deserialize_bytes(deserializer)?;
259 if bytes.len() == G::ELEMENT_SIZE {
260 G::deserialize_element(&bytes)
261 .ok_or_else(|| D::Error::custom("bytes do not represent a group element"))
262 } else {
263 let expected_len = G::ELEMENT_SIZE.to_string();
264 Err(D::Error::invalid_length(
265 bytes.len(),
266 &expected_len.as_str(),
267 ))
268 }
269 }
270}
271
272impl<G: Group> Serialize for ElementHelper<G> {
273 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
274 where
275 S: Serializer,
276 {
277 Self::serialize(&self.0, serializer)
278 }
279}
280
281impl<'de, G: Group> Deserialize<'de> for ElementHelper<G> {
282 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283 where
284 D: Deserializer<'de>,
285 {
286 Self::deserialize(deserializer).map(Self)
287 }
288}
289
290impl<G: Group> Helper for ElementHelper<G> {
291 const PLURAL_DESCRIPTION: &'static str = "group elements";
292 type Target = G::Element;
293
294 fn from_target(target: &Self::Target) -> Self {
295 Self(*target)
296 }
297
298 fn into_target(self) -> Self::Target {
299 self.0
300 }
301}
302
303pub(crate) struct VecHelper<T, const MIN: usize>(PhantomData<T>);
304
305impl<T: Helper, const MIN: usize> VecHelper<T, MIN> {
306 fn new() -> Self {
307 Self(PhantomData)
308 }
309
310 pub fn serialize<S>(values: &[T::Target], serializer: S) -> Result<S::Ok, S::Error>
311 where
312 S: Serializer,
313 {
314 debug_assert!(values.len() >= MIN);
315 serializer.collect_seq(values.iter().map(T::from_target))
316 }
317
318 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<T::Target>, D::Error>
319 where
320 D: Deserializer<'de>,
321 {
322 deserializer.deserialize_seq(Self::new())
323 }
324}
325
326impl<'de, T: Helper, const MIN: usize> Visitor<'de> for VecHelper<T, MIN> {
327 type Value = Vec<T::Target>;
328
329 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
330 write!(formatter, "at least {MIN} {}", T::PLURAL_DESCRIPTION)
331 }
332
333 fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
334 where
335 S: SeqAccess<'de>,
336 {
337 let mut scalars: Vec<T::Target> = if let Some(size) = access.size_hint() {
338 if size < MIN {
339 return Err(S::Error::invalid_length(size, &self));
340 }
341 Vec::with_capacity(size)
342 } else {
343 Vec::new()
344 };
345
346 while let Some(value) = access.next_element::<T>()? {
347 scalars.push(value.into_target());
348 }
349 if scalars.len() >= MIN {
350 Ok(scalars)
351 } else {
352 Err(S::Error::invalid_length(scalars.len(), &self))
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::group::Ristretto;
361
362 #[test]
363 fn opening_roundtrip() {
364 let opening = Opening(Zeroizing::new([6; 32]));
365 let json = serde_json::to_value(&opening).unwrap();
366 assert!(json.is_string(), "{json:?}");
367 let opening_copy: Opening = serde_json::from_value(json).unwrap();
368 assert_eq!(opening_copy.0, opening.0);
369 }
370
371 #[test]
372 fn key_roundtrip() {
373 let keypair = Keypair::<Ristretto>::generate(&mut rand::rng());
374 let json = serde_json::to_value(&keypair).unwrap();
375 assert!(json.is_string(), "{json:?}");
376 let keypair_copy: Keypair<Ristretto> = serde_json::from_value(json).unwrap();
377 assert_eq!(keypair_copy.public(), keypair.public());
378
379 let json = serde_json::to_value(keypair.public()).unwrap();
380 assert!(json.is_string(), "{json:?}");
381 let public_key: PublicKey<Ristretto> = serde_json::from_value(json).unwrap();
382 assert_eq!(public_key, *keypair.public());
383
384 let json = serde_json::to_value(keypair.secret()).unwrap();
385 assert!(json.is_string(), "{json:?}");
386 let secret_key: SecretKey<Ristretto> = serde_json::from_value(json).unwrap();
387 assert_eq!(secret_key.expose_scalar(), keypair.secret().expose_scalar());
388 }
389
390 #[test]
391 fn public_key_deserialization_with_incorrect_length() {
392 let err = serde_json::from_str::<PublicKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
393 let err_string = err.to_string();
394 assert!(
395 err_string.contains("invalid size of the byte buffer"),
396 "{err_string}"
397 );
398 }
399
400 #[test]
401 fn public_key_deserialization_of_non_element() {
402 let err = serde_json::from_str::<PublicKey<Ristretto>>(
403 "\"tNDkeYUVQWgh34d-RqaElOk7yFB8d2qCh5f4Vi2euT0\"",
404 )
405 .unwrap_err();
406 let err_string = err.to_string();
407 assert!(
408 err_string.contains("does not represent a group element"),
409 "{err_string}"
410 );
411 }
412
413 #[test]
414 fn secret_key_deserialization_with_incorrect_length() {
415 let err = serde_json::from_str::<SecretKey<Ristretto>>("\"dGVzdA\"").unwrap_err();
416 let err_string = err.to_string();
417 assert!(
418 err_string.contains("bytes do not represent a group scalar"),
419 "{err_string}"
420 );
421 }
422
423 #[test]
424 fn secret_key_deserialization_of_invalid_scalar() {
425 let err = serde_json::from_str::<SecretKey<Ristretto>>(
428 "\"nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8\"",
429 )
430 .unwrap_err();
431 let err_string = err.to_string();
432 assert!(
433 err_string.contains("bytes do not represent a group scalar"),
434 "{err_string}"
435 );
436 }
437
438 #[derive(Debug, PartialEq, Serialize, Deserialize)]
439 #[serde(bound = "")]
440 struct TestObject<G: Group> {
441 #[serde(with = "ScalarHelper::<G>")]
442 scalar: G::Scalar,
443 #[serde(with = "ElementHelper::<G>")]
444 element: G::Element,
445 #[serde(with = "VecHelper::<ScalarHelper<G>, 2>")]
446 more_scalars: Vec<G::Scalar>,
447 }
448
449 impl TestObject<Ristretto> {
450 fn sample() -> Self {
451 Self {
452 scalar: 12345_u64.into(),
453 element: Ristretto::mul_generator(&54321_u64.into()),
454 more_scalars: vec![7_u64.into(), 890_u64.into()],
455 }
456 }
457 }
458
459 #[test]
460 fn helpers_roundtrip() {
461 let object = TestObject::sample();
462 let json = serde_json::to_value(&object).unwrap();
463 let object_copy: TestObject<Ristretto> = serde_json::from_value(json).unwrap();
464 assert_eq!(object_copy, object);
465 }
466
467 #[test]
468 fn scalar_helper_invalid_scalar() {
469 let object = TestObject::sample();
470 let mut json = serde_json::to_value(object).unwrap();
471 json.as_object_mut()
472 .unwrap()
473 .insert("scalar".into(), "dGVzdA".into());
474
475 let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
476 let err_string = err.to_string();
477 assert!(
478 err_string.contains("invalid length 4, expected 32"),
479 "{err_string}"
480 );
481
482 json.as_object_mut().unwrap().insert(
483 "scalar".into(),
484 "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
485 );
486 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
487 let err_string = err.to_string();
488 assert!(
489 err_string.contains("bytes do not represent a group scalar"),
490 "{err_string}"
491 );
492 }
493
494 #[test]
495 fn element_helper_invalid_element() {
496 let object = TestObject::sample();
497 let mut json = serde_json::to_value(object).unwrap();
498 json.as_object_mut()
499 .unwrap()
500 .insert("element".into(), "dGVzdA".into());
501
502 let err = serde_json::from_value::<TestObject<Ristretto>>(json.clone()).unwrap_err();
503 let err_string = err.to_string();
504 assert!(
505 err_string.contains("invalid length 4, expected 32"),
506 "{err_string}"
507 );
508
509 json.as_object_mut().unwrap().insert(
510 "element".into(),
511 "nN3xf7lSOX0_zs6QPBwWHYi0Dkx2Ln_z1MPwnbzaM_8".into(),
512 );
513 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
514 let err_string = err.to_string();
515 assert!(
516 err_string.contains("bytes do not represent a group element"),
517 "{err_string}"
518 );
519 }
520
521 #[test]
522 fn vec_helper_invalid_length() {
523 let object = TestObject::sample();
524 let mut json = serde_json::to_value(object).unwrap();
525 let more_scalars = &mut json.as_object_mut().unwrap()["more_scalars"];
526 more_scalars.as_array_mut().unwrap().pop();
527
528 let err = serde_json::from_value::<TestObject<Ristretto>>(json).unwrap_err();
529 let err_string = err.to_string();
530 assert!(
531 err_string.contains("invalid length 1, expected at least 2 group scalars"),
532 "{err_string}"
533 );
534 }
535}