1use core::iter;
4
5use merlin::Transcript;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use super::{Error, Params, PublicPolynomial, lagrange_coefficients};
10use crate::{
11 CandidateDecryption, Ciphertext, PublicKey, VerifiableDecryption,
12 alloc::Vec,
13 group::Group,
14 proofs::{LogEqualityProof, ProofOfPossession, TranscriptForGroup, VerificationError},
15};
16
17#[derive(Debug, Clone)]
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21#[cfg_attr(feature = "serde", serde(bound = ""))]
22pub struct PublicKeySet<G: Group> {
23 params: Params,
24 shared_key: PublicKey<G>,
25 participant_keys: Vec<PublicKey<G>>,
26}
27
28impl<G: Group> PublicKeySet<G> {
29 pub(crate) fn validate(
30 params: Params,
31 public_polynomial: &[G::Element],
32 proof_of_possession: &ProofOfPossession<G>,
33 ) -> Result<(), Error> {
34 if public_polynomial.len() != params.threshold {
35 return Err(Error::MalformedDealerPolynomial);
36 }
37
38 let mut transcript = Transcript::new(b"elgamal_share_poly");
39 transcript.append_u64(b"n", params.shares as u64);
40 transcript.append_u64(b"t", params.threshold as u64);
41
42 let public_poly_keys: Vec<_> = public_polynomial
43 .iter()
44 .copied()
45 .map(PublicKey::from_element)
46 .collect();
47 proof_of_possession
48 .verify(public_poly_keys.iter(), &mut transcript)
49 .map_err(Error::InvalidDealerProof)?;
50 Ok(())
51 }
52
53 pub fn new(
61 params: Params,
62 public_polynomial: Vec<G::Element>,
63 proof_of_possession: &ProofOfPossession<G>,
64 ) -> Result<Self, Error> {
65 Self::validate(params, &public_polynomial, proof_of_possession)?;
66
67 let public_poly = PublicPolynomial::<G>(public_polynomial);
68 let shared_key = PublicKey::from_element(public_poly.value_at_zero());
69 let participant_keys = (0..params.shares)
70 .map(|idx| PublicKey::from_element(public_poly.value_at((idx as u64 + 1).into())))
71 .collect();
72
73 Ok(Self {
74 params,
75 shared_key,
76 participant_keys,
77 })
78 }
79
80 pub fn from_participants(
88 params: Params,
89 participant_keys: Vec<PublicKey<G>>,
90 ) -> Result<Self, Error> {
91 if params.shares != participant_keys.len() {
92 return Err(Error::ParticipantCountMismatch);
93 }
94
95 let indexes: Vec<_> = (0..params.threshold).collect();
97 let (denominators, scale) = lagrange_coefficients::<G>(&indexes);
98 let starting_keys = participant_keys
99 .iter()
100 .map(PublicKey::as_element)
101 .take(params.threshold);
102 let shared_key = G::vartime_multi_mul(&denominators, starting_keys.clone());
103 let shared_key = PublicKey::from_element(shared_key * &scale);
104
105 let mut inverses: Vec<_> = (1_u64..=params.shares as u64)
109 .map(G::Scalar::from)
110 .collect();
111 G::invert_scalars(&mut inverses);
112
113 for (x, key) in participant_keys.iter().enumerate().skip(params.threshold) {
114 let mut key_scale = indexes
115 .iter()
116 .map(|&idx| G::Scalar::from((x - idx) as u64))
117 .fold(G::Scalar::from(1), |acc, value| acc * value);
118
119 let key_denominators: Vec<_> = denominators
120 .iter()
121 .enumerate()
122 .map(|(idx, &d)| d * G::Scalar::from(idx as u64 + 1) * inverses[x - idx - 1])
123 .collect();
124
125 if params.threshold % 2 == 0 {
129 key_scale = -key_scale;
130 }
131
132 let interpolated_key = G::vartime_multi_mul(&key_denominators, starting_keys.clone());
133 let interpolated_key = interpolated_key * &key_scale;
134 if interpolated_key != key.as_element() {
135 return Err(Error::MalformedParticipantKeys);
136 }
137 }
138
139 Ok(Self {
140 params,
141 shared_key,
142 participant_keys,
143 })
144 }
145
146 pub fn params(&self) -> Params {
148 self.params
149 }
150
151 pub fn shared_key(&self) -> &PublicKey<G> {
153 &self.shared_key
154 }
155
156 pub fn participant_key(&self, index: usize) -> Option<&PublicKey<G>> {
159 self.participant_keys.get(index)
160 }
161
162 pub fn participant_keys(&self) -> &[PublicKey<G>] {
164 &self.participant_keys
165 }
166
167 pub(super) fn commit(&self, transcript: &mut Transcript) {
168 transcript.append_u64(b"n", self.params.shares as u64);
169 transcript.append_u64(b"t", self.params.threshold as u64);
170 transcript.append_element_bytes(b"K", self.shared_key.as_bytes());
171 }
172
173 pub fn verify_participant(
187 &self,
188 index: usize,
189 proof: &ProofOfPossession<G>,
190 ) -> Result<(), VerificationError> {
191 let participant_key = self.participant_key(index).unwrap_or_else(|| {
192 panic!(
193 "participant index {index} out of bounds, expected a value in 0..{}",
194 self.participant_keys.len()
195 );
196 });
197 let mut transcript = Transcript::new(b"elgamal_participant_pop");
198 self.commit(&mut transcript);
199 transcript.append_u64(b"i", index as u64);
200 proof.verify(iter::once(participant_key), &mut transcript)
201 }
202
203 pub fn verify_share(
210 &self,
211 candidate_share: CandidateDecryption<G>,
212 ciphertext: Ciphertext<G>,
213 index: usize,
214 proof: &LogEqualityProof<G>,
215 ) -> Result<VerifiableDecryption<G>, VerificationError> {
216 let key_share = self.participant_keys[index].as_element();
217 let dh_element = candidate_share.dh_element();
218 let mut transcript = Transcript::new(b"elgamal_decryption_share");
219 self.commit(&mut transcript);
220 transcript.append_u64(b"i", index as u64);
221
222 proof.verify(
223 &PublicKey::from_element(ciphertext.random_element),
224 (key_share, dh_element),
225 &mut transcript,
226 )?;
227 Ok(VerifiableDecryption::from_element(dh_element))
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::{
235 group::{ElementOps, Ristretto},
236 sharing::Dealer,
237 };
238
239 #[test]
240 fn restoring_key_set_from_participant_keys_errors() {
241 let mut rng = rand::rng();
242 let params = Params::new(10, 7);
243
244 let dealer = Dealer::<Ristretto>::new(params, &mut rng);
245 let (public_poly, _) = dealer.public_info();
246 let public_poly = PublicPolynomial::<Ristretto>(public_poly);
247 let participant_keys: Vec<PublicKey<Ristretto>> = (1..=params.shares)
248 .map(|i| PublicKey::from_element(public_poly.value_at((i as u64).into())))
249 .collect();
250
251 PublicKeySet::from_participants(params, participant_keys.clone()).unwrap();
253
254 let err =
255 PublicKeySet::from_participants(params, participant_keys[1..].to_vec()).unwrap_err();
256 assert!(matches!(err, Error::ParticipantCountMismatch));
257
258 let mut bogus_keys = participant_keys.clone();
260 bogus_keys.swap(1, 5);
261 let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
262 assert!(matches!(err, Error::MalformedParticipantKeys));
263
264 for i in 0..params.shares {
265 let mut bogus_keys = participant_keys.clone();
266 bogus_keys[i] =
267 PublicKey::from_element(bogus_keys[i].as_element() + Ristretto::generator());
268 let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
269 assert!(matches!(err, Error::MalformedParticipantKeys));
270 }
271 }
272}