use merlin::Transcript;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use core::iter;
use super::{lagrange_coefficients, Error, Params, PublicPolynomial};
use crate::{
alloc::Vec,
group::Group,
proofs::{LogEqualityProof, ProofOfPossession, TranscriptForGroup, VerificationError},
CandidateDecryption, Ciphertext, PublicKey, VerifiableDecryption,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(bound = ""))]
pub struct PublicKeySet<G: Group> {
params: Params,
shared_key: PublicKey<G>,
participant_keys: Vec<PublicKey<G>>,
}
impl<G: Group> PublicKeySet<G> {
pub(crate) fn validate(
params: Params,
public_polynomial: &[G::Element],
proof_of_possession: &ProofOfPossession<G>,
) -> Result<(), Error> {
if public_polynomial.len() != params.threshold {
return Err(Error::MalformedDealerPolynomial);
}
let mut transcript = Transcript::new(b"elgamal_share_poly");
transcript.append_u64(b"n", params.shares as u64);
transcript.append_u64(b"t", params.threshold as u64);
let public_poly_keys: Vec<_> = public_polynomial
.iter()
.copied()
.map(PublicKey::from_element)
.collect();
proof_of_possession
.verify(public_poly_keys.iter(), &mut transcript)
.map_err(Error::InvalidDealerProof)?;
Ok(())
}
pub fn new(
params: Params,
public_polynomial: Vec<G::Element>,
proof_of_possession: &ProofOfPossession<G>,
) -> Result<Self, Error> {
Self::validate(params, &public_polynomial, proof_of_possession)?;
let public_poly = PublicPolynomial::<G>(public_polynomial);
let shared_key = PublicKey::from_element(public_poly.value_at_zero());
let participant_keys = (0..params.shares)
.map(|idx| PublicKey::from_element(public_poly.value_at((idx as u64 + 1).into())))
.collect();
Ok(Self {
params,
shared_key,
participant_keys,
})
}
pub fn from_participants(
params: Params,
participant_keys: Vec<PublicKey<G>>,
) -> Result<Self, Error> {
if params.shares != participant_keys.len() {
return Err(Error::ParticipantCountMismatch);
}
let indexes: Vec<_> = (0..params.threshold).collect();
let (denominators, scale) = lagrange_coefficients::<G>(&indexes);
let starting_keys = participant_keys
.iter()
.map(PublicKey::as_element)
.take(params.threshold);
let shared_key = G::vartime_multi_mul(&denominators, starting_keys.clone());
let shared_key = PublicKey::from_element(shared_key * &scale);
let mut inverses: Vec<_> = (1_u64..=params.shares as u64)
.map(G::Scalar::from)
.collect();
G::invert_scalars(&mut inverses);
for (x, key) in participant_keys.iter().enumerate().skip(params.threshold) {
let mut key_scale = indexes
.iter()
.map(|&idx| G::Scalar::from((x - idx) as u64))
.fold(G::Scalar::from(1), |acc, value| acc * value);
let key_denominators: Vec<_> = denominators
.iter()
.enumerate()
.map(|(idx, &d)| d * G::Scalar::from(idx as u64 + 1) * inverses[x - idx - 1])
.collect();
if params.threshold % 2 == 0 {
key_scale = -key_scale;
}
let interpolated_key = G::vartime_multi_mul(&key_denominators, starting_keys.clone());
let interpolated_key = interpolated_key * &key_scale;
if interpolated_key != key.as_element() {
return Err(Error::MalformedParticipantKeys);
}
}
Ok(Self {
params,
shared_key,
participant_keys,
})
}
pub fn params(&self) -> Params {
self.params
}
pub fn shared_key(&self) -> &PublicKey<G> {
&self.shared_key
}
pub fn participant_key(&self, index: usize) -> Option<&PublicKey<G>> {
self.participant_keys.get(index)
}
pub fn participant_keys(&self) -> &[PublicKey<G>] {
&self.participant_keys
}
pub(super) fn commit(&self, transcript: &mut Transcript) {
transcript.append_u64(b"n", self.params.shares as u64);
transcript.append_u64(b"t", self.params.threshold as u64);
transcript.append_element_bytes(b"K", self.shared_key.as_bytes());
}
pub fn verify_participant(
&self,
index: usize,
proof: &ProofOfPossession<G>,
) -> Result<(), VerificationError> {
let participant_key = self.participant_key(index).unwrap_or_else(|| {
panic!(
"participant index {index} out of bounds, expected a value in 0..{}",
self.participant_keys.len()
);
});
let mut transcript = Transcript::new(b"elgamal_participant_pop");
self.commit(&mut transcript);
transcript.append_u64(b"i", index as u64);
proof.verify(iter::once(participant_key), &mut transcript)
}
pub fn verify_share(
&self,
candidate_share: CandidateDecryption<G>,
ciphertext: Ciphertext<G>,
index: usize,
proof: &LogEqualityProof<G>,
) -> Result<VerifiableDecryption<G>, VerificationError> {
let key_share = self.participant_keys[index].as_element();
let dh_element = candidate_share.dh_element();
let mut transcript = Transcript::new(b"elgamal_decryption_share");
self.commit(&mut transcript);
transcript.append_u64(b"i", index as u64);
proof.verify(
&PublicKey::from_element(ciphertext.random_element),
(key_share, dh_element),
&mut transcript,
)?;
Ok(VerifiableDecryption::from_element(dh_element))
}
}
#[cfg(test)]
mod tests {
use rand::thread_rng;
use super::*;
use crate::{
group::{ElementOps, Ristretto},
sharing::Dealer,
};
#[test]
fn restoring_key_set_from_participant_keys_errors() {
let mut rng = thread_rng();
let params = Params::new(10, 7);
let dealer = Dealer::<Ristretto>::new(params, &mut rng);
let (public_poly, _) = dealer.public_info();
let public_poly = PublicPolynomial::<Ristretto>(public_poly);
let participant_keys: Vec<PublicKey<Ristretto>> = (1..=params.shares)
.map(|i| PublicKey::from_element(public_poly.value_at((i as u64).into())))
.collect();
PublicKeySet::from_participants(params, participant_keys.clone()).unwrap();
let err =
PublicKeySet::from_participants(params, participant_keys[1..].to_vec()).unwrap_err();
assert!(matches!(err, Error::ParticipantCountMismatch));
let mut bogus_keys = participant_keys.clone();
bogus_keys.swap(1, 5);
let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
assert!(matches!(err, Error::MalformedParticipantKeys));
for i in 0..params.shares {
let mut bogus_keys = participant_keys.clone();
bogus_keys[i] =
PublicKey::from_element(bogus_keys[i].as_element() + Ristretto::generator());
let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
assert!(matches!(err, Error::MalformedParticipantKeys));
}
}
}