1use merlin::Transcript;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use core::iter;
8
9use super::{lagrange_coefficients, Error, Params, PublicPolynomial};
10
11use crate::{
12    alloc::Vec,
13    group::Group,
14    proofs::{LogEqualityProof, ProofOfPossession, TranscriptForGroup, VerificationError},
15    CandidateDecryption, Ciphertext, PublicKey, VerifiableDecryption,
16};
17
18#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "serde", serde(bound = ""))]
23pub struct PublicKeySet<G: Group> {
24    params: Params,
25    shared_key: PublicKey<G>,
26    participant_keys: Vec<PublicKey<G>>,
27}
28
29impl<G: Group> PublicKeySet<G> {
30    pub(crate) fn validate(
31        params: Params,
32        public_polynomial: &[G::Element],
33        proof_of_possession: &ProofOfPossession<G>,
34    ) -> Result<(), Error> {
35        if public_polynomial.len() != params.threshold {
36            return Err(Error::MalformedDealerPolynomial);
37        }
38
39        let mut transcript = Transcript::new(b"elgamal_share_poly");
40        transcript.append_u64(b"n", params.shares as u64);
41        transcript.append_u64(b"t", params.threshold as u64);
42
43        let public_poly_keys: Vec<_> = public_polynomial
44            .iter()
45            .copied()
46            .map(PublicKey::from_element)
47            .collect();
48        proof_of_possession
49            .verify(public_poly_keys.iter(), &mut transcript)
50            .map_err(Error::InvalidDealerProof)?;
51        Ok(())
52    }
53
54    pub fn new(
62        params: Params,
63        public_polynomial: Vec<G::Element>,
64        proof_of_possession: &ProofOfPossession<G>,
65    ) -> Result<Self, Error> {
66        Self::validate(params, &public_polynomial, proof_of_possession)?;
67
68        let public_poly = PublicPolynomial::<G>(public_polynomial);
69        let shared_key = PublicKey::from_element(public_poly.value_at_zero());
70        let participant_keys = (0..params.shares)
71            .map(|idx| PublicKey::from_element(public_poly.value_at((idx as u64 + 1).into())))
72            .collect();
73
74        Ok(Self {
75            params,
76            shared_key,
77            participant_keys,
78        })
79    }
80
81    pub fn from_participants(
89        params: Params,
90        participant_keys: Vec<PublicKey<G>>,
91    ) -> Result<Self, Error> {
92        if params.shares != participant_keys.len() {
93            return Err(Error::ParticipantCountMismatch);
94        }
95
96        let indexes: Vec<_> = (0..params.threshold).collect();
98        let (denominators, scale) = lagrange_coefficients::<G>(&indexes);
99        let starting_keys = participant_keys
100            .iter()
101            .map(PublicKey::as_element)
102            .take(params.threshold);
103        let shared_key = G::vartime_multi_mul(&denominators, starting_keys.clone());
104        let shared_key = PublicKey::from_element(shared_key * &scale);
105
106        let mut inverses: Vec<_> = (1_u64..=params.shares as u64)
110            .map(G::Scalar::from)
111            .collect();
112        G::invert_scalars(&mut inverses);
113
114        for (x, key) in participant_keys.iter().enumerate().skip(params.threshold) {
115            let mut key_scale = indexes
116                .iter()
117                .map(|&idx| G::Scalar::from((x - idx) as u64))
118                .fold(G::Scalar::from(1), |acc, value| acc * value);
119
120            let key_denominators: Vec<_> = denominators
121                .iter()
122                .enumerate()
123                .map(|(idx, &d)| d * G::Scalar::from(idx as u64 + 1) * inverses[x - idx - 1])
124                .collect();
125
126            if params.threshold % 2 == 0 {
130                key_scale = -key_scale;
131            }
132
133            let interpolated_key = G::vartime_multi_mul(&key_denominators, starting_keys.clone());
134            let interpolated_key = interpolated_key * &key_scale;
135            if interpolated_key != key.as_element() {
136                return Err(Error::MalformedParticipantKeys);
137            }
138        }
139
140        Ok(Self {
141            params,
142            shared_key,
143            participant_keys,
144        })
145    }
146
147    pub fn params(&self) -> Params {
149        self.params
150    }
151
152    pub fn shared_key(&self) -> &PublicKey<G> {
154        &self.shared_key
155    }
156
157    pub fn participant_key(&self, index: usize) -> Option<&PublicKey<G>> {
160        self.participant_keys.get(index)
161    }
162
163    pub fn participant_keys(&self) -> &[PublicKey<G>] {
165        &self.participant_keys
166    }
167
168    pub(super) fn commit(&self, transcript: &mut Transcript) {
169        transcript.append_u64(b"n", self.params.shares as u64);
170        transcript.append_u64(b"t", self.params.threshold as u64);
171        transcript.append_element_bytes(b"K", self.shared_key.as_bytes());
172    }
173
174    pub fn verify_participant(
188        &self,
189        index: usize,
190        proof: &ProofOfPossession<G>,
191    ) -> Result<(), VerificationError> {
192        let participant_key = self.participant_key(index).unwrap_or_else(|| {
193            panic!(
194                "participant index {index} out of bounds, expected a value in 0..{}",
195                self.participant_keys.len()
196            );
197        });
198        let mut transcript = Transcript::new(b"elgamal_participant_pop");
199        self.commit(&mut transcript);
200        transcript.append_u64(b"i", index as u64);
201        proof.verify(iter::once(participant_key), &mut transcript)
202    }
203
204    pub fn verify_share(
211        &self,
212        candidate_share: CandidateDecryption<G>,
213        ciphertext: Ciphertext<G>,
214        index: usize,
215        proof: &LogEqualityProof<G>,
216    ) -> Result<VerifiableDecryption<G>, VerificationError> {
217        let key_share = self.participant_keys[index].as_element();
218        let dh_element = candidate_share.dh_element();
219        let mut transcript = Transcript::new(b"elgamal_decryption_share");
220        self.commit(&mut transcript);
221        transcript.append_u64(b"i", index as u64);
222
223        proof.verify(
224            &PublicKey::from_element(ciphertext.random_element),
225            (key_share, dh_element),
226            &mut transcript,
227        )?;
228        Ok(VerifiableDecryption::from_element(dh_element))
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::{
236        group::{ElementOps, Ristretto},
237        sharing::Dealer,
238    };
239
240    #[test]
241    fn restoring_key_set_from_participant_keys_errors() {
242        let mut rng = rand::rng();
243        let params = Params::new(10, 7);
244
245        let dealer = Dealer::<Ristretto>::new(params, &mut rng);
246        let (public_poly, _) = dealer.public_info();
247        let public_poly = PublicPolynomial::<Ristretto>(public_poly);
248        let participant_keys: Vec<PublicKey<Ristretto>> = (1..=params.shares)
249            .map(|i| PublicKey::from_element(public_poly.value_at((i as u64).into())))
250            .collect();
251
252        PublicKeySet::from_participants(params, participant_keys.clone()).unwrap();
254
255        let err =
256            PublicKeySet::from_participants(params, participant_keys[1..].to_vec()).unwrap_err();
257        assert!(matches!(err, Error::ParticipantCountMismatch));
258
259        let mut bogus_keys = participant_keys.clone();
261        bogus_keys.swap(1, 5);
262        let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
263        assert!(matches!(err, Error::MalformedParticipantKeys));
264
265        for i in 0..params.shares {
266            let mut bogus_keys = participant_keys.clone();
267            bogus_keys[i] =
268                PublicKey::from_element(bogus_keys[i].as_element() + Ristretto::generator());
269            let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
270            assert!(matches!(err, Error::MalformedParticipantKeys));
271        }
272    }
273}