use merlin::Transcript;
use rand_core::{CryptoRng, RngCore};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use core::{fmt, mem};
#[cfg(feature = "serde")]
use crate::serde::{ScalarHelper, VecHelper};
use crate::{
alloc::{vec, Vec},
encryption::ExtendedCiphertext,
group::Group,
proofs::{TranscriptForGroup, VerificationError},
Ciphertext, PublicKey, SecretKey,
};
struct Ring<'a, G: Group> {
index: usize,
admissible_values: &'a [G::Element],
ciphertext: Ciphertext<G>,
transcript: Transcript,
responses: &'a mut [G::Scalar],
terminal_commitments: (G::Element, G::Element),
value_index: usize,
discrete_log: SecretKey<G>,
random_scalar: SecretKey<G>,
}
impl<G: Group> fmt::Debug for Ring<'_, G> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Ring")
.field("index", &self.index)
.field("admissible_values", &self.admissible_values)
.field("ciphertext", &self.ciphertext)
.field("responses", &self.responses)
.field("terminal_commitments", &self.terminal_commitments)
.finish()
}
}
impl<'a, G: Group> Ring<'a, G> {
#[allow(clippy::too_many_arguments)] fn new<R: CryptoRng + RngCore>(
index: usize,
log_base: G::Element,
ciphertext: ExtendedCiphertext<G>,
admissible_values: &'a [G::Element],
value_index: usize,
transcript: &Transcript,
responses: &'a mut [G::Scalar],
rng: &mut R,
) -> Self {
assert!(
!admissible_values.is_empty(),
"No admissible values supplied"
);
assert!(
value_index < admissible_values.len(),
"Specified value index is out of bounds"
);
debug_assert_eq!(
responses.len(),
admissible_values.len(),
"Number of responses doesn't match number of admissible values"
);
let random_element = ciphertext.inner.random_element;
let blinded_value = ciphertext.inner.blinded_element;
debug_assert!(
{
let expected_blinded_value = log_base * ciphertext.random_scalar.expose_scalar()
+ admissible_values[value_index];
expected_blinded_value == blinded_value
},
"Specified ciphertext does not match the specified `value_index`"
);
let mut transcript = transcript.clone();
transcript.start_proof(b"ring_enc");
transcript.append_message(b"enc", &ciphertext.inner.to_bytes());
transcript.append_u64(b"i", index as u64);
let random_scalar = SecretKey::<G>::generate(rng);
let mut commitments = (
G::mul_generator(random_scalar.expose_scalar()),
log_base * random_scalar.expose_scalar(),
);
let it = admissible_values.iter().enumerate().skip(value_index + 1);
for (eq_index, &admissible_value) in it {
let mut eq_transcript = transcript.clone();
eq_transcript.append_u64(b"j", eq_index as u64 - 1);
eq_transcript.append_element::<G>(b"R_G", &commitments.0);
eq_transcript.append_element::<G>(b"R_K", &commitments.1);
let challenge = eq_transcript.challenge_scalar::<G>(b"c");
let response = G::generate_scalar(rng);
responses[eq_index] = response;
let dh_element = blinded_value - admissible_value;
commitments = (
G::mul_generator(&response) - random_element * &challenge,
G::multi_mul([&response, &-challenge], [log_base, dh_element]),
);
}
Self {
index,
value_index,
admissible_values,
ciphertext: ciphertext.inner,
transcript,
responses,
terminal_commitments: commitments,
discrete_log: ciphertext.random_scalar,
random_scalar,
}
}
fn aggregate<R: CryptoRng + RngCore>(
rings: Vec<Self>,
log_base: G::Element,
transcript: &mut Transcript,
rng: &mut R,
) -> G::Scalar {
debug_assert!(
rings.iter().enumerate().all(|(i, ring)| i == ring.index),
"Rings have bogus indexes"
);
for ring in &rings {
let commitments = &ring.terminal_commitments;
transcript.append_element::<G>(b"R_G", &commitments.0);
transcript.append_element::<G>(b"R_K", &commitments.1);
}
let common_challenge = transcript.challenge_scalar::<G>(b"c");
for ring in rings {
ring.finalize(log_base, common_challenge, rng);
}
common_challenge
}
fn finalize<R: CryptoRng + RngCore>(
self,
log_base: G::Element,
common_challenge: G::Scalar,
rng: &mut R,
) {
let mut challenge = common_challenge;
let it = self.admissible_values[..self.value_index]
.iter()
.enumerate();
for (eq_index, &admissible_value) in it {
let response = G::generate_scalar(rng);
self.responses[eq_index] = response;
let dh_element = self.ciphertext.blinded_element - admissible_value;
let commitments = (
G::mul_generator(&response) - self.ciphertext.random_element * &challenge,
G::multi_mul([&response, &-challenge], [log_base, dh_element]),
);
let mut eq_transcript = self.transcript.clone();
eq_transcript.append_u64(b"j", eq_index as u64);
eq_transcript.append_element::<G>(b"R_G", &commitments.0);
eq_transcript.append_element::<G>(b"R_K", &commitments.1);
challenge = eq_transcript.challenge_scalar::<G>(b"c");
}
debug_assert_eq!(self.responses[self.value_index], G::Scalar::from(0_u64));
self.responses[self.value_index] =
challenge * self.discrete_log.expose_scalar() + self.random_scalar.expose_scalar();
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(bound = ""))]
pub struct RingProof<G: Group> {
#[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
common_challenge: G::Scalar,
#[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
ring_responses: Vec<G::Scalar>,
}
impl<G: Group> RingProof<G> {
fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
transcript.start_proof(b"multi_ring_enc");
transcript.append_element_bytes(b"K", receiver.as_bytes());
}
pub(crate) fn new(common_challenge: G::Scalar, ring_responses: Vec<G::Scalar>) -> Self {
Self {
common_challenge,
ring_responses,
}
}
pub(crate) fn verify<'a>(
&self,
receiver: &PublicKey<G>,
admissible_values: impl Iterator<Item = &'a [G::Element]> + Clone,
ciphertexts: impl Iterator<Item = Ciphertext<G>>,
transcript: &mut Transcript,
) -> Result<(), VerificationError> {
let total_rings_size: usize = admissible_values.clone().map(<[_]>::len).sum();
VerificationError::check_lengths(
"items in all rings",
self.total_rings_size(),
total_rings_size,
)?;
Self::initialize_transcript(transcript, receiver);
let initial_ring_transcript = transcript.clone();
let it = admissible_values.zip(ciphertexts).enumerate();
let mut starting_response = 0;
for (ring_index, (values, ciphertext)) in it {
let mut challenge = self.common_challenge;
let mut commitments = (G::generator(), G::generator());
let mut ring_transcript = initial_ring_transcript.clone();
ring_transcript.start_proof(b"ring_enc");
ring_transcript.append_message(b"enc", &ciphertext.to_bytes());
ring_transcript.append_u64(b"i", ring_index as u64);
for (eq_index, (&admissible_value, response)) in values
.iter()
.zip(&self.ring_responses[starting_response..])
.enumerate()
{
let dh_element = ciphertext.blinded_element - admissible_value;
let neg_challenge = -challenge;
commitments = (
G::vartime_double_mul_generator(
&neg_challenge,
ciphertext.random_element,
response,
),
G::vartime_multi_mul(
[response, &neg_challenge],
[receiver.as_element(), dh_element],
),
);
if eq_index + 1 < values.len() {
let mut eq_transcript = ring_transcript.clone();
eq_transcript.append_u64(b"j", eq_index as u64);
eq_transcript.append_element::<G>(b"R_G", &commitments.0);
eq_transcript.append_element::<G>(b"R_K", &commitments.1);
challenge = eq_transcript.challenge_scalar::<G>(b"c");
}
}
starting_response += values.len();
transcript.append_element::<G>(b"R_G", &commitments.0);
transcript.append_element::<G>(b"R_K", &commitments.1);
}
let expected_challenge = transcript.challenge_scalar::<G>(b"c");
if expected_challenge == self.common_challenge {
Ok(())
} else {
Err(VerificationError::ChallengeMismatch)
}
}
pub(crate) fn total_rings_size(&self) -> usize {
self.ring_responses.len()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![0_u8; G::SCALAR_SIZE * (1 + self.total_rings_size())];
G::serialize_scalar(&self.common_challenge, &mut bytes[..G::SCALAR_SIZE]);
let chunks = bytes[G::SCALAR_SIZE..].chunks_mut(G::SCALAR_SIZE);
for (response, buffer) in self.ring_responses.iter().zip(chunks) {
G::serialize_scalar(response, buffer);
}
bytes
}
#[allow(clippy::missing_panics_doc)] pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() % G::SCALAR_SIZE != 0 || bytes.len() < 3 * G::SCALAR_SIZE {
return None;
}
let common_challenge = G::deserialize_scalar(&bytes[..G::SCALAR_SIZE])?;
let ring_responses: Option<Vec<_>> = bytes[G::SCALAR_SIZE..]
.chunks(G::SCALAR_SIZE)
.map(G::deserialize_scalar)
.collect();
let ring_responses = ring_responses?;
debug_assert!(ring_responses.len() >= 2);
Some(Self {
common_challenge,
ring_responses,
})
}
}
#[doc(hidden)] pub struct RingProofBuilder<'a, G: Group, R> {
receiver: &'a PublicKey<G>,
transcript: &'a mut Transcript,
rings: Vec<Ring<'a, G>>,
ring_responses: &'a mut [G::Scalar],
rng: &'a mut R,
}
impl<G: Group, R: fmt::Debug> fmt::Debug for RingProofBuilder<'_, G, R> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("RingProofBuilder")
.field("receiver", self.receiver)
.field("rings", &self.rings)
.field("rng", self.rng)
.finish()
}
}
impl<'a, G: Group, R: RngCore + CryptoRng> RingProofBuilder<'a, G, R> {
pub fn new(
receiver: &'a PublicKey<G>,
ring_count: usize,
ring_responses: &'a mut [G::Scalar],
transcript: &'a mut Transcript,
rng: &'a mut R,
) -> Self {
RingProof::<G>::initialize_transcript(transcript, receiver);
Self {
receiver,
transcript,
rings: Vec::with_capacity(ring_count),
ring_responses,
rng,
}
}
pub fn add_value(
&mut self,
admissible_values: &'a [G::Element],
value_index: usize,
) -> ExtendedCiphertext<G> {
let ext_ciphertext =
ExtendedCiphertext::new(admissible_values[value_index], self.receiver, self.rng);
self.add_precomputed_value(ext_ciphertext.clone(), admissible_values, value_index);
ext_ciphertext
}
pub(crate) fn add_precomputed_value(
&mut self,
ciphertext: ExtendedCiphertext<G>,
admissible_values: &'a [G::Element],
value_index: usize,
) {
let ring_responses = mem::take(&mut self.ring_responses);
let (responses_for_ring, rest) = ring_responses.split_at_mut(admissible_values.len());
self.ring_responses = rest;
let ring = Ring::new(
self.rings.len(),
self.receiver.as_element(),
ciphertext,
admissible_values,
value_index,
&*self.transcript,
responses_for_ring,
self.rng,
);
self.rings.push(ring);
}
pub fn build(self) -> G::Scalar {
debug_assert!(
self.ring_responses.is_empty(),
"Not all ring_responses were used"
);
Ring::aggregate(
self.rings,
self.receiver.as_element(),
self.transcript,
self.rng,
)
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
use test_casing::test_casing;
use core::iter;
use super::*;
use crate::{
curve25519::{ristretto::RistrettoPoint, scalar::Scalar as Scalar25519, traits::Identity},
group::{ElementOps, Ristretto},
};
type Keypair = crate::Keypair<Ristretto>;
#[test]
fn single_ring_with_2_elements_works() {
let mut rng = thread_rng();
let keypair = Keypair::generate(&mut rng);
let log_base = keypair.public().as_element();
let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
let value = RistrettoPoint::identity();
let ext_ciphertext = ExtendedCiphertext::new(value, keypair.public(), &mut rng);
let ciphertext = ext_ciphertext.inner;
let mut transcript = Transcript::new(b"test_ring_encryption");
RingProof::initialize_transcript(&mut transcript, keypair.public());
let mut ring_responses = vec![Scalar25519::default(); 2];
let signature_ring = Ring::new(
0,
log_base,
ext_ciphertext,
&admissible_values,
0,
&transcript,
&mut ring_responses,
&mut rng,
);
let common_challenge =
Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
RingProof::new(common_challenge, ring_responses)
.verify(
keypair.public(),
iter::once(&admissible_values as &[_]),
iter::once(ciphertext),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
let value = Ristretto::generator();
let ext_ciphertext = ExtendedCiphertext::new(value, keypair.public(), &mut rng);
let ciphertext = ext_ciphertext.inner;
let mut transcript = Transcript::new(b"test_ring_encryption");
RingProof::initialize_transcript(&mut transcript, keypair.public());
let mut ring_responses = vec![Scalar25519::default(); 2];
let signature_ring = Ring::new(
0,
log_base,
ext_ciphertext,
&admissible_values,
1,
&transcript,
&mut ring_responses,
&mut rng,
);
let common_challenge =
Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
RingProof::new(common_challenge, ring_responses)
.verify(
keypair.public(),
iter::once(&admissible_values as &[_]),
iter::once(ciphertext),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
}
#[test]
fn single_ring_with_4_elements_works() {
let mut rng = thread_rng();
let keypair = Keypair::generate(&mut rng);
let log_base = keypair.public().as_element();
let admissible_values: Vec<_> = (0_u32..4)
.map(|i| Ristretto::mul_generator(&Scalar25519::from(i)))
.collect();
for _ in 0..100 {
let val: u32 = rng.gen_range(0..4);
let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
let ext_ciphertext = ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
let ciphertext = ext_ciphertext.inner;
let mut transcript = Transcript::new(b"test_ring_encryption");
RingProof::initialize_transcript(&mut transcript, keypair.public());
let mut ring_responses = vec![Scalar25519::default(); 4];
let signature_ring = Ring::new(
0,
log_base,
ext_ciphertext,
&admissible_values,
val as usize,
&transcript,
&mut ring_responses,
&mut rng,
);
let common_challenge =
Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
RingProof::new(common_challenge, ring_responses)
.verify(
keypair.public(),
iter::once(admissible_values.as_slice()),
iter::once(ciphertext),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
}
}
#[test_casing(5, 3..=7)]
fn multiple_rings_with_boolean_flags_work(ring_count: usize) {
let mut rng = thread_rng();
let keypair = Keypair::generate(&mut rng);
let log_base = keypair.public().as_element();
let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
for _ in 0..20 {
let mut transcript = Transcript::new(b"test_ring_encryption");
RingProof::initialize_transcript(&mut transcript, keypair.public());
let mut ring_responses = vec![Scalar25519::default(); ring_count * 2];
let (ciphertexts, rings): (Vec<_>, Vec<_>) = ring_responses
.chunks_mut(2)
.enumerate()
.map(|(ring_index, ring_responses)| {
let val: u32 = rng.gen_range(0..=1);
let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
let ext_ciphertext =
ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
let ciphertext = ext_ciphertext.inner;
let signature_ring = Ring::new(
ring_index,
log_base,
ext_ciphertext,
&admissible_values,
val as usize,
&transcript,
ring_responses,
&mut rng,
);
(ciphertext, signature_ring)
})
.unzip();
let common_challenge = Ring::aggregate(rings, log_base, &mut transcript, &mut rng);
RingProof::new(common_challenge, ring_responses)
.verify(
keypair.public(),
iter::repeat(&admissible_values as &[_]).take(ring_count),
ciphertexts.into_iter(),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
}
}
#[test]
fn multiple_rings_with_base4_value_encoding_work() {
const RING_COUNT: u8 = 4;
let admissible_values: Vec<_> = (0..RING_COUNT)
.map(|ring_index| {
let power: u32 = 1 << (2 * u32::from(ring_index));
[
RistrettoPoint::identity(),
Ristretto::mul_generator(&Scalar25519::from(power)),
Ristretto::mul_generator(&Scalar25519::from(power * 2)),
Ristretto::mul_generator(&Scalar25519::from(power * 3)),
]
})
.collect();
let mut rng = thread_rng();
let keypair = Keypair::generate(&mut rng);
let log_base = keypair.public().as_element();
for _ in 0..20 {
let overall_value: u8 = rng.gen();
let mut transcript = Transcript::new(b"test_ring_encryption");
RingProof::initialize_transcript(&mut transcript, keypair.public());
let mut ring_responses = vec![Scalar25519::default(); RING_COUNT as usize * 4];
let (ciphertexts, rings): (Vec<_>, Vec<_>) = ring_responses
.chunks_mut(4)
.enumerate()
.map(|(ring_index, ring_responses)| {
let mask = 3 << (2 * ring_index);
let val = overall_value & mask;
let val_index = (val >> (2 * ring_index)) as usize;
assert!(val_index < 4);
let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
let ext_ciphertext =
ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
let ciphertext = ext_ciphertext.inner;
let signature_ring = Ring::new(
ring_index,
log_base,
ext_ciphertext,
&admissible_values[ring_index],
val_index,
&transcript,
ring_responses,
&mut rng,
);
(ciphertext, signature_ring)
})
.unzip();
let common_challenge = Ring::aggregate(rings, log_base, &mut transcript, &mut rng);
let admissible_values = admissible_values.iter().map(|values| values as &[_]);
RingProof::new(common_challenge, ring_responses)
.verify(
keypair.public(),
admissible_values,
ciphertexts.into_iter(),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
}
}
#[test_casing(5, 3..=7)]
#[allow(clippy::needless_collect)]
fn proof_builder_works(ring_count: usize) {
let mut rng = thread_rng();
let keypair = Keypair::generate(&mut rng);
let mut transcript = Transcript::new(b"test_ring_encryption");
let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
let mut ring_responses = vec![Scalar25519::default(); ring_count * 2];
let mut builder = RingProofBuilder::new(
keypair.public(),
ring_count,
&mut ring_responses,
&mut transcript,
&mut rng,
);
let ciphertexts: Vec<_> = (0..ring_count)
.map(|i| builder.add_value(&admissible_values, i & 1).inner)
.collect();
RingProof::new(builder.build(), ring_responses)
.verify(
keypair.public(),
iter::repeat(&admissible_values as &[_]).take(ring_count),
ciphertexts.into_iter(),
&mut Transcript::new(b"test_ring_encryption"),
)
.unwrap();
}
}