use merlin::Transcript;
use rand_core::{CryptoRng, RngCore};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
use core::iter;
#[cfg(feature = "serde")]
use crate::serde::{ScalarHelper, VecHelper};
use crate::{
alloc::Vec, group::Group, proofs::TranscriptForGroup, Ciphertext, CiphertextWithValue,
PublicKey, SecretKey, VerificationError,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(bound = ""))]
pub struct SumOfSquaresProof<G: Group> {
#[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
challenge: G::Scalar,
#[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
ciphertext_responses: Vec<G::Scalar>,
#[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
sum_response: G::Scalar,
}
impl<G: Group> SumOfSquaresProof<G> {
fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
transcript.start_proof(b"sum_of_squares");
transcript.append_element_bytes(b"K", receiver.as_bytes());
}
#[allow(clippy::needless_collect)] pub fn new<'a, R: RngCore + CryptoRng>(
ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
sum_of_squares_ciphertext: &CiphertextWithValue<G>,
receiver: &PublicKey<G>,
transcript: &mut Transcript,
rng: &mut R,
) -> Self {
Self::initialize_transcript(transcript, receiver);
let sum_scalar = SecretKey::<G>::generate(rng);
let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
let partial_scalars: Vec<_> = ciphertexts
.map(|ciphertext| {
transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
let random_scalar = SecretKey::<G>::generate(rng);
let random_commitment = G::mul_generator(random_scalar.expose_scalar());
transcript.append_element::<G>(b"[e_r]G", &random_commitment);
let value_scalar = SecretKey::<G>::generate(rng);
let value_commitment = G::mul_generator(value_scalar.expose_scalar())
+ receiver.as_element() * random_scalar.expose_scalar();
transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
let neg_value = Zeroizing::new(-*ciphertext.value());
sum_random_scalar += ciphertext.randomness() * &neg_value;
(ciphertext, random_scalar, value_scalar)
})
.collect();
let scalars = partial_scalars
.iter()
.map(|(_, _, value_scalar)| value_scalar.expose_scalar())
.chain(iter::once(sum_scalar.expose_scalar()));
let random_sum_commitment = {
let elements = partial_scalars
.iter()
.map(|(ciphertext, ..)| ciphertext.inner().random_element)
.chain(iter::once(G::generator()));
G::multi_mul(scalars.clone(), elements)
};
let value_sum_commitment = {
let elements = partial_scalars
.iter()
.map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
.chain(iter::once(receiver.as_element()));
G::multi_mul(scalars, elements)
};
transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
let challenge = transcript.challenge_scalar::<G>(b"c");
let ciphertext_responses = partial_scalars
.into_iter()
.flat_map(|(ciphertext, random_scalar, value_scalar)| {
[
challenge * ciphertext.randomness().expose_scalar()
+ random_scalar.expose_scalar(),
challenge * ciphertext.value() + value_scalar.expose_scalar(),
]
})
.collect();
let sum_response =
challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
Self {
challenge,
ciphertext_responses,
sum_response,
}
}
pub fn verify<'a>(
&self,
ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
sum_of_squares_ciphertext: &Ciphertext<G>,
receiver: &PublicKey<G>,
transcript: &mut Transcript,
) -> Result<(), VerificationError> {
let ciphertexts_count = ciphertexts.clone().count();
VerificationError::check_lengths(
"ciphertext responses",
self.ciphertext_responses.len(),
ciphertexts_count * 2,
)?;
Self::initialize_transcript(transcript, receiver);
let neg_challenge = -self.challenge;
for (response_chunk, ciphertext) in
self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
{
transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
let r_response = &response_chunk[0];
let v_response = &response_chunk[1];
let random_commitment = G::vartime_double_mul_generator(
&-self.challenge,
ciphertext.random_element,
r_response,
);
transcript.append_element::<G>(b"[e_r]G", &random_commitment);
let value_commitment = G::vartime_multi_mul(
[v_response, r_response, &neg_challenge],
[
G::generator(),
receiver.as_element(),
ciphertext.blinded_element,
],
);
transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
}
let scalars = OddItems::new(self.ciphertext_responses.iter())
.chain([&self.sum_response, &neg_challenge]);
let random_sum_commitment = {
let elements = ciphertexts
.clone()
.map(|c| c.random_element)
.chain([G::generator(), sum_of_squares_ciphertext.random_element]);
G::vartime_multi_mul(scalars.clone(), elements)
};
let value_sum_commitment = {
let elements = ciphertexts.map(|c| c.blinded_element).chain([
receiver.as_element(),
sum_of_squares_ciphertext.blinded_element,
]);
G::vartime_multi_mul(scalars, elements)
};
transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
let expected_challenge = transcript.challenge_scalar::<G>(b"c");
if expected_challenge == self.challenge {
Ok(())
} else {
Err(VerificationError::ChallengeMismatch)
}
}
}
#[derive(Debug, Clone)]
struct OddItems<I> {
iter: I,
ended: bool,
}
impl<I: Iterator> OddItems<I> {
fn new(iter: I) -> Self {
Self { iter, ended: false }
}
}
impl<I: Iterator> Iterator for OddItems<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
if self.ended {
return None;
}
self.ended = self.iter.next().is_none();
if self.ended {
return None;
}
let item = self.iter.next();
self.ended = item.is_none();
item
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (min, max) = self.iter.size_hint();
(min / 2, max.map(|max| max / 2))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{group::Ristretto, Keypair};
use rand::thread_rng;
#[test]
fn sum_of_squares_proof_basics() {
let mut rng = thread_rng();
let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
let proof = SumOfSquaresProof::new(
[&ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
&mut rng,
);
let ciphertext = ciphertext.into();
let sq_ciphertext = sq_ciphertext.into();
proof
.verify(
[&ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap();
let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
let err = proof
.verify(
[&ciphertext].into_iter(),
&other_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::ChallengeMismatch));
let err = proof
.verify(
[&other_ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::ChallengeMismatch));
let err = proof
.verify(
[&ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"other_transcript"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::ChallengeMismatch));
}
#[test]
fn sum_of_squares_proof_with_bogus_inputs() {
let mut rng = thread_rng();
let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
let proof = SumOfSquaresProof::new(
[&ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
&mut rng,
);
let ciphertext = ciphertext.into();
let sq_ciphertext = sq_ciphertext.into();
let err = proof
.verify(
[&ciphertext].into_iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::ChallengeMismatch));
}
#[test]
fn sum_of_squares_proof_with_several_squares() {
let mut rng = thread_rng();
let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
let ciphertexts =
[3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
let proof = SumOfSquaresProof::new(
ciphertexts.iter(),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
&mut rng,
);
let sq_ciphertext = sq_ciphertext.into();
proof
.verify(
ciphertexts.iter().map(CiphertextWithValue::inner),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap();
let err = proof
.verify(
ciphertexts.iter().rev().map(CiphertextWithValue::inner),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::ChallengeMismatch));
let err = proof
.verify(
ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
&sq_ciphertext,
&receiver,
&mut Transcript::new(b"test"),
)
.unwrap_err();
assert!(matches!(err, VerificationError::LenMismatch { .. }));
}
#[test]
fn odd_items() {
let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
assert_eq!(odd_items.size_hint(), (2, Some(2)));
assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
let other_items = OddItems::new(0..7);
assert_eq!(other_items.size_hint(), (3, Some(3)));
assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
}
}