elastic_elgamal/proofs/
mul.rs

1//! Proofs related to multiplication.
2
3use core::iter;
4
5use elliptic_curve::{
6    rand_core::{CryptoRng, RngCore},
7    zeroize::Zeroizing,
8};
9use merlin::Transcript;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[cfg(feature = "serde")]
14use crate::serde::{ScalarHelper, VecHelper};
15use crate::{
16    Ciphertext, CiphertextWithValue, PublicKey, SecretKey, VerificationError, alloc::Vec,
17    group::Group, proofs::TranscriptForGroup,
18};
19
20/// Zero-knowledge proof that an ElGamal-encrypted value is equal to a sum of squares
21/// of one or more other ElGamal-encrypted values.
22///
23/// # Construction
24///
25/// Consider the case with a single sum element (i.e., proving that an encrypted value is
26/// a square of another encrypted value). The prover wants to prove the knowledge of scalars
27///
28/// ```text
29/// r_x, x, r_z:
30///   R_x = [r_x]G, X = [x]G + [r_x]K;
31///   R_z = [r_z]G, Z = [x^2]G + [r_z]K,
32/// ```
33///
34/// where
35///
36/// - `G` is the conventional generator of the considered prime-order group
37/// - `K` is a group element equivalent to the receiver's public key
38/// - `(R_x, X)` and `(R_z, Z)` are ElGamal ciphertexts of values `x` and `x^2`, respectively.
39///
40/// Observe that
41///
42/// ```text
43/// r'_z := r_z - x * r_x =>
44///   R_z = [r'_z]G + [x]R_x; Z = [x]X + [r'_z]K.
45/// ```
46///
47/// and that proving the knowledge of `(r_x, x, r'_z)` is equivalent to the initial problem.
48/// The new problem can be solved using a conventional sigma protocol:
49///
50/// 1. **Commitment.** The prover generates random scalars `e_r`, `e_x` and `e_z` and commits
51///    to them via `E_r = [e_r]G`, `E_x = [e_x]G + [e_r]K`, `E_rz = [e_x]R_x + [e_z]G` and
52///    `E_z = [e_x]X + [e_z]K`.
53/// 2. **Challenge.** The verifier sends to the prover random scalar `c`.
54/// 3. **Response.** The prover computes the following scalars and sends them to the verifier.
55///
56/// ```text
57/// s_r = e_r + c * r_x;
58/// s_x = e_x + c * x;
59/// s_z = e_z + c * (r_z - x * r_x);
60/// ```
61///
62/// The verification equations are
63///
64/// ```text
65/// [s_r]G ?= E_r + [c]R_x;
66/// [s_x]G + [s_r]K ?= E_x + [c]X;
67/// [s_x]R_x + [s_z]G ?= E_rz + [c]R_z;
68/// [s_x]X + [s_z]K ?= E_z + [c]Z.
69/// ```
70///
71/// The case with multiple squares is a straightforward generalization:
72///
73/// - `e_r`, `E_r`, `e_x`, `E_x`, `s_r` and `s_x` are independently defined for each
74///   partial ciphertext in the same way as above.
75/// - Commitments `E_rz` and `E_z` sum over `[e_x]R_x` and `[e_x]X` for all ciphertexts,
76///   respectively.
77/// - Response `s_z` similarly substitutes `x * r_x` with the corresponding sum.
78///
79/// A non-interactive version of the proof is obtained by applying [Fiat–Shamir transform][fst].
80/// As with [`LogEqualityProof`], it is more efficient to represent a proof as the challenge
81/// and responses; in this case, the proof size is `2n + 2` scalars, where `n` is the number of
82/// partial ciphertexts.
83///
84/// [fst]: https://en.wikipedia.org/wiki/Fiat%E2%80%93Shamir_heuristic
85/// [`LogEqualityProof`]: crate::LogEqualityProof
86#[derive(Debug, Clone)]
87#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
88#[cfg_attr(feature = "serde", serde(bound = ""))]
89pub struct SumOfSquaresProof<G: Group> {
90    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
91    challenge: G::Scalar,
92    #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
93    ciphertext_responses: Vec<G::Scalar>,
94    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
95    sum_response: G::Scalar,
96}
97
98impl<G: Group> SumOfSquaresProof<G> {
99    fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
100        transcript.start_proof(b"sum_of_squares");
101        transcript.append_element_bytes(b"K", receiver.as_bytes());
102    }
103
104    /// Creates a new proof that squares of values encrypted in `ciphertexts` for `receiver` sum up
105    /// to a value encrypted in `sum_of_squares_ciphertext`.
106    ///
107    /// All provided ciphertexts must be encrypted for `receiver`; otherwise, the created proof
108    /// will not verify.
109    #[allow(clippy::needless_collect)] // false positive
110    pub fn new<'a, R: RngCore + CryptoRng>(
111        ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
112        sum_of_squares_ciphertext: &CiphertextWithValue<G>,
113        receiver: &PublicKey<G>,
114        transcript: &mut Transcript,
115        rng: &mut R,
116    ) -> Self {
117        Self::initialize_transcript(transcript, receiver);
118
119        let sum_scalar = SecretKey::<G>::generate(rng);
120        let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
121
122        let partial_scalars: Vec<_> = ciphertexts
123            .map(|ciphertext| {
124                transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
125                transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
126
127                let random_scalar = SecretKey::<G>::generate(rng);
128                let random_commitment = G::mul_generator(random_scalar.expose_scalar());
129                transcript.append_element::<G>(b"[e_r]G", &random_commitment);
130                let value_scalar = SecretKey::<G>::generate(rng);
131                let value_commitment = G::mul_generator(value_scalar.expose_scalar())
132                    + receiver.as_element() * random_scalar.expose_scalar();
133                transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
134
135                let neg_value = Zeroizing::new(-*ciphertext.value());
136                sum_random_scalar += ciphertext.randomness() * &neg_value;
137                (ciphertext, random_scalar, value_scalar)
138            })
139            .collect();
140
141        let scalars = partial_scalars
142            .iter()
143            .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
144            .chain(iter::once(sum_scalar.expose_scalar()));
145        let random_sum_commitment = {
146            let elements = partial_scalars
147                .iter()
148                .map(|(ciphertext, ..)| ciphertext.inner().random_element)
149                .chain(iter::once(G::generator()));
150            G::multi_mul(scalars.clone(), elements)
151        };
152        let value_sum_commitment = {
153            let elements = partial_scalars
154                .iter()
155                .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
156                .chain(iter::once(receiver.as_element()));
157            G::multi_mul(scalars, elements)
158        };
159
160        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
161        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
162        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
163        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
164        let challenge = transcript.challenge_scalar::<G>(b"c");
165
166        let ciphertext_responses = partial_scalars
167            .into_iter()
168            .flat_map(|(ciphertext, random_scalar, value_scalar)| {
169                [
170                    challenge * ciphertext.randomness().expose_scalar()
171                        + random_scalar.expose_scalar(),
172                    challenge * ciphertext.value() + value_scalar.expose_scalar(),
173                ]
174            })
175            .collect();
176        let sum_response =
177            challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
178
179        Self {
180            challenge,
181            ciphertext_responses,
182            sum_response,
183        }
184    }
185
186    /// Verifies this proof against the provided partial ciphertexts and the ciphertext of the
187    /// sum of their squares. The order of partial ciphertexts must correspond to their order
188    /// when creating the proof.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if this proof does not verify.
193    pub fn verify<'a>(
194        &self,
195        ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
196        sum_of_squares_ciphertext: &Ciphertext<G>,
197        receiver: &PublicKey<G>,
198        transcript: &mut Transcript,
199    ) -> Result<(), VerificationError> {
200        let ciphertexts_count = ciphertexts.clone().count();
201        VerificationError::check_lengths(
202            "ciphertext responses",
203            self.ciphertext_responses.len(),
204            ciphertexts_count * 2,
205        )?;
206
207        Self::initialize_transcript(transcript, receiver);
208        let neg_challenge = -self.challenge;
209
210        for (response_chunk, ciphertext) in
211            self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
212        {
213            transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
214            transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
215
216            let r_response = &response_chunk[0];
217            let v_response = &response_chunk[1];
218            let random_commitment = G::vartime_double_mul_generator(
219                &-self.challenge,
220                ciphertext.random_element,
221                r_response,
222            );
223            transcript.append_element::<G>(b"[e_r]G", &random_commitment);
224            let value_commitment = G::vartime_multi_mul(
225                [v_response, r_response, &neg_challenge],
226                [
227                    G::generator(),
228                    receiver.as_element(),
229                    ciphertext.blinded_element,
230                ],
231            );
232            transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
233        }
234
235        let scalars = OddItems::new(self.ciphertext_responses.iter())
236            .chain([&self.sum_response, &neg_challenge]);
237        let random_sum_commitment = {
238            let elements = ciphertexts
239                .clone()
240                .map(|c| c.random_element)
241                .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
242            G::vartime_multi_mul(scalars.clone(), elements)
243        };
244        let value_sum_commitment = {
245            let elements = ciphertexts.map(|c| c.blinded_element).chain([
246                receiver.as_element(),
247                sum_of_squares_ciphertext.blinded_element,
248            ]);
249            G::vartime_multi_mul(scalars, elements)
250        };
251
252        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
253        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
254        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
255        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
256        let expected_challenge = transcript.challenge_scalar::<G>(b"c");
257
258        if expected_challenge == self.challenge {
259            Ok(())
260        } else {
261            Err(VerificationError::ChallengeMismatch)
262        }
263    }
264}
265
266/// Thin wrapper around an iterator that drops its even-indexed elements. This is necessary
267/// because `Ristretto::vartime_multi_mul()` panics otherwise, which is caused by an imprecise
268/// `Iterator::size_hint()` value.
269#[derive(Debug, Clone)]
270struct OddItems<I> {
271    iter: I,
272    ended: bool,
273}
274
275impl<I: Iterator> OddItems<I> {
276    fn new(iter: I) -> Self {
277        Self { iter, ended: false }
278    }
279}
280
281impl<I: Iterator> Iterator for OddItems<I> {
282    type Item = I::Item;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        if self.ended {
286            return None;
287        }
288        self.ended = self.iter.next().is_none();
289        if self.ended {
290            return None;
291        }
292
293        let item = self.iter.next();
294        self.ended = item.is_none();
295        item
296    }
297
298    fn size_hint(&self) -> (usize, Option<usize>) {
299        let (min, max) = self.iter.size_hint();
300        (min / 2, max.map(|max| max / 2))
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::{Keypair, group::Ristretto};
308
309    #[test]
310    fn sum_of_squares_proof_basics() {
311        let mut rng = rand::rng();
312        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
313        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
314        let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
315
316        let proof = SumOfSquaresProof::new(
317            [&ciphertext].into_iter(),
318            &sq_ciphertext,
319            &receiver,
320            &mut Transcript::new(b"test"),
321            &mut rng,
322        );
323
324        let ciphertext = ciphertext.into();
325        let sq_ciphertext = sq_ciphertext.into();
326        proof
327            .verify(
328                [&ciphertext].into_iter(),
329                &sq_ciphertext,
330                &receiver,
331                &mut Transcript::new(b"test"),
332            )
333            .unwrap();
334
335        let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
336        let err = proof
337            .verify(
338                [&ciphertext].into_iter(),
339                &other_ciphertext,
340                &receiver,
341                &mut Transcript::new(b"test"),
342            )
343            .unwrap_err();
344        assert!(matches!(err, VerificationError::ChallengeMismatch));
345
346        let err = proof
347            .verify(
348                [&other_ciphertext].into_iter(),
349                &sq_ciphertext,
350                &receiver,
351                &mut Transcript::new(b"test"),
352            )
353            .unwrap_err();
354        assert!(matches!(err, VerificationError::ChallengeMismatch));
355
356        let err = proof
357            .verify(
358                [&ciphertext].into_iter(),
359                &sq_ciphertext,
360                &receiver,
361                &mut Transcript::new(b"other_transcript"),
362            )
363            .unwrap_err();
364        assert!(matches!(err, VerificationError::ChallengeMismatch));
365    }
366
367    #[test]
368    fn sum_of_squares_proof_with_bogus_inputs() {
369        let mut rng = rand::rng();
370        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
371        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
372        let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
373
374        let proof = SumOfSquaresProof::new(
375            [&ciphertext].into_iter(),
376            &sq_ciphertext,
377            &receiver,
378            &mut Transcript::new(b"test"),
379            &mut rng,
380        );
381
382        let ciphertext = ciphertext.into();
383        let sq_ciphertext = sq_ciphertext.into();
384        let err = proof
385            .verify(
386                [&ciphertext].into_iter(),
387                &sq_ciphertext,
388                &receiver,
389                &mut Transcript::new(b"test"),
390            )
391            .unwrap_err();
392        assert!(matches!(err, VerificationError::ChallengeMismatch));
393    }
394
395    #[test]
396    fn sum_of_squares_proof_with_several_squares() {
397        let mut rng = rand::rng();
398        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
399        let ciphertexts =
400            [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
401        let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
402
403        let proof = SumOfSquaresProof::new(
404            ciphertexts.iter(),
405            &sq_ciphertext,
406            &receiver,
407            &mut Transcript::new(b"test"),
408            &mut rng,
409        );
410
411        let sq_ciphertext = sq_ciphertext.into();
412        proof
413            .verify(
414                ciphertexts.iter().map(CiphertextWithValue::inner),
415                &sq_ciphertext,
416                &receiver,
417                &mut Transcript::new(b"test"),
418            )
419            .unwrap();
420
421        // The proof will not verify if ciphertexts are rearranged.
422        let err = proof
423            .verify(
424                ciphertexts.iter().rev().map(CiphertextWithValue::inner),
425                &sq_ciphertext,
426                &receiver,
427                &mut Transcript::new(b"test"),
428            )
429            .unwrap_err();
430        assert!(matches!(err, VerificationError::ChallengeMismatch));
431
432        let err = proof
433            .verify(
434                ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
435                &sq_ciphertext,
436                &receiver,
437                &mut Transcript::new(b"test"),
438            )
439            .unwrap_err();
440        assert!(matches!(err, VerificationError::LenMismatch { .. }));
441    }
442
443    #[test]
444    fn odd_items() {
445        let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
446        assert_eq!(odd_items.size_hint(), (2, Some(2)));
447        assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
448
449        let other_items = OddItems::new(0..7);
450        assert_eq!(other_items.size_hint(), (3, Some(3)));
451        assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
452    }
453}