1use core::iter;
4
5use elliptic_curve::{rand_core::CryptoRng, zeroize::Zeroizing};
6use merlin::Transcript;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "serde")]
11use crate::serde::{ScalarHelper, VecHelper};
12use crate::{
13 Ciphertext, CiphertextWithValue, PublicKey, SecretKey, VerificationError, alloc::Vec,
14 group::Group, proofs::TranscriptForGroup,
15};
16
17#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85#[cfg_attr(feature = "serde", serde(bound = ""))]
86pub struct SumOfSquaresProof<G: Group> {
87 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
88 challenge: G::Scalar,
89 #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
90 ciphertext_responses: Vec<G::Scalar>,
91 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
92 sum_response: G::Scalar,
93}
94
95impl<G: Group> SumOfSquaresProof<G> {
96 fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
97 transcript.start_proof(b"sum_of_squares");
98 transcript.append_element_bytes(b"K", receiver.as_bytes());
99 }
100
101 #[allow(clippy::needless_collect)] pub fn new<'a, R: CryptoRng>(
108 ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
109 sum_of_squares_ciphertext: &CiphertextWithValue<G>,
110 receiver: &PublicKey<G>,
111 transcript: &mut Transcript,
112 rng: &mut R,
113 ) -> Self {
114 Self::initialize_transcript(transcript, receiver);
115
116 let sum_scalar = SecretKey::<G>::generate(rng);
117 let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
118
119 let partial_scalars: Vec<_> = ciphertexts
120 .map(|ciphertext| {
121 transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
122 transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
123
124 let random_scalar = SecretKey::<G>::generate(rng);
125 let random_commitment = G::mul_generator(random_scalar.expose_scalar());
126 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
127 let value_scalar = SecretKey::<G>::generate(rng);
128 let value_commitment = G::mul_generator(value_scalar.expose_scalar())
129 + receiver.as_element() * random_scalar.expose_scalar();
130 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
131
132 let neg_value = Zeroizing::new(-*ciphertext.value());
133 sum_random_scalar += ciphertext.randomness() * &neg_value;
134 (ciphertext, random_scalar, value_scalar)
135 })
136 .collect();
137
138 let scalars = partial_scalars
139 .iter()
140 .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
141 .chain(iter::once(sum_scalar.expose_scalar()));
142 let random_sum_commitment = {
143 let elements = partial_scalars
144 .iter()
145 .map(|(ciphertext, ..)| ciphertext.inner().random_element)
146 .chain(iter::once(G::generator()));
147 G::multi_mul(scalars.clone(), elements)
148 };
149 let value_sum_commitment = {
150 let elements = partial_scalars
151 .iter()
152 .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
153 .chain(iter::once(receiver.as_element()));
154 G::multi_mul(scalars, elements)
155 };
156
157 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
158 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
159 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
160 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
161 let challenge = transcript.challenge_scalar::<G>(b"c");
162
163 let ciphertext_responses = partial_scalars
164 .into_iter()
165 .flat_map(|(ciphertext, random_scalar, value_scalar)| {
166 [
167 challenge * ciphertext.randomness().expose_scalar()
168 + random_scalar.expose_scalar(),
169 challenge * ciphertext.value() + value_scalar.expose_scalar(),
170 ]
171 })
172 .collect();
173 let sum_response =
174 challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
175
176 Self {
177 challenge,
178 ciphertext_responses,
179 sum_response,
180 }
181 }
182
183 pub fn verify<'a>(
191 &self,
192 ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
193 sum_of_squares_ciphertext: &Ciphertext<G>,
194 receiver: &PublicKey<G>,
195 transcript: &mut Transcript,
196 ) -> Result<(), VerificationError> {
197 let ciphertexts_count = ciphertexts.clone().count();
198 VerificationError::check_lengths(
199 "ciphertext responses",
200 self.ciphertext_responses.len(),
201 ciphertexts_count * 2,
202 )?;
203
204 Self::initialize_transcript(transcript, receiver);
205 let neg_challenge = -self.challenge;
206
207 for (response_chunk, ciphertext) in
208 self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
209 {
210 transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
211 transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
212
213 let r_response = &response_chunk[0];
214 let v_response = &response_chunk[1];
215 let random_commitment = G::vartime_double_mul_generator(
216 &-self.challenge,
217 ciphertext.random_element,
218 r_response,
219 );
220 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
221 let value_commitment = G::vartime_multi_mul(
222 [v_response, r_response, &neg_challenge],
223 [
224 G::generator(),
225 receiver.as_element(),
226 ciphertext.blinded_element,
227 ],
228 );
229 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
230 }
231
232 let scalars = OddItems::new(self.ciphertext_responses.iter())
233 .chain([&self.sum_response, &neg_challenge]);
234 let random_sum_commitment = {
235 let elements = ciphertexts
236 .clone()
237 .map(|c| c.random_element)
238 .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
239 G::vartime_multi_mul(scalars.clone(), elements)
240 };
241 let value_sum_commitment = {
242 let elements = ciphertexts.map(|c| c.blinded_element).chain([
243 receiver.as_element(),
244 sum_of_squares_ciphertext.blinded_element,
245 ]);
246 G::vartime_multi_mul(scalars, elements)
247 };
248
249 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
250 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
251 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
252 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
253 let expected_challenge = transcript.challenge_scalar::<G>(b"c");
254
255 if expected_challenge == self.challenge {
256 Ok(())
257 } else {
258 Err(VerificationError::ChallengeMismatch)
259 }
260 }
261}
262
263#[derive(Debug, Clone)]
267struct OddItems<I> {
268 iter: I,
269 ended: bool,
270}
271
272impl<I: Iterator> OddItems<I> {
273 fn new(iter: I) -> Self {
274 Self { iter, ended: false }
275 }
276}
277
278impl<I: Iterator> Iterator for OddItems<I> {
279 type Item = I::Item;
280
281 fn next(&mut self) -> Option<Self::Item> {
282 if self.ended {
283 return None;
284 }
285 self.ended = self.iter.next().is_none();
286 if self.ended {
287 return None;
288 }
289
290 let item = self.iter.next();
291 self.ended = item.is_none();
292 item
293 }
294
295 fn size_hint(&self) -> (usize, Option<usize>) {
296 let (min, max) = self.iter.size_hint();
297 (min / 2, max.map(|max| max / 2))
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::{Keypair, group::Ristretto};
305
306 #[test]
307 fn sum_of_squares_proof_basics() {
308 let mut rng = rand::rng();
309 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
310 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
311 let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
312
313 let proof = SumOfSquaresProof::new(
314 [&ciphertext].into_iter(),
315 &sq_ciphertext,
316 &receiver,
317 &mut Transcript::new(b"test"),
318 &mut rng,
319 );
320
321 let ciphertext = ciphertext.into();
322 let sq_ciphertext = sq_ciphertext.into();
323 proof
324 .verify(
325 [&ciphertext].into_iter(),
326 &sq_ciphertext,
327 &receiver,
328 &mut Transcript::new(b"test"),
329 )
330 .unwrap();
331
332 let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
333 let err = proof
334 .verify(
335 [&ciphertext].into_iter(),
336 &other_ciphertext,
337 &receiver,
338 &mut Transcript::new(b"test"),
339 )
340 .unwrap_err();
341 assert!(matches!(err, VerificationError::ChallengeMismatch));
342
343 let err = proof
344 .verify(
345 [&other_ciphertext].into_iter(),
346 &sq_ciphertext,
347 &receiver,
348 &mut Transcript::new(b"test"),
349 )
350 .unwrap_err();
351 assert!(matches!(err, VerificationError::ChallengeMismatch));
352
353 let err = proof
354 .verify(
355 [&ciphertext].into_iter(),
356 &sq_ciphertext,
357 &receiver,
358 &mut Transcript::new(b"other_transcript"),
359 )
360 .unwrap_err();
361 assert!(matches!(err, VerificationError::ChallengeMismatch));
362 }
363
364 #[test]
365 fn sum_of_squares_proof_with_bogus_inputs() {
366 let mut rng = rand::rng();
367 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
368 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
369 let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
370
371 let proof = SumOfSquaresProof::new(
372 [&ciphertext].into_iter(),
373 &sq_ciphertext,
374 &receiver,
375 &mut Transcript::new(b"test"),
376 &mut rng,
377 );
378
379 let ciphertext = ciphertext.into();
380 let sq_ciphertext = sq_ciphertext.into();
381 let err = proof
382 .verify(
383 [&ciphertext].into_iter(),
384 &sq_ciphertext,
385 &receiver,
386 &mut Transcript::new(b"test"),
387 )
388 .unwrap_err();
389 assert!(matches!(err, VerificationError::ChallengeMismatch));
390 }
391
392 #[test]
393 fn sum_of_squares_proof_with_several_squares() {
394 let mut rng = rand::rng();
395 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
396 let ciphertexts =
397 [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
398 let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
399
400 let proof = SumOfSquaresProof::new(
401 ciphertexts.iter(),
402 &sq_ciphertext,
403 &receiver,
404 &mut Transcript::new(b"test"),
405 &mut rng,
406 );
407
408 let sq_ciphertext = sq_ciphertext.into();
409 proof
410 .verify(
411 ciphertexts.iter().map(CiphertextWithValue::inner),
412 &sq_ciphertext,
413 &receiver,
414 &mut Transcript::new(b"test"),
415 )
416 .unwrap();
417
418 let err = proof
420 .verify(
421 ciphertexts.iter().rev().map(CiphertextWithValue::inner),
422 &sq_ciphertext,
423 &receiver,
424 &mut Transcript::new(b"test"),
425 )
426 .unwrap_err();
427 assert!(matches!(err, VerificationError::ChallengeMismatch));
428
429 let err = proof
430 .verify(
431 ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
432 &sq_ciphertext,
433 &receiver,
434 &mut Transcript::new(b"test"),
435 )
436 .unwrap_err();
437 assert!(matches!(err, VerificationError::LenMismatch { .. }));
438 }
439
440 #[test]
441 fn odd_items() {
442 let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
443 assert_eq!(odd_items.size_hint(), (2, Some(2)));
444 assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
445
446 let other_items = OddItems::new(0..7);
447 assert_eq!(other_items.size_hint(), (3, Some(3)));
448 assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
449 }
450}