1use core::fmt;
4
5use elliptic_curve::rand_core::{CryptoRng, RngCore};
6use merlin::Transcript;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 Ciphertext, PreparedRange, PublicKey, RangeDecomposition, RangeProof, SumOfSquaresProof,
12 VerificationError, alloc::Vec, group::Group,
13};
14
15#[derive(Debug, Clone)]
47pub struct QuadraticVotingParams<G: Group> {
48 vote_count_range: PreparedRange<G>,
49 credit_range: PreparedRange<G>,
50 options_count: usize,
51 receiver: PublicKey<G>,
52}
53
54impl<G: Group> QuadraticVotingParams<G> {
55 pub fn new(receiver: PublicKey<G>, options: usize, credits: u64) -> Self {
64 assert!(options > 0, "Number of options must be positive");
65 assert!(credits > 0, "Number of credits must be positive");
66
67 let max_votes = isqrt(credits);
68 let vote_count_range = RangeDecomposition::optimal(max_votes + 1);
69 let credit_range = RangeDecomposition::optimal(credits + 1);
70 Self {
71 vote_count_range: vote_count_range.into(),
72 credit_range: credit_range.into(),
73 options_count: options,
74 receiver,
75 }
76 }
77
78 pub fn receiver(&self) -> &PublicKey<G> {
80 &self.receiver
81 }
82
83 pub fn options_count(&self) -> usize {
85 self.options_count
86 }
87
88 pub fn credits(&self) -> u64 {
90 self.credit_range.decomposition().upper_bound() - 1
91 }
92
93 pub fn max_votes(&self) -> u64 {
95 self.vote_count_range.decomposition().upper_bound() - 1
96 }
97
98 pub fn set_max_votes(&mut self, max_votes: u64) {
105 assert!(
106 max_votes * max_votes <= self.credits(),
107 "Vote bound {max_votes} is too large; its square is greater than credit bound {}",
108 self.credits()
109 );
110 self.vote_count_range = RangeDecomposition::optimal(max_votes + 1).into();
111 }
112
113 fn check_options_count(&self, actual_count: usize) -> Result<(), QuadraticVotingError> {
114 if self.options_count == actual_count {
115 Ok(())
116 } else {
117 Err(QuadraticVotingError::OptionsLenMismatch {
118 expected: self.options_count,
119 actual: actual_count,
120 })
121 }
122 }
123}
124
125fn isqrt(mut x: u64) -> u64 {
128 let mut root = 0_u64;
129 let mut power_of_4 = 1_u64 << 62;
130 while power_of_4 > x {
131 power_of_4 /= 4;
132 }
133 while power_of_4 > 0 {
134 if x >= root + power_of_4 {
135 x -= root + power_of_4;
136 root = root / 2 + power_of_4;
137 } else {
138 root /= 2;
139 }
140 power_of_4 /= 4;
141 }
142 root
143}
144
145#[derive(Debug, Clone)]
203#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
204#[cfg_attr(feature = "serde", serde(bound = ""))]
205pub struct QuadraticVotingBallot<G: Group> {
206 votes: Vec<CiphertextWithRangeProof<G>>,
207 credit: CiphertextWithRangeProof<G>,
208 credit_equivalence_proof: SumOfSquaresProof<G>,
209}
210
211#[derive(Debug, Clone)]
212#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
213#[cfg_attr(feature = "serde", serde(bound = ""))]
214struct CiphertextWithRangeProof<G: Group> {
215 ciphertext: Ciphertext<G>,
216 range_proof: RangeProof<G>,
217}
218
219impl<G: Group> CiphertextWithRangeProof<G> {
220 fn new(ciphertext: Ciphertext<G>, range_proof: RangeProof<G>) -> Self {
221 Self {
222 ciphertext,
223 range_proof,
224 }
225 }
226}
227
228impl<G: Group> QuadraticVotingBallot<G> {
229 pub fn new<R: CryptoRng + RngCore>(
235 params: &QuadraticVotingParams<G>,
236 votes: &[u64],
237 rng: &mut R,
238 ) -> Self {
239 assert_eq!(
240 votes.len(),
241 params.options_count,
242 "Mismatch between expected and actual number of choices"
243 );
244 let credit = votes.iter().map(|&x| x * x).sum::<u64>();
245
246 let votes: Vec<_> = votes
247 .iter()
248 .map(|&vote_count| {
249 let (ciphertext, proof) = RangeProof::new(
250 ¶ms.receiver,
251 ¶ms.vote_count_range,
252 vote_count,
253 &mut Transcript::new(b"quadratic_voting_variant"),
254 rng,
255 );
256 (ciphertext.generalize(), proof)
257 })
258 .collect();
259 let (credit, credit_range_proof) = RangeProof::new(
260 ¶ms.receiver,
261 ¶ms.credit_range,
262 credit,
263 &mut Transcript::new(b"quadratic_voting_credit_range"),
264 rng,
265 );
266 let credit = credit.generalize();
267
268 let credit_equivalence_proof = SumOfSquaresProof::new(
269 votes.iter().map(|(ciphertext, _)| ciphertext),
270 &credit,
271 ¶ms.receiver,
272 &mut Transcript::new(b"quadratic_voting_credit_equiv"),
273 rng,
274 );
275
276 Self {
277 votes: votes
278 .into_iter()
279 .map(|(ciphertext, proof)| CiphertextWithRangeProof::new(ciphertext.into(), proof))
280 .collect(),
281 credit: CiphertextWithRangeProof::new(credit.into(), credit_range_proof),
282 credit_equivalence_proof,
283 }
284 }
285
286 pub fn verify(
292 &self,
293 params: &QuadraticVotingParams<G>,
294 ) -> Result<impl Iterator<Item = Ciphertext<G>> + '_, QuadraticVotingError> {
295 params.check_options_count(self.votes.len())?;
296
297 for (i, vote_count) in self.votes.iter().enumerate() {
298 vote_count
299 .range_proof
300 .verify(
301 ¶ms.receiver,
302 ¶ms.vote_count_range,
303 vote_count.ciphertext,
304 &mut Transcript::new(b"quadratic_voting_variant"),
305 )
306 .map_err(|error| QuadraticVotingError::Variant { index: i, error })?;
307 }
308
309 self.credit
310 .range_proof
311 .verify(
312 ¶ms.receiver,
313 ¶ms.credit_range,
314 self.credit.ciphertext,
315 &mut Transcript::new(b"quadratic_voting_credit_range"),
316 )
317 .map_err(QuadraticVotingError::CreditRange)?;
318
319 self.credit_equivalence_proof
320 .verify(
321 self.votes.iter().map(|c| &c.ciphertext),
322 &self.credit.ciphertext,
323 ¶ms.receiver,
324 &mut Transcript::new(b"quadratic_voting_credit_equiv"),
325 )
326 .map_err(QuadraticVotingError::CreditEquivalence)?;
327
328 Ok(self.votes.iter().map(|c| c.ciphertext))
329 }
330}
331
332#[derive(Debug)]
334#[non_exhaustive]
335pub enum QuadraticVotingError {
336 Variant {
338 index: usize,
340 error: VerificationError,
342 },
343 CreditRange(VerificationError),
345 CreditEquivalence(VerificationError),
347 OptionsLenMismatch {
349 expected: usize,
351 actual: usize,
353 },
354}
355
356impl fmt::Display for QuadraticVotingError {
357 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358 match self {
359 Self::Variant { index, error } => write!(
360 formatter,
361 "error verifying range proof for option #{}: {error}",
362 *index + 1
363 ),
364 Self::CreditRange(err) => {
365 write!(formatter, "error verifying range proof for credits: {err}")
366 }
367 Self::CreditEquivalence(err) => {
368 write!(formatter, "error verifying credit equivalence proof: {err}")
369 }
370 Self::OptionsLenMismatch { expected, actual } => write!(
371 formatter,
372 "number of options in the ballot ({actual}) differs from expected ({expected})"
373 ),
374 }
375 }
376}
377
378#[cfg(feature = "std")]
379impl std::error::Error for QuadraticVotingError {
380 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
381 match self {
382 Self::Variant { error, .. }
383 | Self::CreditRange(error)
384 | Self::CreditEquivalence(error) => Some(error),
385 _ => None,
386 }
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::{
394 DiscreteLogTable, Keypair,
395 group::{ElementOps, Ristretto},
396 };
397
398 #[test]
399 fn isqrt_is_correct() {
400 let samples = (0..1_000).chain((0..1_000).map(|x| x * 1_000)).chain([
401 u64::MAX,
402 u64::MAX - 1,
403 1 << 63,
404 1 << 62,
405 (1 << 62) - 1,
406 ]);
407 for sample in samples {
408 let sqrt = isqrt(sample);
409 assert!(sqrt * sqrt <= sample, "sqrt({sample}) ?= {sqrt}");
410
411 let next_square = (sqrt + 1).checked_mul(sqrt + 1);
412 assert!(
413 next_square.is_none_or(|sq| sq > sample),
414 "sqrt({sample}) ?= {sqrt}"
415 );
416 }
417 }
418
419 #[test]
420 fn quadratic_voting() {
421 let mut rng = rand::rng();
422 let (pk, sk) = Keypair::generate(&mut rng).into_tuple();
423 let params = QuadraticVotingParams::<Ristretto>::new(pk, 5, 25);
424 let ballot = QuadraticVotingBallot::new(¶ms, &[1, 3, 0, 3, 2], &mut rng);
425
426 let choices = ballot.verify(¶ms).unwrap();
427 let lookup_table = DiscreteLogTable::new(0..=5);
428 let choices: Vec<_> = choices
429 .map(|c| sk.decrypt(c, &lookup_table).unwrap())
430 .collect();
431 assert_eq!(choices, [1, 3, 0, 3, 2]);
432
433 {
434 let mut bogus_ballot = ballot.clone();
435 bogus_ballot.votes[0].ciphertext.blinded_element += Ristretto::generator();
436 let err = bogus_ballot.verify(¶ms).map(drop).unwrap_err();
437 assert!(matches!(
438 err,
439 QuadraticVotingError::Variant {
440 index: 0,
441 error: VerificationError::ChallengeMismatch
442 }
443 ));
444 }
445
446 {
447 let mut bogus_ballot = ballot.clone();
448 bogus_ballot.credit.ciphertext.blinded_element -= Ristretto::generator();
449 let err = bogus_ballot.verify(¶ms).map(drop).unwrap_err();
450 assert!(matches!(err, QuadraticVotingError::CreditRange(_)));
451 }
452
453 let mut bogus_ballot = ballot.clone();
454 let (ciphertext, proof) = RangeProof::new(
455 ¶ms.receiver,
456 ¶ms.vote_count_range,
457 3, &mut Transcript::new(b"quadratic_voting_variant"),
459 &mut rng,
460 );
461 bogus_ballot.votes[0] = CiphertextWithRangeProof::new(ciphertext.into(), proof);
462
463 let err = bogus_ballot.verify(¶ms).map(drop).unwrap_err();
464 assert!(matches!(err, QuadraticVotingError::CreditEquivalence(_)));
465 }
466}