elastic_elgamal/proofs/
mul.rs

1//! Proofs related to multiplication.
2
3use core::iter;
4
5use elliptic_curve::{rand_core::CryptoRng, zeroize::Zeroizing};
6use merlin::Transcript;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "serde")]
11use crate::serde::{ScalarHelper, VecHelper};
12use crate::{
13    Ciphertext, CiphertextWithValue, PublicKey, SecretKey, VerificationError, alloc::Vec,
14    group::Group, proofs::TranscriptForGroup,
15};
16
17/// Zero-knowledge proof that an ElGamal-encrypted value is equal to a sum of squares
18/// of one or more other ElGamal-encrypted values.
19///
20/// # Construction
21///
22/// Consider the case with a single sum element (i.e., proving that an encrypted value is
23/// a square of another encrypted value). The prover wants to prove the knowledge of scalars
24///
25/// ```text
26/// r_x, x, r_z:
27///   R_x = [r_x]G, X = [x]G + [r_x]K;
28///   R_z = [r_z]G, Z = [x^2]G + [r_z]K,
29/// ```
30///
31/// where
32///
33/// - `G` is the conventional generator of the considered prime-order group
34/// - `K` is a group element equivalent to the receiver's public key
35/// - `(R_x, X)` and `(R_z, Z)` are ElGamal ciphertexts of values `x` and `x^2`, respectively.
36///
37/// Observe that
38///
39/// ```text
40/// r'_z := r_z - x * r_x =>
41///   R_z = [r'_z]G + [x]R_x; Z = [x]X + [r'_z]K.
42/// ```
43///
44/// and that proving the knowledge of `(r_x, x, r'_z)` is equivalent to the initial problem.
45/// The new problem can be solved using a conventional sigma protocol:
46///
47/// 1. **Commitment.** The prover generates random scalars `e_r`, `e_x` and `e_z` and commits
48///    to them via `E_r = [e_r]G`, `E_x = [e_x]G + [e_r]K`, `E_rz = [e_x]R_x + [e_z]G` and
49///    `E_z = [e_x]X + [e_z]K`.
50/// 2. **Challenge.** The verifier sends to the prover random scalar `c`.
51/// 3. **Response.** The prover computes the following scalars and sends them to the verifier.
52///
53/// ```text
54/// s_r = e_r + c * r_x;
55/// s_x = e_x + c * x;
56/// s_z = e_z + c * (r_z - x * r_x);
57/// ```
58///
59/// The verification equations are
60///
61/// ```text
62/// [s_r]G ?= E_r + [c]R_x;
63/// [s_x]G + [s_r]K ?= E_x + [c]X;
64/// [s_x]R_x + [s_z]G ?= E_rz + [c]R_z;
65/// [s_x]X + [s_z]K ?= E_z + [c]Z.
66/// ```
67///
68/// The case with multiple squares is a straightforward generalization:
69///
70/// - `e_r`, `E_r`, `e_x`, `E_x`, `s_r` and `s_x` are independently defined for each
71///   partial ciphertext in the same way as above.
72/// - Commitments `E_rz` and `E_z` sum over `[e_x]R_x` and `[e_x]X` for all ciphertexts,
73///   respectively.
74/// - Response `s_z` similarly substitutes `x * r_x` with the corresponding sum.
75///
76/// A non-interactive version of the proof is obtained by applying [Fiat–Shamir transform][fst].
77/// As with [`LogEqualityProof`], it is more efficient to represent a proof as the challenge
78/// and responses; in this case, the proof size is `2n + 2` scalars, where `n` is the number of
79/// partial ciphertexts.
80///
81/// [fst]: https://en.wikipedia.org/wiki/Fiat%E2%80%93Shamir_heuristic
82/// [`LogEqualityProof`]: crate::LogEqualityProof
83#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85#[cfg_attr(feature = "serde", serde(bound = ""))]
86pub struct SumOfSquaresProof<G: Group> {
87    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
88    challenge: G::Scalar,
89    #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
90    ciphertext_responses: Vec<G::Scalar>,
91    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
92    sum_response: G::Scalar,
93}
94
95impl<G: Group> SumOfSquaresProof<G> {
96    fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
97        transcript.start_proof(b"sum_of_squares");
98        transcript.append_element_bytes(b"K", receiver.as_bytes());
99    }
100
101    /// Creates a new proof that squares of values encrypted in `ciphertexts` for `receiver` sum up
102    /// to a value encrypted in `sum_of_squares_ciphertext`.
103    ///
104    /// All provided ciphertexts must be encrypted for `receiver`; otherwise, the created proof
105    /// will not verify.
106    #[allow(clippy::needless_collect)] // false positive
107    pub fn new<'a, R: CryptoRng>(
108        ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
109        sum_of_squares_ciphertext: &CiphertextWithValue<G>,
110        receiver: &PublicKey<G>,
111        transcript: &mut Transcript,
112        rng: &mut R,
113    ) -> Self {
114        Self::initialize_transcript(transcript, receiver);
115
116        let sum_scalar = SecretKey::<G>::generate(rng);
117        let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
118
119        let partial_scalars: Vec<_> = ciphertexts
120            .map(|ciphertext| {
121                transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
122                transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
123
124                let random_scalar = SecretKey::<G>::generate(rng);
125                let random_commitment = G::mul_generator(random_scalar.expose_scalar());
126                transcript.append_element::<G>(b"[e_r]G", &random_commitment);
127                let value_scalar = SecretKey::<G>::generate(rng);
128                let value_commitment = G::mul_generator(value_scalar.expose_scalar())
129                    + receiver.as_element() * random_scalar.expose_scalar();
130                transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
131
132                let neg_value = Zeroizing::new(-*ciphertext.value());
133                sum_random_scalar += ciphertext.randomness() * &neg_value;
134                (ciphertext, random_scalar, value_scalar)
135            })
136            .collect();
137
138        let scalars = partial_scalars
139            .iter()
140            .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
141            .chain(iter::once(sum_scalar.expose_scalar()));
142        let random_sum_commitment = {
143            let elements = partial_scalars
144                .iter()
145                .map(|(ciphertext, ..)| ciphertext.inner().random_element)
146                .chain(iter::once(G::generator()));
147            G::multi_mul(scalars.clone(), elements)
148        };
149        let value_sum_commitment = {
150            let elements = partial_scalars
151                .iter()
152                .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
153                .chain(iter::once(receiver.as_element()));
154            G::multi_mul(scalars, elements)
155        };
156
157        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
158        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
159        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
160        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
161        let challenge = transcript.challenge_scalar::<G>(b"c");
162
163        let ciphertext_responses = partial_scalars
164            .into_iter()
165            .flat_map(|(ciphertext, random_scalar, value_scalar)| {
166                [
167                    challenge * ciphertext.randomness().expose_scalar()
168                        + random_scalar.expose_scalar(),
169                    challenge * ciphertext.value() + value_scalar.expose_scalar(),
170                ]
171            })
172            .collect();
173        let sum_response =
174            challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
175
176        Self {
177            challenge,
178            ciphertext_responses,
179            sum_response,
180        }
181    }
182
183    /// Verifies this proof against the provided partial ciphertexts and the ciphertext of the
184    /// sum of their squares. The order of partial ciphertexts must correspond to their order
185    /// when creating the proof.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if this proof does not verify.
190    pub fn verify<'a>(
191        &self,
192        ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
193        sum_of_squares_ciphertext: &Ciphertext<G>,
194        receiver: &PublicKey<G>,
195        transcript: &mut Transcript,
196    ) -> Result<(), VerificationError> {
197        let ciphertexts_count = ciphertexts.clone().count();
198        VerificationError::check_lengths(
199            "ciphertext responses",
200            self.ciphertext_responses.len(),
201            ciphertexts_count * 2,
202        )?;
203
204        Self::initialize_transcript(transcript, receiver);
205        let neg_challenge = -self.challenge;
206
207        for (response_chunk, ciphertext) in
208            self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
209        {
210            transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
211            transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
212
213            let r_response = &response_chunk[0];
214            let v_response = &response_chunk[1];
215            let random_commitment = G::vartime_double_mul_generator(
216                &-self.challenge,
217                ciphertext.random_element,
218                r_response,
219            );
220            transcript.append_element::<G>(b"[e_r]G", &random_commitment);
221            let value_commitment = G::vartime_multi_mul(
222                [v_response, r_response, &neg_challenge],
223                [
224                    G::generator(),
225                    receiver.as_element(),
226                    ciphertext.blinded_element,
227                ],
228            );
229            transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
230        }
231
232        let scalars = OddItems::new(self.ciphertext_responses.iter())
233            .chain([&self.sum_response, &neg_challenge]);
234        let random_sum_commitment = {
235            let elements = ciphertexts
236                .clone()
237                .map(|c| c.random_element)
238                .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
239            G::vartime_multi_mul(scalars.clone(), elements)
240        };
241        let value_sum_commitment = {
242            let elements = ciphertexts.map(|c| c.blinded_element).chain([
243                receiver.as_element(),
244                sum_of_squares_ciphertext.blinded_element,
245            ]);
246            G::vartime_multi_mul(scalars, elements)
247        };
248
249        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
250        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
251        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
252        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
253        let expected_challenge = transcript.challenge_scalar::<G>(b"c");
254
255        if expected_challenge == self.challenge {
256            Ok(())
257        } else {
258            Err(VerificationError::ChallengeMismatch)
259        }
260    }
261}
262
263/// Thin wrapper around an iterator that drops its even-indexed elements. This is necessary
264/// because `Ristretto::vartime_multi_mul()` panics otherwise, which is caused by an imprecise
265/// `Iterator::size_hint()` value.
266#[derive(Debug, Clone)]
267struct OddItems<I> {
268    iter: I,
269    ended: bool,
270}
271
272impl<I: Iterator> OddItems<I> {
273    fn new(iter: I) -> Self {
274        Self { iter, ended: false }
275    }
276}
277
278impl<I: Iterator> Iterator for OddItems<I> {
279    type Item = I::Item;
280
281    fn next(&mut self) -> Option<Self::Item> {
282        if self.ended {
283            return None;
284        }
285        self.ended = self.iter.next().is_none();
286        if self.ended {
287            return None;
288        }
289
290        let item = self.iter.next();
291        self.ended = item.is_none();
292        item
293    }
294
295    fn size_hint(&self) -> (usize, Option<usize>) {
296        let (min, max) = self.iter.size_hint();
297        (min / 2, max.map(|max| max / 2))
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::{Keypair, group::Ristretto};
305
306    #[test]
307    fn sum_of_squares_proof_basics() {
308        let mut rng = rand::rng();
309        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
310        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
311        let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
312
313        let proof = SumOfSquaresProof::new(
314            [&ciphertext].into_iter(),
315            &sq_ciphertext,
316            &receiver,
317            &mut Transcript::new(b"test"),
318            &mut rng,
319        );
320
321        let ciphertext = ciphertext.into();
322        let sq_ciphertext = sq_ciphertext.into();
323        proof
324            .verify(
325                [&ciphertext].into_iter(),
326                &sq_ciphertext,
327                &receiver,
328                &mut Transcript::new(b"test"),
329            )
330            .unwrap();
331
332        let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
333        let err = proof
334            .verify(
335                [&ciphertext].into_iter(),
336                &other_ciphertext,
337                &receiver,
338                &mut Transcript::new(b"test"),
339            )
340            .unwrap_err();
341        assert!(matches!(err, VerificationError::ChallengeMismatch));
342
343        let err = proof
344            .verify(
345                [&other_ciphertext].into_iter(),
346                &sq_ciphertext,
347                &receiver,
348                &mut Transcript::new(b"test"),
349            )
350            .unwrap_err();
351        assert!(matches!(err, VerificationError::ChallengeMismatch));
352
353        let err = proof
354            .verify(
355                [&ciphertext].into_iter(),
356                &sq_ciphertext,
357                &receiver,
358                &mut Transcript::new(b"other_transcript"),
359            )
360            .unwrap_err();
361        assert!(matches!(err, VerificationError::ChallengeMismatch));
362    }
363
364    #[test]
365    fn sum_of_squares_proof_with_bogus_inputs() {
366        let mut rng = rand::rng();
367        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
368        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
369        let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
370
371        let proof = SumOfSquaresProof::new(
372            [&ciphertext].into_iter(),
373            &sq_ciphertext,
374            &receiver,
375            &mut Transcript::new(b"test"),
376            &mut rng,
377        );
378
379        let ciphertext = ciphertext.into();
380        let sq_ciphertext = sq_ciphertext.into();
381        let err = proof
382            .verify(
383                [&ciphertext].into_iter(),
384                &sq_ciphertext,
385                &receiver,
386                &mut Transcript::new(b"test"),
387            )
388            .unwrap_err();
389        assert!(matches!(err, VerificationError::ChallengeMismatch));
390    }
391
392    #[test]
393    fn sum_of_squares_proof_with_several_squares() {
394        let mut rng = rand::rng();
395        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
396        let ciphertexts =
397            [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
398        let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
399
400        let proof = SumOfSquaresProof::new(
401            ciphertexts.iter(),
402            &sq_ciphertext,
403            &receiver,
404            &mut Transcript::new(b"test"),
405            &mut rng,
406        );
407
408        let sq_ciphertext = sq_ciphertext.into();
409        proof
410            .verify(
411                ciphertexts.iter().map(CiphertextWithValue::inner),
412                &sq_ciphertext,
413                &receiver,
414                &mut Transcript::new(b"test"),
415            )
416            .unwrap();
417
418        // The proof will not verify if ciphertexts are rearranged.
419        let err = proof
420            .verify(
421                ciphertexts.iter().rev().map(CiphertextWithValue::inner),
422                &sq_ciphertext,
423                &receiver,
424                &mut Transcript::new(b"test"),
425            )
426            .unwrap_err();
427        assert!(matches!(err, VerificationError::ChallengeMismatch));
428
429        let err = proof
430            .verify(
431                ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
432                &sq_ciphertext,
433                &receiver,
434                &mut Transcript::new(b"test"),
435            )
436            .unwrap_err();
437        assert!(matches!(err, VerificationError::LenMismatch { .. }));
438    }
439
440    #[test]
441    fn odd_items() {
442        let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
443        assert_eq!(odd_items.size_hint(), (2, Some(2)));
444        assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
445
446        let other_items = OddItems::new(0..7);
447        assert_eq!(other_items.size_hint(), (3, Some(3)));
448        assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
449    }
450}