elastic_elgamal/group/
ristretto.rs1use core::convert::TryInto;
2
3use elliptic_curve::rand_core::{CryptoRng, RngCore};
4
5use crate::{
6 curve25519::{
7 constants::{RISTRETTO_BASEPOINT_POINT, RISTRETTO_BASEPOINT_TABLE},
8 ristretto::{CompressedRistretto, RistrettoPoint},
9 scalar::Scalar,
10 traits::{Identity, IsIdentity, MultiscalarMul, VartimeMultiscalarMul},
11 },
12 group::{ElementOps, Group, RandomBytesProvider, ScalarOps},
13};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17#[cfg_attr(
18 docsrs,
19 doc(cfg(any(feature = "curve25519-dalek", feature = "curve25519-dalek-ng")))
20)]
21pub struct Ristretto(());
22
23impl ScalarOps for Ristretto {
24 type Scalar = Scalar;
25
26 const SCALAR_SIZE: usize = 32;
27
28 fn generate_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Self::Scalar {
29 let mut scalar_bytes = [0_u8; 64];
30 rng.fill_bytes(&mut scalar_bytes[..]);
31 Scalar::from_bytes_mod_order_wide(&scalar_bytes)
32 }
33
34 fn scalar_from_random_bytes(source: RandomBytesProvider<'_>) -> Self::Scalar {
35 let mut scalar_bytes = [0_u8; 64];
36 source.fill_bytes(&mut scalar_bytes);
37 Scalar::from_bytes_mod_order_wide(&scalar_bytes)
38 }
39
40 fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
41 scalar.invert()
42 }
43
44 #[cfg(feature = "curve25519-dalek")]
45 fn invert_scalars(scalars: &mut [Self::Scalar]) {
46 Scalar::invert_batch_alloc(scalars);
47 }
48
49 #[cfg(feature = "curve25519-dalek-ng")]
50 fn invert_scalars(scalars: &mut [Self::Scalar]) {
51 Scalar::batch_invert(scalars);
52 }
53
54 fn serialize_scalar(scalar: &Self::Scalar, buffer: &mut [u8]) {
55 buffer.copy_from_slice(&scalar.to_bytes());
56 }
57
58 #[cfg(feature = "curve25519-dalek")]
59 fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
60 let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
61 Scalar::from_canonical_bytes(*bytes).into()
62 }
63
64 #[cfg(feature = "curve25519-dalek-ng")]
65 fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
66 let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
67 Scalar::from_canonical_bytes(*bytes)
68 }
69}
70
71impl ElementOps for Ristretto {
72 type Element = RistrettoPoint;
73
74 const ELEMENT_SIZE: usize = 32;
75
76 fn identity() -> Self::Element {
77 RistrettoPoint::identity()
78 }
79
80 fn is_identity(element: &Self::Element) -> bool {
81 element.is_identity()
82 }
83
84 fn generator() -> Self::Element {
85 RISTRETTO_BASEPOINT_POINT
86 }
87
88 fn serialize_element(element: &Self::Element, buffer: &mut [u8]) {
89 buffer.copy_from_slice(&element.compress().to_bytes());
90 }
91
92 #[cfg(feature = "curve25519-dalek")]
93 fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
94 CompressedRistretto::from_slice(buffer).ok()?.decompress()
95 }
96
97 #[cfg(feature = "curve25519-dalek-ng")]
98 fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
99 CompressedRistretto::from_slice(buffer).decompress()
100 }
101}
102
103impl Group for Ristretto {
104 #[cfg(feature = "curve25519-dalek")]
105 fn mul_generator(k: &Scalar) -> Self::Element {
106 k * RISTRETTO_BASEPOINT_TABLE
107 }
108
109 #[cfg(feature = "curve25519-dalek-ng")]
110 fn mul_generator(k: &Scalar) -> Self::Element {
111 k * &RISTRETTO_BASEPOINT_TABLE
112 }
113
114 fn vartime_mul_generator(k: &Scalar) -> Self::Element {
115 #[cfg(feature = "curve25519-dalek")]
116 let zero = Scalar::ZERO;
117 #[cfg(feature = "curve25519-dalek-ng")]
118 let zero = Scalar::zero();
119
120 RistrettoPoint::vartime_double_scalar_mul_basepoint(&zero, &RistrettoPoint::identity(), k)
121 }
122
123 fn multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
124 where
125 I: IntoIterator<Item = &'a Self::Scalar>,
126 J: IntoIterator<Item = Self::Element>,
127 {
128 RistrettoPoint::multiscalar_mul(scalars, elements)
129 }
130
131 fn vartime_double_mul_generator(
132 k: &Scalar,
133 k_element: Self::Element,
134 r: &Scalar,
135 ) -> Self::Element {
136 RistrettoPoint::vartime_double_scalar_mul_basepoint(k, &k_element, r)
137 }
138
139 fn vartime_multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
140 where
141 I: IntoIterator<Item = &'a Self::Scalar>,
142 J: IntoIterator<Item = Self::Element>,
143 {
144 RistrettoPoint::vartime_multiscalar_mul(scalars, elements)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::{
152 DiscreteLogTable,
153 app::{ChoiceParams, EncryptedChoice},
154 group::Curve25519Subgroup,
155 };
156
157 type SecretKey = crate::SecretKey<Ristretto>;
158 type Keypair = crate::Keypair<Ristretto>;
159
160 #[test]
161 fn encrypt_and_decrypt() {
162 let mut rng = rand::rng();
163 let keypair = Keypair::generate(&mut rng);
164 let value = Ristretto::generate_scalar(&mut rng);
165 let encrypted = keypair.public().encrypt(value, &mut rng);
166 let decryption = keypair.secret().decrypt_to_element(encrypted);
167 assert_eq!(decryption, Ristretto::vartime_mul_generator(&value));
168 }
169
170 #[test]
171 fn encrypt_choice() {
172 let mut rng = rand::rng();
173 let (pk, sk) = Keypair::generate(&mut rng).into_tuple();
174 let choice_params = ChoiceParams::single(pk, 5);
175 let encrypted = EncryptedChoice::single(&choice_params, 3, &mut rng);
176 let choices = encrypted.verify(&choice_params).unwrap();
177
178 let lookup_table = DiscreteLogTable::new(0..=1);
179 for (i, &choice) in choices.iter().enumerate() {
180 let decryption = sk.decrypt(choice, &lookup_table);
181 assert_eq!(decryption.unwrap(), u64::from(i == 3));
182 }
183 }
184
185 #[test]
186 fn edwards_and_ristretto_public_keys_differ() {
187 type SubgroupSecretKey = crate::SecretKey<Curve25519Subgroup>;
188 type SubgroupKeypair = crate::Keypair<Curve25519Subgroup>;
189
190 for _ in 0..1_000 {
191 let secret_key = SecretKey::generate(&mut rand::rng());
192 let keypair = Keypair::from(secret_key.clone());
193 let secret_key = SubgroupSecretKey::new(*secret_key.expose_scalar());
194 let ed_keypair = SubgroupKeypair::from(secret_key);
195 assert_ne!(keypair.public().as_bytes(), ed_keypair.public().as_bytes());
196 }
197 }
198}