1use core::iter;
4
5use elliptic_curve::{
6 rand_core::{CryptoRng, RngCore},
7 zeroize::Zeroizing,
8};
9use merlin::Transcript;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[cfg(feature = "serde")]
14use crate::serde::{ScalarHelper, VecHelper};
15use crate::{
16 Ciphertext, CiphertextWithValue, PublicKey, SecretKey, VerificationError, alloc::Vec,
17 group::Group, proofs::TranscriptForGroup,
18};
19
20#[derive(Debug, Clone)]
87#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
88#[cfg_attr(feature = "serde", serde(bound = ""))]
89pub struct SumOfSquaresProof<G: Group> {
90 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
91 challenge: G::Scalar,
92 #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
93 ciphertext_responses: Vec<G::Scalar>,
94 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
95 sum_response: G::Scalar,
96}
97
98impl<G: Group> SumOfSquaresProof<G> {
99 fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
100 transcript.start_proof(b"sum_of_squares");
101 transcript.append_element_bytes(b"K", receiver.as_bytes());
102 }
103
104 #[allow(clippy::needless_collect)] pub fn new<'a, R: RngCore + CryptoRng>(
111 ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
112 sum_of_squares_ciphertext: &CiphertextWithValue<G>,
113 receiver: &PublicKey<G>,
114 transcript: &mut Transcript,
115 rng: &mut R,
116 ) -> Self {
117 Self::initialize_transcript(transcript, receiver);
118
119 let sum_scalar = SecretKey::<G>::generate(rng);
120 let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
121
122 let partial_scalars: Vec<_> = ciphertexts
123 .map(|ciphertext| {
124 transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
125 transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
126
127 let random_scalar = SecretKey::<G>::generate(rng);
128 let random_commitment = G::mul_generator(random_scalar.expose_scalar());
129 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
130 let value_scalar = SecretKey::<G>::generate(rng);
131 let value_commitment = G::mul_generator(value_scalar.expose_scalar())
132 + receiver.as_element() * random_scalar.expose_scalar();
133 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
134
135 let neg_value = Zeroizing::new(-*ciphertext.value());
136 sum_random_scalar += ciphertext.randomness() * &neg_value;
137 (ciphertext, random_scalar, value_scalar)
138 })
139 .collect();
140
141 let scalars = partial_scalars
142 .iter()
143 .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
144 .chain(iter::once(sum_scalar.expose_scalar()));
145 let random_sum_commitment = {
146 let elements = partial_scalars
147 .iter()
148 .map(|(ciphertext, ..)| ciphertext.inner().random_element)
149 .chain(iter::once(G::generator()));
150 G::multi_mul(scalars.clone(), elements)
151 };
152 let value_sum_commitment = {
153 let elements = partial_scalars
154 .iter()
155 .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
156 .chain(iter::once(receiver.as_element()));
157 G::multi_mul(scalars, elements)
158 };
159
160 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
161 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
162 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
163 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
164 let challenge = transcript.challenge_scalar::<G>(b"c");
165
166 let ciphertext_responses = partial_scalars
167 .into_iter()
168 .flat_map(|(ciphertext, random_scalar, value_scalar)| {
169 [
170 challenge * ciphertext.randomness().expose_scalar()
171 + random_scalar.expose_scalar(),
172 challenge * ciphertext.value() + value_scalar.expose_scalar(),
173 ]
174 })
175 .collect();
176 let sum_response =
177 challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
178
179 Self {
180 challenge,
181 ciphertext_responses,
182 sum_response,
183 }
184 }
185
186 pub fn verify<'a>(
194 &self,
195 ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
196 sum_of_squares_ciphertext: &Ciphertext<G>,
197 receiver: &PublicKey<G>,
198 transcript: &mut Transcript,
199 ) -> Result<(), VerificationError> {
200 let ciphertexts_count = ciphertexts.clone().count();
201 VerificationError::check_lengths(
202 "ciphertext responses",
203 self.ciphertext_responses.len(),
204 ciphertexts_count * 2,
205 )?;
206
207 Self::initialize_transcript(transcript, receiver);
208 let neg_challenge = -self.challenge;
209
210 for (response_chunk, ciphertext) in
211 self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
212 {
213 transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
214 transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
215
216 let r_response = &response_chunk[0];
217 let v_response = &response_chunk[1];
218 let random_commitment = G::vartime_double_mul_generator(
219 &-self.challenge,
220 ciphertext.random_element,
221 r_response,
222 );
223 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
224 let value_commitment = G::vartime_multi_mul(
225 [v_response, r_response, &neg_challenge],
226 [
227 G::generator(),
228 receiver.as_element(),
229 ciphertext.blinded_element,
230 ],
231 );
232 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
233 }
234
235 let scalars = OddItems::new(self.ciphertext_responses.iter())
236 .chain([&self.sum_response, &neg_challenge]);
237 let random_sum_commitment = {
238 let elements = ciphertexts
239 .clone()
240 .map(|c| c.random_element)
241 .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
242 G::vartime_multi_mul(scalars.clone(), elements)
243 };
244 let value_sum_commitment = {
245 let elements = ciphertexts.map(|c| c.blinded_element).chain([
246 receiver.as_element(),
247 sum_of_squares_ciphertext.blinded_element,
248 ]);
249 G::vartime_multi_mul(scalars, elements)
250 };
251
252 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
253 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
254 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
255 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
256 let expected_challenge = transcript.challenge_scalar::<G>(b"c");
257
258 if expected_challenge == self.challenge {
259 Ok(())
260 } else {
261 Err(VerificationError::ChallengeMismatch)
262 }
263 }
264}
265
266#[derive(Debug, Clone)]
270struct OddItems<I> {
271 iter: I,
272 ended: bool,
273}
274
275impl<I: Iterator> OddItems<I> {
276 fn new(iter: I) -> Self {
277 Self { iter, ended: false }
278 }
279}
280
281impl<I: Iterator> Iterator for OddItems<I> {
282 type Item = I::Item;
283
284 fn next(&mut self) -> Option<Self::Item> {
285 if self.ended {
286 return None;
287 }
288 self.ended = self.iter.next().is_none();
289 if self.ended {
290 return None;
291 }
292
293 let item = self.iter.next();
294 self.ended = item.is_none();
295 item
296 }
297
298 fn size_hint(&self) -> (usize, Option<usize>) {
299 let (min, max) = self.iter.size_hint();
300 (min / 2, max.map(|max| max / 2))
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::{Keypair, group::Ristretto};
308
309 #[test]
310 fn sum_of_squares_proof_basics() {
311 let mut rng = rand::rng();
312 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
313 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
314 let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
315
316 let proof = SumOfSquaresProof::new(
317 [&ciphertext].into_iter(),
318 &sq_ciphertext,
319 &receiver,
320 &mut Transcript::new(b"test"),
321 &mut rng,
322 );
323
324 let ciphertext = ciphertext.into();
325 let sq_ciphertext = sq_ciphertext.into();
326 proof
327 .verify(
328 [&ciphertext].into_iter(),
329 &sq_ciphertext,
330 &receiver,
331 &mut Transcript::new(b"test"),
332 )
333 .unwrap();
334
335 let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
336 let err = proof
337 .verify(
338 [&ciphertext].into_iter(),
339 &other_ciphertext,
340 &receiver,
341 &mut Transcript::new(b"test"),
342 )
343 .unwrap_err();
344 assert!(matches!(err, VerificationError::ChallengeMismatch));
345
346 let err = proof
347 .verify(
348 [&other_ciphertext].into_iter(),
349 &sq_ciphertext,
350 &receiver,
351 &mut Transcript::new(b"test"),
352 )
353 .unwrap_err();
354 assert!(matches!(err, VerificationError::ChallengeMismatch));
355
356 let err = proof
357 .verify(
358 [&ciphertext].into_iter(),
359 &sq_ciphertext,
360 &receiver,
361 &mut Transcript::new(b"other_transcript"),
362 )
363 .unwrap_err();
364 assert!(matches!(err, VerificationError::ChallengeMismatch));
365 }
366
367 #[test]
368 fn sum_of_squares_proof_with_bogus_inputs() {
369 let mut rng = rand::rng();
370 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
371 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
372 let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
373
374 let proof = SumOfSquaresProof::new(
375 [&ciphertext].into_iter(),
376 &sq_ciphertext,
377 &receiver,
378 &mut Transcript::new(b"test"),
379 &mut rng,
380 );
381
382 let ciphertext = ciphertext.into();
383 let sq_ciphertext = sq_ciphertext.into();
384 let err = proof
385 .verify(
386 [&ciphertext].into_iter(),
387 &sq_ciphertext,
388 &receiver,
389 &mut Transcript::new(b"test"),
390 )
391 .unwrap_err();
392 assert!(matches!(err, VerificationError::ChallengeMismatch));
393 }
394
395 #[test]
396 fn sum_of_squares_proof_with_several_squares() {
397 let mut rng = rand::rng();
398 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
399 let ciphertexts =
400 [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
401 let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
402
403 let proof = SumOfSquaresProof::new(
404 ciphertexts.iter(),
405 &sq_ciphertext,
406 &receiver,
407 &mut Transcript::new(b"test"),
408 &mut rng,
409 );
410
411 let sq_ciphertext = sq_ciphertext.into();
412 proof
413 .verify(
414 ciphertexts.iter().map(CiphertextWithValue::inner),
415 &sq_ciphertext,
416 &receiver,
417 &mut Transcript::new(b"test"),
418 )
419 .unwrap();
420
421 let err = proof
423 .verify(
424 ciphertexts.iter().rev().map(CiphertextWithValue::inner),
425 &sq_ciphertext,
426 &receiver,
427 &mut Transcript::new(b"test"),
428 )
429 .unwrap_err();
430 assert!(matches!(err, VerificationError::ChallengeMismatch));
431
432 let err = proof
433 .verify(
434 ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
435 &sq_ciphertext,
436 &receiver,
437 &mut Transcript::new(b"test"),
438 )
439 .unwrap_err();
440 assert!(matches!(err, VerificationError::LenMismatch { .. }));
441 }
442
443 #[test]
444 fn odd_items() {
445 let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
446 assert_eq!(odd_items.size_hint(), (2, Some(2)));
447 assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
448
449 let other_items = OddItems::new(0..7);
450 assert_eq!(other_items.size_hint(), (3, Some(3)));
451 assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
452 }
453}