1use core::{fmt, iter, ops};
4
5use elliptic_curve::{rand_core::CryptoRng, zeroize::Zeroizing};
6use merlin::Transcript;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9
10use crate::{
11 Ciphertext, CiphertextWithValue, LogEqualityProof, PublicKey, RingProof, RingProofBuilder,
12 VerificationError,
13 alloc::{Vec, vec},
14 group::Group,
15};
16
17pub trait ProveSum<G: Group>: Clone + crate::sealed::Sealed {
22 #[cfg(not(feature = "serde"))]
24 type Proof: Sized;
25 #[cfg(feature = "serde")]
27 type Proof: Sized + Serialize + DeserializeOwned;
28
29 #[doc(hidden)]
30 fn prove<R: CryptoRng>(
31 &self,
32 ciphertext: &CiphertextWithValue<G, u64>,
33 receiver: &PublicKey<G>,
34 rng: &mut R,
35 ) -> Self::Proof;
36
37 #[doc(hidden)]
38 fn verify(
39 &self,
40 ciphertext: &Ciphertext<G>,
41 proof: &Self::Proof,
42 receiver: &PublicKey<G>,
43 ) -> Result<(), ChoiceVerificationError>;
44}
45
46#[derive(Debug, Clone, Copy)]
52pub struct SingleChoice(());
53
54impl crate::sealed::Sealed for SingleChoice {}
55
56impl<G: Group> ProveSum<G> for SingleChoice {
57 type Proof = LogEqualityProof<G>;
58
59 fn prove<R: CryptoRng>(
60 &self,
61 ciphertext: &CiphertextWithValue<G, u64>,
62 receiver: &PublicKey<G>,
63 rng: &mut R,
64 ) -> Self::Proof {
65 LogEqualityProof::new(
66 receiver,
67 ciphertext.randomness(),
68 (
69 ciphertext.inner().random_element,
70 ciphertext.inner().blinded_element - G::generator(),
71 ),
72 &mut Transcript::new(b"choice_encryption_sum"),
73 rng,
74 )
75 }
76
77 fn verify(
78 &self,
79 ciphertext: &Ciphertext<G>,
80 proof: &Self::Proof,
81 receiver: &PublicKey<G>,
82 ) -> Result<(), ChoiceVerificationError> {
83 let powers = (
84 ciphertext.random_element,
85 ciphertext.blinded_element - G::generator(),
86 );
87 proof
88 .verify(
89 receiver,
90 powers,
91 &mut Transcript::new(b"choice_encryption_sum"),
92 )
93 .map_err(ChoiceVerificationError::Sum)
94 }
95}
96
97#[derive(Debug, Clone, Copy)]
104pub struct MultiChoice(());
105
106impl crate::sealed::Sealed for MultiChoice {}
107
108impl<G: Group> ProveSum<G> for MultiChoice {
109 type Proof = ();
110
111 fn prove<R: CryptoRng>(
112 &self,
113 _ciphertext: &CiphertextWithValue<G, u64>,
114 _receiver: &PublicKey<G>,
115 _rng: &mut R,
116 ) -> Self::Proof {
117 }
119
120 fn verify(
121 &self,
122 _ciphertext: &Ciphertext<G>,
123 _proof: &Self::Proof,
124 _receiver: &PublicKey<G>,
125 ) -> Result<(), ChoiceVerificationError> {
126 Ok(()) }
128}
129
130#[derive(Debug)]
132pub struct ChoiceParams<G: Group, S: ProveSum<G>> {
133 options_count: usize,
134 sum_prover: S,
135 receiver: PublicKey<G>,
136}
137
138impl<G: Group, S: ProveSum<G>> Clone for ChoiceParams<G, S> {
139 fn clone(&self) -> Self {
140 Self {
141 options_count: self.options_count,
142 sum_prover: self.sum_prover.clone(),
143 receiver: self.receiver.clone(),
144 }
145 }
146}
147
148impl<G: Group, S: ProveSum<G>> ChoiceParams<G, S> {
149 fn check_options_count(&self, actual_count: usize) -> Result<(), ChoiceVerificationError> {
150 if self.options_count == actual_count {
151 Ok(())
152 } else {
153 Err(ChoiceVerificationError::OptionsLenMismatch {
154 expected: self.options_count,
155 actual: actual_count,
156 })
157 }
158 }
159
160 pub fn receiver(&self) -> &PublicKey<G> {
162 &self.receiver
163 }
164
165 pub fn options_count(&self) -> usize {
167 self.options_count
168 }
169}
170
171impl<G: Group> ChoiceParams<G, SingleChoice> {
172 pub fn single(receiver: PublicKey<G>, options_count: usize) -> Self {
178 assert!(options_count > 0, "Number of options must be positive");
179 Self {
180 options_count,
181 sum_prover: SingleChoice(()),
182 receiver,
183 }
184 }
185}
186
187impl<G: Group> ChoiceParams<G, MultiChoice> {
188 pub fn multi(receiver: PublicKey<G>, options_count: usize) -> Self {
194 assert!(options_count > 0, "Number of options must be positive");
195 Self {
196 options_count,
197 sum_prover: MultiChoice(()),
198 receiver,
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
274#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
275#[cfg_attr(feature = "serde", serde(bound = ""))]
276pub struct EncryptedChoice<G: Group, S: ProveSum<G>> {
277 choices: Vec<Ciphertext<G>>,
278 range_proof: RingProof<G>,
279 sum_proof: S::Proof,
280}
281
282impl<G: Group> EncryptedChoice<G, SingleChoice> {
283 pub fn single<R: CryptoRng>(
289 params: &ChoiceParams<G, SingleChoice>,
290 choice: usize,
291 rng: &mut R,
292 ) -> Self {
293 assert!(
294 choice < params.options_count,
295 "invalid choice {choice}; expected a value in 0..{}",
296 params.options_count
297 );
298 let choices: Vec<_> = (0..params.options_count).map(|i| choice == i).collect();
299 Self::new(params, &Zeroizing::new(choices), rng)
300 }
301}
302
303#[allow(clippy::len_without_is_empty)] impl<G: Group, S: ProveSum<G>> EncryptedChoice<G, S> {
305 pub fn new<R: CryptoRng>(params: &ChoiceParams<G, S>, choices: &[bool], rng: &mut R) -> Self {
314 assert!(!choices.is_empty(), "No choices provided");
315 assert_eq!(
316 choices.len(),
317 params.options_count,
318 "Mismatch between expected and actual number of choices"
319 );
320
321 let admissible_values = [G::identity(), G::generator()];
322 let mut ring_responses = vec![G::Scalar::default(); 2 * params.options_count];
323 let mut transcript = Transcript::new(b"encrypted_choice_ranges");
324 let mut proof_builder = RingProofBuilder::new(
325 ¶ms.receiver,
326 params.options_count,
327 &mut ring_responses,
328 &mut transcript,
329 rng,
330 );
331
332 let sum = choices.iter().map(|&flag| u64::from(flag)).sum::<u64>();
333 let choices: Vec<_> = choices
334 .iter()
335 .map(|&flag| proof_builder.add_value(&admissible_values, usize::from(flag)))
336 .collect();
337 let range_proof = RingProof::new(proof_builder.build(), ring_responses);
338
339 let sum_ciphertext = choices.iter().cloned().reduce(ops::Add::add).unwrap();
340 let sum_ciphertext = sum_ciphertext.with_value(sum);
341 let sum_proof = params
342 .sum_prover
343 .prove(&sum_ciphertext, ¶ms.receiver, rng);
344 Self {
345 choices: choices.into_iter().map(|choice| choice.inner).collect(),
346 range_proof,
347 sum_proof,
348 }
349 }
350
351 #[allow(clippy::missing_panics_doc)]
358 pub fn verify(
359 &self,
360 params: &ChoiceParams<G, S>,
361 ) -> Result<&[Ciphertext<G>], ChoiceVerificationError> {
362 params.check_options_count(self.choices.len())?;
363 let sum_of_ciphertexts = self.choices.iter().copied().reduce(ops::Add::add);
364 let sum_of_ciphertexts = sum_of_ciphertexts.unwrap();
365 params
367 .sum_prover
368 .verify(&sum_of_ciphertexts, &self.sum_proof, ¶ms.receiver)?;
369
370 let admissible_values = [G::identity(), G::generator()];
371 self.range_proof
372 .verify(
373 ¶ms.receiver,
374 iter::repeat_n(&admissible_values as &[_], self.choices.len()),
375 self.choices.iter().copied(),
376 &mut Transcript::new(b"encrypted_choice_ranges"),
377 )
378 .map(|()| self.choices.as_slice())
379 .map_err(ChoiceVerificationError::Range)
380 }
381
382 pub fn len(&self) -> usize {
385 self.choices.len()
386 }
387
388 pub fn choices_unchecked(&self) -> &[Ciphertext<G>] {
390 &self.choices
391 }
392
393 pub fn range_proof(&self) -> &RingProof<G> {
395 &self.range_proof
396 }
397
398 pub fn sum_proof(&self) -> &S::Proof {
400 &self.sum_proof
401 }
402}
403
404#[derive(Debug)]
406#[non_exhaustive]
407pub enum ChoiceVerificationError {
408 OptionsLenMismatch {
410 expected: usize,
412 actual: usize,
414 },
415 Sum(VerificationError),
417 Range(VerificationError),
419}
420
421impl fmt::Display for ChoiceVerificationError {
422 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
423 match self {
424 Self::OptionsLenMismatch { expected, actual } => write!(
425 formatter,
426 "number of options in the ballot ({actual}) differs from expected ({expected})",
427 ),
428 Self::Sum(err) => write!(formatter, "cannot verify sum proof: {err}"),
429 Self::Range(err) => write!(formatter, "cannot verify range proofs: {err}"),
430 }
431 }
432}
433
434#[cfg(feature = "std")]
435impl std::error::Error for ChoiceVerificationError {
436 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
437 match self {
438 Self::Sum(err) | Self::Range(err) => Some(err),
439 _ => None,
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::{
448 Keypair,
449 group::{Generic, Ristretto},
450 };
451
452 fn test_bogus_encrypted_choice_does_not_work<G: Group>() {
453 let mut rng = rand::rng();
454 let (receiver, _) = Keypair::<G>::generate(&mut rng).into_tuple();
455 let params = ChoiceParams::single(receiver.clone(), 5);
456
457 let mut choice = EncryptedChoice::single(¶ms, 2, &mut rng);
458 let (encrypted_one, _) = receiver.encrypt_bool(true, &mut rng);
459 choice.choices[0] = encrypted_one;
460 assert!(choice.verify(¶ms).is_err());
461
462 let mut choice = EncryptedChoice::single(¶ms, 4, &mut rng);
463 let (encrypted_zero, _) = receiver.encrypt_bool(false, &mut rng);
464 choice.choices[4] = encrypted_zero;
465 assert!(choice.verify(¶ms).is_err());
466
467 let mut choice = EncryptedChoice::single(¶ms, 4, &mut rng);
468 choice.choices[4].blinded_element =
469 choice.choices[4].blinded_element + G::mul_generator(&G::Scalar::from(10));
470 choice.choices[3].blinded_element =
471 choice.choices[3].blinded_element - G::mul_generator(&G::Scalar::from(10));
472 assert!(choice.verify(¶ms).is_err());
475 }
476
477 #[test]
478 fn bogus_encrypted_choice_does_not_work_for_edwards() {
479 test_bogus_encrypted_choice_does_not_work::<Ristretto>();
480 }
481
482 #[test]
483 fn bogus_encrypted_choice_does_not_work_for_k256() {
484 test_bogus_encrypted_choice_does_not_work::<Generic<k256::Secp256k1>>();
485 }
486}